@@ -290,65 +290,45 @@ def evaluate_detection_batch(
290290
291291 class_id_idx = 4
292292 true_classes = np .array (targets [:, class_id_idx ], dtype = np .int16 )
293- detection_classes = np .array (
294- detection_batch_filtered [:, class_id_idx ], dtype = np .int16
295- )
293+ detection_classes = np .array (detection_batch_filtered [:, class_id_idx ], dtype = np .int16 )
296294 true_boxes = targets [:, :class_id_idx ]
297295 detection_boxes = detection_batch_filtered [:, :class_id_idx ]
298296
299- iou_batch = box_iou_batch (
300- boxes_true = true_boxes , boxes_detection = detection_boxes
301- )
302- # matched_idx = np.asarray(iou_batch > iou_threshold).nonzero()
303-
304- # if matched_idx[0].shape[0]:
305- # matches = np.stack(
306- # (matched_idx[0], matched_idx[1], iou_batch[matched_idx]), axis=1
307- # )
308- # matches = ConfusionMatrix._drop_extra_matches(matches=matches)
309- # else:
310- # matches = np.zeros((0, 3))
311-
312- matched_idx = np .asarray (iou_batch > iou_threshold ).nonzero ()
313-
314- if matched_idx [0 ].shape [0 ]:
315- # Filter matches by class equality
316- valid_matches_mask = (
317- detection_classes [matched_idx [1 ]] == true_classes [matched_idx [0 ]]
318- )
319- if np .any (valid_matches_mask ):
320- valid_true_idx = matched_idx [0 ][valid_matches_mask ]
321- valid_pred_idx = matched_idx [1 ][valid_matches_mask ]
297+ iou_batch = box_iou_batch (boxes_true = true_boxes , boxes_detection = detection_boxes )
322298
323- ious = iou_batch [ valid_true_idx , valid_pred_idx ]
324- matches = np . stack (( valid_true_idx , valid_pred_idx , ious ), axis = 1 )
299+ matched_gt_idx = set ()
300+ matched_det_idx = set ( )
325301
326- # Now drop extra matches with highest IoU per GT/pred
327- matches = ConfusionMatrix ._drop_extra_matches (matches = matches )
328- else :
329- matches = np .zeros ((0 , 3 ))
330- else :
331- matches = np .zeros ((0 , 3 ))
302+ # For each GT, find best matching detection (highest IoU > threshold)
303+ for gt_idx , gt_class in enumerate (true_classes ):
304+ candidate_det_idxs = np .where (iou_batch [gt_idx ] > iou_threshold )[0 ]
332305
333- matched_true_idx , matched_detection_idx , _ = matches .transpose ().astype (
334- np .int16
335- )
306+ if len (candidate_det_idxs ) == 0 :
307+ # No matching detection → FN for this GT
308+ result_matrix [gt_class , num_classes ] += 1
309+ continue
336310
337- for i , true_class_value in enumerate (true_classes ):
338- j = matched_true_idx == i
339- if matches .shape [0 ] > 0 and sum (j ) == 1 :
340- result_matrix [
341- true_class_value , detection_classes [matched_detection_idx [j ]]
342- ] += 1 # TP
311+ best_det_idx = candidate_det_idxs [np .argmax (iou_batch [gt_idx , candidate_det_idxs ])]
312+ det_class = detection_classes [best_det_idx ]
313+
314+ if best_det_idx not in matched_det_idx :
315+ # Count as matched regardless of class:
316+ # same class → TP, different class → misclassification
317+ result_matrix [gt_class , det_class ] += 1
318+ matched_gt_idx .add (gt_idx )
319+ matched_det_idx .add (best_det_idx )
343320 else :
344- result_matrix [true_class_value , num_classes ] += 1 # FN
321+ # Detection already matched, GT is FN
322+ result_matrix [gt_class , num_classes ] += 1
345323
346- for i , detection_class_value in enumerate (detection_classes ):
347- if not any (matched_detection_idx == i ):
348- result_matrix [num_classes , detection_class_value ] += 1 # FP
324+ # unmatched detections are FP
325+ for det_idx , det_class in enumerate (detection_classes ):
326+ if det_idx not in matched_det_idx :
327+ result_matrix [num_classes , det_class ] += 1
349328
350329 return result_matrix
351330
331+
352332 @staticmethod
353333 def _drop_extra_matches (matches : np .ndarray ) -> np .ndarray :
354334 """
0 commit comments