Skip to content

Commit bbcef84

Browse files
Update
1 parent 0d42787 commit bbcef84

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

supervision/metrics/detection.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -290,18 +290,26 @@ def evaluate_detection_batch(
290290

291291
class_id_idx = 4
292292
true_classes = np.array(targets[:, class_id_idx], dtype=np.int16)
293-
detection_classes = np.array(detection_batch_filtered[:, class_id_idx], dtype=np.int16)
293+
detection_classes = np.array(
294+
detection_batch_filtered[:, class_id_idx], dtype=np.int16
295+
)
294296
true_boxes = targets[:, :class_id_idx]
295297
detection_boxes = detection_batch_filtered[:, :class_id_idx]
296298

297-
iou_batch = box_iou_batch(boxes_true=true_boxes, boxes_detection=detection_boxes)
299+
iou_batch = box_iou_batch(
300+
boxes_true=true_boxes, boxes_detection=detection_boxes
301+
)
302+
# matched_idx = np.asarray(iou_batch > iou_threshold).nonzero()
298303

299-
matched_gt_idx = set()
300-
matched_det_idx = set()
304+
# if matched_idx[0].shape[0]:
305+
# matches = np.stack(
306+
# (matched_idx[0], matched_idx[1], iou_batch[matched_idx]), axis=1
307+
# )
308+
# matches = ConfusionMatrix._drop_extra_matches(matches=matches)
309+
# else:
310+
# matches = np.zeros((0, 3))
301311

302-
# For each GT, find best matching detection (highest IoU > threshold)
303-
for gt_idx, gt_class in enumerate(true_classes):
304-
candidate_det_idxs = np.where(iou_batch[gt_idx] > iou_threshold)[0]
312+
matched_idx = np.asarray(iou_batch > iou_threshold).nonzero()
305313

306314
if matched_idx[0].shape[0]:
307315
# Filter matches by class equality
@@ -312,18 +320,15 @@ def evaluate_detection_batch(
312320
valid_true_idx = matched_idx[0][valid_matches_mask]
313321
valid_pred_idx = matched_idx[1][valid_matches_mask]
314322

315-
best_det_idx = candidate_det_idxs[np.argmax(iou_batch[gt_idx, candidate_det_idxs])]
316-
det_class = detection_classes[best_det_idx]
323+
ious = iou_batch[valid_true_idx, valid_pred_idx]
324+
matches = np.stack((valid_true_idx, valid_pred_idx, ious), axis=1)
317325

318-
if best_det_idx not in matched_det_idx:
319-
# Count as matched regardless of class:
320-
# same class → TP, different class → misclassification
321-
result_matrix[gt_class, det_class] += 1
322-
matched_gt_idx.add(gt_idx)
323-
matched_det_idx.add(best_det_idx)
326+
# Now drop extra matches with highest IoU per GT/pred
327+
matches = ConfusionMatrix._drop_extra_matches(matches=matches)
324328
else:
325-
# Detection already matched, GT is FN
326-
result_matrix[gt_class, num_classes] += 1
329+
matches = np.zeros((0, 3))
330+
else:
331+
matches = np.zeros((0, 3))
327332

328333
matched_true_idx, matched_detection_idx, _ = matches.transpose().astype(
329334
np.int16
@@ -344,7 +349,6 @@ def evaluate_detection_batch(
344349

345350
return result_matrix
346351

347-
348352
@staticmethod
349353
def _drop_extra_matches(matches: np.ndarray) -> np.ndarray:
350354
"""

0 commit comments

Comments
 (0)