Skip to content

Commit e56ee44

Browse files
Correct confusion matrix computation
1 parent bbcef84 commit e56ee44

File tree

1 file changed

+27
-47
lines changed

1 file changed

+27
-47
lines changed

supervision/metrics/detection.py

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -290,65 +290,45 @@ def evaluate_detection_batch(
290290

291291
class_id_idx = 4
292292
true_classes = np.array(targets[:, class_id_idx], dtype=np.int16)
293-
detection_classes = np.array(
294-
detection_batch_filtered[:, class_id_idx], dtype=np.int16
295-
)
293+
detection_classes = np.array(detection_batch_filtered[:, class_id_idx], dtype=np.int16)
296294
true_boxes = targets[:, :class_id_idx]
297295
detection_boxes = detection_batch_filtered[:, :class_id_idx]
298296

299-
iou_batch = box_iou_batch(
300-
boxes_true=true_boxes, boxes_detection=detection_boxes
301-
)
302-
# matched_idx = np.asarray(iou_batch > iou_threshold).nonzero()
303-
304-
# if matched_idx[0].shape[0]:
305-
# matches = np.stack(
306-
# (matched_idx[0], matched_idx[1], iou_batch[matched_idx]), axis=1
307-
# )
308-
# matches = ConfusionMatrix._drop_extra_matches(matches=matches)
309-
# else:
310-
# matches = np.zeros((0, 3))
311-
312-
matched_idx = np.asarray(iou_batch > iou_threshold).nonzero()
313-
314-
if matched_idx[0].shape[0]:
315-
# Filter matches by class equality
316-
valid_matches_mask = (
317-
detection_classes[matched_idx[1]] == true_classes[matched_idx[0]]
318-
)
319-
if np.any(valid_matches_mask):
320-
valid_true_idx = matched_idx[0][valid_matches_mask]
321-
valid_pred_idx = matched_idx[1][valid_matches_mask]
297+
iou_batch = box_iou_batch(boxes_true=true_boxes, boxes_detection=detection_boxes)
322298

323-
ious = iou_batch[valid_true_idx, valid_pred_idx]
324-
matches = np.stack((valid_true_idx, valid_pred_idx, ious), axis=1)
299+
matched_gt_idx = set()
300+
matched_det_idx = set()
325301

326-
# Now drop extra matches with highest IoU per GT/pred
327-
matches = ConfusionMatrix._drop_extra_matches(matches=matches)
328-
else:
329-
matches = np.zeros((0, 3))
330-
else:
331-
matches = np.zeros((0, 3))
302+
# For each GT, find best matching detection (highest IoU > threshold)
303+
for gt_idx, gt_class in enumerate(true_classes):
304+
candidate_det_idxs = np.where(iou_batch[gt_idx] > iou_threshold)[0]
332305

333-
matched_true_idx, matched_detection_idx, _ = matches.transpose().astype(
334-
np.int16
335-
)
306+
if len(candidate_det_idxs) == 0:
307+
# No matching detection → FN for this GT
308+
result_matrix[gt_class, num_classes] += 1
309+
continue
336310

337-
for i, true_class_value in enumerate(true_classes):
338-
j = matched_true_idx == i
339-
if matches.shape[0] > 0 and sum(j) == 1:
340-
result_matrix[
341-
true_class_value, detection_classes[matched_detection_idx[j]]
342-
] += 1 # TP
311+
best_det_idx = candidate_det_idxs[np.argmax(iou_batch[gt_idx, candidate_det_idxs])]
312+
det_class = detection_classes[best_det_idx]
313+
314+
if best_det_idx not in matched_det_idx:
315+
# Count as matched regardless of class:
316+
# same class → TP, different class → misclassification
317+
result_matrix[gt_class, det_class] += 1
318+
matched_gt_idx.add(gt_idx)
319+
matched_det_idx.add(best_det_idx)
343320
else:
344-
result_matrix[true_class_value, num_classes] += 1 # FN
321+
# Detection already matched, GT is FN
322+
result_matrix[gt_class, num_classes] += 1
345323

346-
for i, detection_class_value in enumerate(detection_classes):
347-
if not any(matched_detection_idx == i):
348-
result_matrix[num_classes, detection_class_value] += 1 # FP
324+
# unmatched detections are FP
325+
for det_idx, det_class in enumerate(detection_classes):
326+
if det_idx not in matched_det_idx:
327+
result_matrix[num_classes, det_class] += 1
349328

350329
return result_matrix
351330

331+
352332
@staticmethod
353333
def _drop_extra_matches(matches: np.ndarray) -> np.ndarray:
354334
"""

0 commit comments

Comments
 (0)