-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmodel.py
402 lines (349 loc) · 16 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import tensorflow as tf
import numpy as np
import math
from utils import convert_to_corners, compute_iou
from data_processing import resize_and_pad_image
from tensorflow import keras
def get_backbone(name="resnet50", weight=None):
"""Supported backbone: resnet50, resnet101, densenet121"""
backbone = None
if "resnet" in name:
if name == "resnet50":
backbone = keras.applications.ResNet50
elif name == "resnet101":
backbone = keras.applications.ResNet101
output_layers = ["conv3_block4_out", "conv4_block6_out", "conv5_block3_out"]
elif "densenet" in name:
if name == "densenet121":
backbone = keras.applications.DenseNet121
output_layers = ["pool3_conv", "pool4_conv", "relu"]
backbone_model = backbone(include_top=False, input_shape=[None, None, 3], weights=weight)
c3_output, c4_output, c5_output = [
backbone_model.get_layer(layer_name).output
for layer_name in output_layers
]
return keras.Model(
inputs=[backbone_model.inputs], outputs=[c3_output, c4_output, c5_output]
)
class FeaturePyramid(keras.layers.Layer):
"""Builds the Feature Pyramid with the feature maps from the backbone.
Attributes:
num_classes: Number of classes in the dataset.
backbone: The backbone to build the feature pyramid from.
Currently supports ResNet50 only.
"""
def __init__(self, backbone="resnet50", weight=None, **kwargs):
super(FeaturePyramid, self).__init__(name="FeaturePyramid", **kwargs)
self.backbone = get_backbone(backbone, weight)
self.conv_c3_1x1 = keras.layers.Conv2D(256, 1, 1, "same")
self.conv_c4_1x1 = keras.layers.Conv2D(256, 1, 1, "same")
self.conv_c5_1x1 = keras.layers.Conv2D(256, 1, 1, "same")
self.conv_c3_3x3 = keras.layers.Conv2D(256, 3, 1, "same")
self.conv_c4_3x3 = keras.layers.Conv2D(256, 3, 1, "same")
self.conv_c5_3x3 = keras.layers.Conv2D(256, 3, 1, "same")
self.conv_c6_3x3 = keras.layers.Conv2D(256, 3, 2, "same")
self.conv_c7_3x3 = keras.layers.Conv2D(256, 3, 2, "same")
def call(self, images, training=False):
c3_output, c4_output, c5_output = self.backbone(images, training=training)
p3_output = self.conv_c3_1x1(c3_output)
p4_output = self.conv_c4_1x1(c4_output)
p5_output = self.conv_c5_1x1(c5_output)
p4_output = p4_output + keras.layers.UpSampling2D(2)(p5_output)
p3_output = p3_output + keras.layers.UpSampling2D(2)(p4_output)
p3_output = self.conv_c3_3x3(p3_output)
p4_output = self.conv_c4_3x3(p4_output)
p5_output = self.conv_c5_3x3(p5_output)
p6_output = self.conv_c6_3x3(c5_output)
p7_output = self.conv_c7_3x3(tf.nn.relu(p6_output))
return p3_output, p4_output, p5_output, p6_output, p7_output
def build_head(output_filters, bias_init):
"""Builds the class/box predictions head.
Arguments:
output_filters: Number of convolution filters in the final layer.
bias_init: Bias Initializer for the final convolution layer.
Returns:
A keras sequential model representing either the classification
or the box regression head depending on `output_filters`.
"""
head = keras.Sequential([keras.Input(shape=[None, None, 256])])
kernel_init = tf.initializers.RandomNormal(0.0, 0.01)
for _ in range(4):
head.add(
keras.layers.Conv2D(256, 3, padding="same", kernel_initializer=kernel_init)
)
head.add(keras.layers.ReLU())
head.add(
keras.layers.Conv2D(
output_filters,
3,
1,
padding="same",
kernel_initializer=kernel_init,
bias_initializer=bias_init,
)
)
return head
class RetinaNet(keras.Model):
"""A subclassed Keras model implementing the RetinaNet architecture.
Attributes:
num_classes: Number of classes in the dataset.
backbone: The backbone to build the feature pyramid from.
Currently supports ResNet50 only.
"""
def __init__(self, num_classes, backbone=None, weight=None, **kwargs):
super(RetinaNet, self).__init__(name="RetinaNet", **kwargs)
self.backbone_name = backbone
self.fpn = FeaturePyramid(backbone, weight)
self.num_classes = num_classes
prior_probability = tf.constant_initializer(-np.log((1 - 0.01) / 0.01))
self.cls_head = build_head(9 * num_classes, prior_probability)
self.box_head = build_head(9 * 4, "zeros")
def call(self, image, training=False):
features = self.fpn(image, training=True)
N = tf.shape(image)[0]
cls_outputs = []
box_outputs = []
for feature in features:
box_outputs.append(tf.reshape(self.box_head(feature), [N, -1, 4]))
cls_outputs.append(
tf.reshape(self.cls_head(feature), [N, -1, self.num_classes])
)
cls_outputs = tf.concat(cls_outputs, axis=1)
box_outputs = tf.concat(box_outputs, axis=1)
return tf.concat([box_outputs, cls_outputs], axis=-1)
class AnchorBox:
"""Generates anchor boxes.
This class has operations to generate anchor boxes for feature maps at
strides `[8, 16, 32, 64, 128]`. Where each anchor each box is of the
format `[x, y, width, height]`.
Attributes:
aspect_ratios: A list of float values representing the aspect ratios of
the anchor boxes at each location on the feature map
scales: A list of float values representing the scale of the anchor boxes
at each location on the feature map.
num_anchors: The number of anchor boxes at each location on feature map
areas: A list of float values representing the areas of the anchor
boxes for each feature map in the feature pyramid.
strides: A list of float value representing the strides for each feature
map in the feature pyramid.
"""
def __init__(self):
self.aspect_ratios = [0.5, 1.0, 2.0]
self.scales = [2 ** x for x in [0, 1 / 3, 2 / 3]]
self._num_anchors = len(self.aspect_ratios) * len(self.scales)
self._strides = [2 ** i for i in range(3, 8)]
self._areas = [x ** 2 for x in [32.0, 64.0, 128.0, 256.0, 512.0]]
self._anchor_dims = self._compute_dims()
def _compute_dims(self):
"""Computes anchor box dimensions for all ratios and scales at all levels
of the feature pyramid.
"""
anchor_dims_all = []
for area in self._areas:
anchor_dims = []
for ratio in self.aspect_ratios:
anchor_height = tf.math.sqrt(area / ratio)
anchor_width = area / anchor_height
dims = tf.reshape(
tf.stack([anchor_width, anchor_height], axis=-1), [1, 1, 2]
)
for scale in self.scales:
anchor_dims.append(scale * dims)
anchor_dims_all.append(tf.stack(anchor_dims, axis=-2))
return anchor_dims_all
def _get_anchors(self, feature_height, feature_width, level):
"""Generates anchor boxes for a given feature map size and level
Arguments:
feature_height: An integer representing the height of the feature map.
feature_width: An integer representing the width of the feature map.
level: An integer representing the level of the feature map in the
feature pyramid.
Returns:
anchor boxes with the shape
`(feature_height * feature_width * num_anchors, 4)`
"""
rx = tf.range(feature_width, dtype=tf.float32) + 0.5
ry = tf.range(feature_height, dtype=tf.float32) + 0.5
centers = tf.stack(tf.meshgrid(rx, ry), axis=-1) * self._strides[level - 3]
centers = tf.expand_dims(centers, axis=-2)
centers = tf.tile(centers, [1, 1, self._num_anchors, 1])
dims = tf.tile(
self._anchor_dims[level - 3], [feature_height, feature_width, 1, 1]
)
anchors = tf.concat([centers, dims], axis=-1)
return tf.reshape(
anchors, [feature_height * feature_width * self._num_anchors, 4]
)
def get_anchors(self, image_height, image_width):
"""Generates anchor boxes for all the feature maps of the feature pyramid.
Arguments:
image_height: Height of the input image.
image_width: Width of the input image.
Returns:
anchor boxes for all the feature maps, stacked as a single tensor
with shape `(total_anchors, 4)`
"""
anchors = [
self._get_anchors(
tf.math.ceil(image_height / 2 ** i),
tf.math.ceil(image_width / 2 ** i),
i,
)
for i in range(3, 8)
]
return tf.concat(anchors, axis=0)
class DecodePredictions(tf.keras.layers.Layer):
"""A Keras layer that decodes predictions of the RetinaNet model.
Attributes:
num_classes: Number of classes in the dataset
confidence_threshold: Minimum class probability, below which detections
are pruned.
nms_iou_threshold: IOU threshold for the NMS operation
max_detections_per_class: Maximum number of detections to retain per
class.
max_detections: Maximum number of detections to retain across all
classes.
box_variance: The scaling factors used to scale the bounding box
predictions.
"""
def __init__(
self,
num_classes=80,
confidence_threshold=0.05,
nms_iou_threshold=0.5,
max_detections_per_class=100,
max_detections=100,
box_variance=[0.1, 0.1, 0.2, 0.2],
verbose=0,
**kwargs
):
super(DecodePredictions, self).__init__(**kwargs)
self.num_classes = num_classes
self.verbose = verbose
self.confidence_threshold = confidence_threshold
self.nms_iou_threshold = nms_iou_threshold
self.max_detections_per_class = max_detections_per_class
self.max_detections = max_detections
self._anchor_box = AnchorBox()
self._box_variance = tf.convert_to_tensor(
[0.1, 0.1, 0.2, 0.2], dtype=tf.float32
)
def _decode_box_predictions(self, anchor_boxes, box_predictions):
boxes = box_predictions * self._box_variance
boxes = tf.concat(
[
boxes[:, :, :2] * anchor_boxes[:, :, 2:] + anchor_boxes[:, :, :2],
tf.math.exp(boxes[:, :, 2:]) * anchor_boxes[:, :, 2:],
],
axis=-1,
)
boxes_transformed = convert_to_corners(boxes)
return boxes_transformed
def call(self, images, predictions):
image_shape = tf.cast(tf.shape(images), dtype=tf.float32)
anchor_boxes = self._anchor_box.get_anchors(image_shape[1], image_shape[2])
box_predictions = predictions[:, :, :4]
cls_predictions = tf.nn.sigmoid(predictions[:, :, 4:])
boxes = self._decode_box_predictions(anchor_boxes[None, ...], box_predictions)
return tf.image.combined_non_max_suppression(
tf.expand_dims(boxes, axis=2),
cls_predictions,
self.max_detections_per_class,
self.max_detections,
self.nms_iou_threshold,
self.confidence_threshold,
clip_boxes=False,
)
class LabelEncoder:
"""Transforms the raw labels into targets for training.
This class has operations to generate targets for a batch of samples which
is made up of the input images, bounding boxes for the objects present and
their class ids.
Attributes:
anchor_box: Anchor box generator to encode the bounding boxes.
box_variance: The scaling factors used to scale the bounding box targets.
"""
def __init__(self):
self._anchor_box = AnchorBox()
self._box_variance = tf.convert_to_tensor(
[0.1, 0.1, 0.2, 0.2], dtype=tf.float32
)
def _match_anchor_boxes(
self, anchor_boxes, gt_boxes, match_iou=0.5, ignore_iou=0.4
):
"""Matches ground truth boxes to anchor boxes based on IOU.
1. Calculates the pairwise IOU for the M `anchor_boxes` and N `gt_boxes`
to get a `(M, N)` shaped matrix.
2. The ground truth box with the maximum IOU in each row is assigned to
the anchor box provided the IOU is greater than `match_iou`.
3. If the maximum IOU in a row is less than `ignore_iou`, the anchor
box is assigned with the background class.
4. The remaining anchor boxes that do not have any class assigned are
ignored during training.
Arguments:
anchor_boxes: A float tensor with the shape `(total_anchors, 4)`
representing all the anchor boxes for a given input image shape,
where each anchor box is of the format `[x, y, width, height]`.
gt_boxes: A float tensor with shape `(num_objects, 4)` representing
the ground truth boxes, where each box is of the format
`[x, y, width, height]`.
match_iou: A float value representing the minimum IOU threshold for
determining if a ground truth box can be assigned to an anchor box.
ignore_iou: A float value representing the IOU threshold under which
an anchor box is assigned to the background class.
Returns:
matched_gt_idx: Index of the matched object
positive_mask: A mask for anchor boxes that have been assigned ground
truth boxes.
ignore_mask: A mask for anchor boxes that need to by ignored during
training
"""
iou_matrix = compute_iou(anchor_boxes, gt_boxes)
max_iou = tf.reduce_max(iou_matrix, axis=1)
matched_gt_idx = tf.argmax(iou_matrix, axis=1)
positive_mask = tf.greater_equal(max_iou, match_iou)
negative_mask = tf.less(max_iou, ignore_iou)
ignore_mask = tf.logical_not(tf.logical_or(positive_mask, negative_mask))
return (
matched_gt_idx,
tf.cast(positive_mask, dtype=tf.float32),
tf.cast(ignore_mask, dtype=tf.float32),
)
def _compute_box_target(self, anchor_boxes, matched_gt_boxes):
"""Transforms the ground truth boxes into targets for training"""
box_target = tf.concat(
[
(matched_gt_boxes[:, :2] - anchor_boxes[:, :2]) / anchor_boxes[:, 2:],
tf.math.log(matched_gt_boxes[:, 2:] / anchor_boxes[:, 2:]),
],
axis=-1,
)
box_target = box_target / self._box_variance
return box_target
def _encode_sample(self, image_shape, gt_boxes, cls_ids):
"""Creates box and classification targets for a single sample"""
anchor_boxes = self._anchor_box.get_anchors(image_shape[1], image_shape[2])
cls_ids = tf.cast(cls_ids, dtype=tf.float32)
matched_gt_idx, positive_mask, ignore_mask = self._match_anchor_boxes(
anchor_boxes, gt_boxes
)
matched_gt_boxes = tf.gather(gt_boxes, matched_gt_idx)
box_target = self._compute_box_target(anchor_boxes, matched_gt_boxes)
matched_gt_cls_ids = tf.gather(cls_ids, matched_gt_idx)
cls_target = tf.where(
tf.not_equal(positive_mask, 1.0), -1.0, matched_gt_cls_ids
)
cls_target = tf.where(tf.equal(ignore_mask, 1.0), -2.0, cls_target)
cls_target = tf.expand_dims(cls_target, axis=-1)
label = tf.concat([box_target, cls_target], axis=-1)
return label
def encode_batch(self, batch_images, gt_boxes, cls_ids):
"""Creates box and classification targets for a batch"""
images_shape = tf.shape(batch_images)
batch_size = images_shape[0]
labels = tf.TensorArray(dtype=tf.float32, size=batch_size, dynamic_size=True)
for i in range(batch_size):
label = self._encode_sample(images_shape, gt_boxes[i], cls_ids[i])
labels = labels.write(i, label)
batch_images = tf.keras.applications.resnet.preprocess_input(batch_images)
return batch_images, labels.stack()