@@ -290,18 +290,26 @@ def evaluate_detection_batch(
290
290
291
291
class_id_idx = 4
292
292
true_classes = np .array (targets [:, class_id_idx ], dtype = np .int16 )
293
- detection_classes = np .array (detection_batch_filtered [:, class_id_idx ], dtype = np .int16 )
293
+ detection_classes = np .array (
294
+ detection_batch_filtered [:, class_id_idx ], dtype = np .int16
295
+ )
294
296
true_boxes = targets [:, :class_id_idx ]
295
297
detection_boxes = detection_batch_filtered [:, :class_id_idx ]
296
298
297
- iou_batch = box_iou_batch (boxes_true = true_boxes , boxes_detection = detection_boxes )
299
+ iou_batch = box_iou_batch (
300
+ boxes_true = true_boxes , boxes_detection = detection_boxes
301
+ )
302
+ # matched_idx = np.asarray(iou_batch > iou_threshold).nonzero()
298
303
299
- matched_gt_idx = set ()
300
- matched_det_idx = set ()
304
+ # if matched_idx[0].shape[0]:
305
+ # matches = np.stack(
306
+ # (matched_idx[0], matched_idx[1], iou_batch[matched_idx]), axis=1
307
+ # )
308
+ # matches = ConfusionMatrix._drop_extra_matches(matches=matches)
309
+ # else:
310
+ # matches = np.zeros((0, 3))
301
311
302
- # For each GT, find best matching detection (highest IoU > threshold)
303
- for gt_idx , gt_class in enumerate (true_classes ):
304
- candidate_det_idxs = np .where (iou_batch [gt_idx ] > iou_threshold )[0 ]
312
+ matched_idx = np .asarray (iou_batch > iou_threshold ).nonzero ()
305
313
306
314
if matched_idx [0 ].shape [0 ]:
307
315
# Filter matches by class equality
@@ -312,18 +320,15 @@ def evaluate_detection_batch(
312
320
valid_true_idx = matched_idx [0 ][valid_matches_mask ]
313
321
valid_pred_idx = matched_idx [1 ][valid_matches_mask ]
314
322
315
- best_det_idx = candidate_det_idxs [ np . argmax ( iou_batch [gt_idx , candidate_det_idxs ]) ]
316
- det_class = detection_classes [ best_det_idx ]
323
+ ious = iou_batch [valid_true_idx , valid_pred_idx ]
324
+ matches = np . stack (( valid_true_idx , valid_pred_idx , ious ), axis = 1 )
317
325
318
- if best_det_idx not in matched_det_idx :
319
- # Count as matched regardless of class:
320
- # same class → TP, different class → misclassification
321
- result_matrix [gt_class , det_class ] += 1
322
- matched_gt_idx .add (gt_idx )
323
- matched_det_idx .add (best_det_idx )
326
+ # Now drop extra matches with highest IoU per GT/pred
327
+ matches = ConfusionMatrix ._drop_extra_matches (matches = matches )
324
328
else :
325
- # Detection already matched, GT is FN
326
- result_matrix [gt_class , num_classes ] += 1
329
+ matches = np .zeros ((0 , 3 ))
330
+ else :
331
+ matches = np .zeros ((0 , 3 ))
327
332
328
333
matched_true_idx , matched_detection_idx , _ = matches .transpose ().astype (
329
334
np .int16
@@ -344,7 +349,6 @@ def evaluate_detection_batch(
344
349
345
350
return result_matrix
346
351
347
-
348
352
@staticmethod
349
353
def _drop_extra_matches (matches : np .ndarray ) -> np .ndarray :
350
354
"""
0 commit comments