Skip to content

Commit 17da6f8

Browse files
author
Juv Chan
committed
Initial commit
1 parent 9b88c4e commit 17da6f8

File tree

8 files changed

+2059
-2
lines changed

8 files changed

+2059
-2
lines changed

README.md

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,21 @@
1-
# amazon-sagemaker-tensorflow-custom-containers
2-
This project shows step-by-step guide on how to build a real-world flower classifier of 102 flower types using TensorFlow, Amazon SageMaker, Docker and Python in a Jupyter Notebook.
1+
# **Build, Train and Deploy A Real-World Flower Classifier of 102 Flower Types**
2+
## *With TensorFlow 2.3, Amazon SageMaker Python SDK 2.5.x and Custom SageMaker Training & Serving Docker Containers*
3+
4+
## **Introduction**
5+
This project shows step-by-step guide on how to build a **real-world flower classifier** of **102 flower types** using **TensorFlow**, **Amazon SageMaker**, **Docker** and **Python** in a **Jupyter Notebook**. It has been tested with the Python packages in the **requirements.txt** on **Python 3.8.5**.
6+
7+
## **Installation**
8+
Clone this project from GitHub.
9+
Create a new Python virtual environment targeting Python 3.6 and above.
10+
11+
Install the required Python packages from the **requirements.txt** or install them from running the project's Jupyter notebook.
12+
Start and run the notebook with **Jupyter Lab**.
13+
14+
Note that the external flower images used in the notebook are not provided as part of the project.
15+
You could use any other free flower images at your own discretion for the evaluation of the project's flower classification model that you are going to build and deploy.
16+
17+
## **Contributing**
18+
Pull requests, suggestions and feedback are welcome. For major changes or issues, please open an issue to discuss.
19+
20+
## **License**
21+
[Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0)

container/Dockerfile

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2020 Juv Chan. All Rights Reserved.
2+
FROM tensorflow/tensorflow:2.3.0-gpu
3+
4+
LABEL maintainer="Juv Chan <[email protected]>"
5+
6+
RUN apt-get update && apt-get install -y --no-install-recommends nginx curl
7+
RUN pip install --no-cache-dir --upgrade pip tensorflow-hub tensorflow-datasets sagemaker-tensorflow-training
8+
9+
RUN echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list
10+
RUN curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -
11+
RUN apt-get update && apt-get install tensorflow-model-server
12+
13+
ENV PATH="/opt/ml/code:${PATH}"
14+
15+
# /opt/ml and all subdirectories are utilized by SageMaker, we use the /code subdirectory to store our user code.
16+
COPY /code /opt/ml/code
17+
WORKDIR /opt/ml/code
18+
19+
RUN chmod 755 serve

container/code/nginx.conf

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
events {
2+
# determines how many requests can simultaneously be served
3+
# https://www.digitalocean.com/community/tutorials/how-to-optimize-nginx-configuration
4+
# for more information
5+
worker_connections 2048;
6+
}
7+
8+
http {
9+
server {
10+
# configures the server to listen to the port 8080
11+
# Amazon SageMaker sends inference requests to port 8080.
12+
# For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
13+
listen 8080 deferred;
14+
client_max_body_size 10M;
15+
16+
# redirects requests from SageMaker to TF Serving
17+
location /invocations {
18+
proxy_pass http://localhost:8501/v1/models/flowers_model:predict;
19+
}
20+
21+
# Used by SageMaker to confirm if server is alive.
22+
# https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
23+
location /ping {
24+
return 200 "OK";
25+
}
26+
}
27+
}

container/code/serve

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/usr/bin/env python
2+
3+
# This file implements the hosting solution, which just starts TensorFlow Model Serving.
4+
import subprocess
5+
import os
6+
7+
TF_SERVING_DEFAULT_PORT = 8501
8+
MODEL_NAME = 'flowers_model'
9+
MODEL_BASE_PATH = '/opt/ml/model'
10+
11+
12+
def start_server():
13+
print('Starting TensorFlow Serving.')
14+
15+
# link the log streams to stdout/err so they will be logged to the container logs
16+
subprocess.check_call(
17+
['ln', '-sf', '/dev/stdout', '/var/log/nginx/access.log'])
18+
subprocess.check_call(
19+
['ln', '-sf', '/dev/stderr', '/var/log/nginx/error.log'])
20+
21+
# start nginx server
22+
nginx = subprocess.Popen(['nginx', '-c', '/opt/ml/code/nginx.conf'])
23+
24+
# start TensorFlow Serving
25+
# https://www.tensorflow.org/serving/api_rest#start_modelserver_with_the_rest_api_endpoint
26+
tf_model_server = subprocess.call(['tensorflow_model_server',
27+
'--rest_api_port=' +
28+
str(TF_SERVING_DEFAULT_PORT),
29+
'--model_name=' + MODEL_NAME,
30+
'--model_base_path=' + MODEL_BASE_PATH])
31+
32+
33+
# The main routine just invokes the start function.
34+
if __name__ == '__main__':
35+
start_server()

container/code/train.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import argparse
2+
import numpy as np
3+
import os
4+
import logging
5+
import tensorflow as tf
6+
import tensorflow_hub as hub
7+
import tensorflow_datasets as tfds
8+
9+
10+
EPOCHS = 5
11+
BATCH_SIZE = 32
12+
LEARNING_RATE = 0.001
13+
DROPOUT_RATE = 0.3
14+
EARLY_STOPPING_TRAIN_ACCURACY = 0.995
15+
TF_AUTOTUNE = tf.data.experimental.AUTOTUNE
16+
TF_HUB_MODEL_URL = 'https://tfhub.dev/google/inaturalist/inception_v3/feature_vector/4'
17+
TF_DATASET_NAME = 'oxford_flowers102'
18+
IMAGE_SIZE = (299, 299)
19+
SHUFFLE_BUFFER_SIZE = 473
20+
MODEL_VERSION = '1'
21+
22+
23+
class EarlyStoppingCallback(tf.keras.callbacks.Callback):
24+
def on_epoch_end(self, epoch, logs={}):
25+
if(logs.get('accuracy') > EARLY_STOPPING_TRAIN_ACCURACY):
26+
print(
27+
f"\nEarly stopping at {logs.get('accuracy'):.4f} > {EARLY_STOPPING_TRAIN_ACCURACY}!\n")
28+
self.model.stop_training = True
29+
30+
31+
def parse_args():
32+
parser = argparse.ArgumentParser()
33+
34+
# hyperparameters sent by the client are passed as command-line arguments to the script
35+
parser.add_argument('--epochs', type=int, default=EPOCHS)
36+
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE)
37+
parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE)
38+
39+
# model_dir is always passed in from SageMaker. By default this is a S3 path under the default bucket.
40+
parser.add_argument('--model_dir', type=str)
41+
parser.add_argument('--sm_model_dir', type=str,
42+
default=os.environ.get('SM_MODEL_DIR'))
43+
parser.add_argument('--model_version', type=str, default=MODEL_VERSION)
44+
45+
return parser.parse_known_args()
46+
47+
48+
def set_gpu_memory_growth():
49+
gpus = tf.config.list_physical_devices('GPU')
50+
51+
if gpus:
52+
print("\nGPU Available.")
53+
print(f"Number of GPU: {len(gpus)}")
54+
try:
55+
for gpu in gpus:
56+
tf.config.experimental.set_memory_growth(gpu, True)
57+
print(f"Enabled Memory Growth on {gpu.name}\n")
58+
print()
59+
except RuntimeError as e:
60+
print(e)
61+
62+
print()
63+
64+
65+
def get_datasets(dataset_name):
66+
tfds.disable_progress_bar()
67+
68+
splits = ['test', 'validation', 'train']
69+
splits, ds_info = tfds.load(dataset_name, split=splits, with_info=True)
70+
(ds_train, ds_validation, ds_test) = splits
71+
72+
return (ds_train, ds_validation, ds_test), ds_info
73+
74+
75+
def parse_image(features):
76+
image = features['image']
77+
image = tf.image.resize(image, IMAGE_SIZE) / 255.0
78+
return image, features['label']
79+
80+
81+
def training_pipeline(train_raw, batch_size):
82+
train_preprocessed = train_raw.shuffle(SHUFFLE_BUFFER_SIZE).map(
83+
parse_image, num_parallel_calls=TF_AUTOTUNE).cache().batch(batch_size).prefetch(TF_AUTOTUNE)
84+
85+
return train_preprocessed
86+
87+
88+
def test_pipeline(test_raw, batch_size):
89+
test_preprocessed = test_raw.map(parse_image, num_parallel_calls=TF_AUTOTUNE).cache(
90+
).batch(batch_size).prefetch(TF_AUTOTUNE)
91+
92+
return test_preprocessed
93+
94+
95+
def create_model(train_batches, val_batches, learning_rate):
96+
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
97+
98+
base_model = hub.KerasLayer(TF_HUB_MODEL_URL,
99+
input_shape=IMAGE_SIZE + (3,), trainable=False)
100+
101+
early_stop_callback = EarlyStoppingCallback()
102+
103+
model = tf.keras.Sequential([
104+
base_model,
105+
tf.keras.layers.Dropout(DROPOUT_RATE),
106+
tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
107+
])
108+
109+
model.compile(optimizer=optimizer,
110+
loss='sparse_categorical_crossentropy', metrics=['accuracy'])
111+
112+
model.summary()
113+
114+
model.fit(train_batches, epochs=args.epochs,
115+
validation_data=val_batches,
116+
callbacks=[early_stop_callback])
117+
118+
return model
119+
120+
121+
if __name__ == "__main__":
122+
args, _ = parse_args()
123+
batch_size = args.batch_size
124+
epochs = args.epochs
125+
learning_rate = args.learning_rate
126+
print(
127+
f"\nBatch Size = {batch_size}, Epochs = {epochs}, Learning Rate = {learning_rate}\n")
128+
129+
set_gpu_memory_growth()
130+
131+
(ds_train, ds_validation, ds_test), ds_info = get_datasets(TF_DATASET_NAME)
132+
NUM_CLASSES = ds_info.features['label'].num_classes
133+
134+
print(
135+
f"\nNumber of Training dataset samples: {tf.data.experimental.cardinality(ds_train)}")
136+
print(
137+
f"Number of Validation dataset samples: {tf.data.experimental.cardinality(ds_validation)}")
138+
print(
139+
f"Number of Test dataset samples: {tf.data.experimental.cardinality(ds_test)}")
140+
print(f"Number of Flower Categories: {NUM_CLASSES}\n")
141+
142+
train_batches = training_pipeline(ds_train, batch_size)
143+
validation_batches = test_pipeline(ds_validation, batch_size)
144+
test_batches = test_pipeline(ds_test, batch_size)
145+
146+
model = create_model(train_batches, validation_batches, learning_rate)
147+
eval_results = model.evaluate(test_batches)
148+
149+
for metric, value in zip(model.metrics_names, eval_results):
150+
print(metric + ': {:.4f}'.format(value))
151+
152+
export_path = os.path.join(args.sm_model_dir, args.model_version)
153+
print(
154+
f'\nModel version: {args.model_version} exported to: {export_path}\n')
155+
156+
model.save(export_path)

data/juvchan_flower.jpg

274 KB
Loading

0 commit comments

Comments
 (0)