-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Correct confusion matrix calculation-function evaluate_detection_batch #1853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 6 commits
16c1070
0468b9a
0d42787
bbcef84
e56ee44
6d6b6dc
71656bf
65411ae
8d2354a
c8c9783
2cdd991
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -299,32 +299,38 @@ def evaluate_detection_batch( | |
iou_batch = box_iou_batch( | ||
boxes_true=true_boxes, boxes_detection=detection_boxes | ||
) | ||
matched_idx = np.asarray(iou_batch > iou_threshold).nonzero() | ||
|
||
if matched_idx[0].shape[0]: | ||
matches = np.stack( | ||
(matched_idx[0], matched_idx[1], iou_batch[matched_idx]), axis=1 | ||
) | ||
matches = ConfusionMatrix._drop_extra_matches(matches=matches) | ||
else: | ||
matches = np.zeros((0, 3)) | ||
matched_gt_idx = set() | ||
matched_det_idx = set() | ||
|
||
matched_true_idx, matched_detection_idx, _ = matches.transpose().astype( | ||
np.int16 | ||
) | ||
# For each GT, find best matching detection (highest IoU > threshold) | ||
for gt_idx, gt_class in enumerate(true_classes): | ||
candidate_det_idxs = np.where(iou_batch[gt_idx] > iou_threshold)[0] | ||
|
||
for i, true_class_value in enumerate(true_classes): | ||
j = matched_true_idx == i | ||
if matches.shape[0] > 0 and sum(j) == 1: | ||
result_matrix[ | ||
true_class_value, detection_classes[matched_detection_idx[j]] | ||
] += 1 # TP | ||
if len(candidate_det_idxs) == 0: | ||
# No matching detection → FN for this GT | ||
result_matrix[gt_class, num_classes] += 1 | ||
continue | ||
|
||
best_det_idx = candidate_det_idxs[ | ||
np.argmax(iou_batch[gt_idx, candidate_det_idxs]) | ||
] | ||
det_class = detection_classes[best_det_idx] | ||
|
||
if best_det_idx not in matched_det_idx: | ||
# Count as matched regardless of class: | ||
# same class → TP, different class → misclassification | ||
result_matrix[gt_class, det_class] += 1 | ||
matched_gt_idx.add(gt_idx) | ||
matched_det_idx.add(best_det_idx) | ||
else: | ||
result_matrix[true_class_value, num_classes] += 1 # FN | ||
# Detection already matched, GT is FN | ||
result_matrix[gt_class, num_classes] += 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that this logic iterates through ground truth boxes and for each one finds the best-matching detection box, i.e, the one with highest IoU above the threshold, that hasn't been matched yet. The issue with this logic is that the matching process depends on the order of the ground truth boxes in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing that out, @soumik12345 — you're absolutely right about the issue with order-dependent matching in the original logic. To address this, the updated implementation builds a full IoU matrix between all ground truth and detection boxes, then collects all valid matches (IoU above threshold), and sorts them globally — prioritizing class-correct matches first, then by highest IoU. This removes any dependency on the order of the ground truth boxes. We then greedily assign matches while ensuring each GT and detection is only matched once, which avoids conflicts where multiple GTs compete for a single detection. This approach ensures that:
I have written some test cases that help me correct the logic, at the beginning i was failing most of them, but now all of them are passed.
If you find any other issues i'll be happy to address them! |
||
|
||
for i, detection_class_value in enumerate(detection_classes): | ||
if not any(matched_detection_idx == i): | ||
result_matrix[num_classes, detection_class_value] += 1 # FP | ||
# unmatched detections are FP | ||
for det_idx, det_class in enumerate(detection_classes): | ||
if det_idx not in matched_det_idx: | ||
result_matrix[num_classes, det_class] += 1 | ||
|
||
return result_matrix | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The selection of the best match is happening based solely on IoU, which means a wrong-class prediction can still be chosen over a right-class one if it has a higher IoU.