@@ -311,15 +311,9 @@ def evaluate_detection_batch(
311
311
)
312
312
true_boxes = targets [:, :class_id_idx ]
313
313
detection_boxes = detection_batch_filtered [:, :class_id_idx ]
314
-
315
- # # Debug: Print IoU calculations
316
- # print("Debug IoU calculations:")
317
- # print(f"GT boxes: {true_boxes}")
318
- # print(f"Detection boxes: {detection_boxes}")
319
314
320
315
# Calculate IoU matrix
321
316
iou_batch = box_iou_batch (boxes_true = true_boxes , boxes_detection = detection_boxes )
322
- # print(f"IoU matrix:\n{iou_batch}")
323
317
324
318
# Find all valid matches (IoU > threshold, regardless of class)
325
319
valid_matches = []
@@ -331,14 +325,10 @@ def evaluate_detection_batch(
331
325
det_class = detection_classes [det_idx ]
332
326
class_match = (gt_class == det_class )
333
327
valid_matches .append ((gt_idx , det_idx , iou , class_match ))
334
- # print(f"Valid match: GT[{gt_idx}] class={gt_class} vs
335
- # Det[{det_idx}] class={det_class}, IoU={iou:.3f},
336
- # class_match={class_match}")
337
328
338
329
# Sort matches by class match first (True before False), then by IoU descending
339
330
# This prioritizes correct class predictions over higher IoU with wrong class
340
331
valid_matches .sort (key = lambda x : (x [3 ], x [2 ]), reverse = True )
341
- # print(f"Sorted matches: {valid_matches}")
342
332
343
333
# Greedily assign matches, ensuring each GT
344
334
# and detection is matched at most once
@@ -350,8 +340,7 @@ def evaluate_detection_batch(
350
340
# Valid spatial match - record the class prediction
351
341
gt_class = true_classes [gt_idx ]
352
342
det_class = detection_classes [det_idx ]
353
- # print(f"Assigning match: GT[{gt_idx}] class={gt_class} ->
354
- # Det[{det_idx}] class={det_class}")
343
+
355
344
# This handles both correct classification (TP) and misclassification
356
345
result_matrix [gt_class , det_class ] += 1
357
346
matched_gt_idx .add (gt_idx )
@@ -360,16 +349,13 @@ def evaluate_detection_batch(
360
349
# Count unmatched ground truth as FN
361
350
for gt_idx , gt_class in enumerate (true_classes ):
362
351
if gt_idx not in matched_gt_idx :
363
- # print(f"Unmatched GT[{gt_idx}] class={gt_class} -> FN")
364
352
result_matrix [gt_class , num_classes ] += 1
365
353
366
354
# Count unmatched detections as FP
367
355
for det_idx , det_class in enumerate (detection_classes ):
368
356
if det_idx not in matched_det_idx :
369
- # print(f"Unmatched Det[{det_idx}] class={det_class} -> FP")
370
357
result_matrix [num_classes , det_class ] += 1
371
-
372
- # print(f"Final matrix:\n{result_matrix}")
358
+
373
359
return result_matrix
374
360
375
361
@staticmethod
0 commit comments