38
38
print_writeout ,
39
39
run_task_tests ,
40
40
)
41
- from lm_eval .logging_utils import add_env_info , get_git_commit_hash
41
+ from lm_eval .loggers import add_env_info , get_git_commit_hash
42
42
from lm_eval .tasks import TaskManager , get_task_dict
43
43
from lm_eval .utils import eval_logger , positional_deprecated , simple_parse_args_string
44
44
from lm_eval import utils
@@ -509,9 +509,14 @@ def evaluate(
509
509
# aggregate results ; run bootstrap CIs
510
510
for task_output in eval_tasks :
511
511
task_output .calculate_aggregate_metric (bootstrap_iters = bootstrap_iters )
512
- results , samples , configs , versions , num_fewshot = consolidate_results (
513
- eval_tasks
514
- )
512
+ (
513
+ results ,
514
+ samples ,
515
+ configs ,
516
+ versions ,
517
+ num_fewshot ,
518
+ higher_is_better ,
519
+ ) = consolidate_results (eval_tasks )
515
520
516
521
### Calculate group metrics ###
517
522
if bool (results ):
@@ -522,6 +527,23 @@ def evaluate(
522
527
# or `task_name: []`.
523
528
# we only want to operate on groups here.
524
529
continue
530
+
531
+ # collect all higher_is_better values for metrics
532
+ # in the group's subtasks.
533
+ # TODO: clean this up ; unify with the below metric_list loop?
534
+ _higher_is_better = {}
535
+ for task in task_list :
536
+ for m , h in higher_is_better [task ].items ():
537
+ if m not in _higher_is_better .keys ():
538
+ _higher_is_better [m ] = h
539
+ if m in _higher_is_better and _higher_is_better [m ] is not None and _higher_is_better [m ] != h :
540
+ eval_logger .warning (
541
+ f"Higher_is_better values for metric { m } in group { group } are not consistent. Defaulting to None."
542
+ )
543
+ _higher_is_better [m ] = None
544
+ higher_is_better [group ] = _higher_is_better
545
+
546
+ # collect all metric keys used by a subtask in the group.
525
547
metric_list = list (
526
548
{
527
549
key
@@ -534,38 +556,20 @@ def evaluate(
534
556
stderr = "_stderr," .join (metric .split ("," ))
535
557
536
558
# gather metrics, sizes, and stderrs from subtasks
537
- metrics = [
538
- results [task ][metric ]
539
- for task in task_list
540
- if metric in results [task ]
541
- ] # TODO: copy?
542
- stderrs = [
543
- results [task ][stderr ]
544
- for task in task_list
545
- if stderr in results [task ]
546
- ]
547
- sizes = [
548
- results [task ]["samples" ]
549
- for task in task_list
550
- if metric in results [task ]
551
- ]
559
+ metrics = [results [task ][metric ] for task in task_list if metric in results [task ]] # TODO: copy?
560
+ stderrs = [results [task ][stderr ] for task in task_list if stderr in results [task ]]
561
+ sizes = [results [task ]["samples" ] for task in task_list if metric in results [task ]]
552
562
553
563
# compute group's pooled metric and stderr
554
- results [group ][metric ] = (
555
- lm_eval .api .metrics .aggregate_subtask_metrics (metrics , sizes )
556
- )
564
+ results [group ][metric ] = lm_eval .api .metrics .aggregate_subtask_metrics (metrics , sizes )
557
565
# TODO: calculate grouped metric using aggregation fn
558
566
if "N/A" in stderrs :
559
567
results [group ][stderr ] = "N/A"
560
568
else :
561
- results [group ][stderr ] = (
562
- lm_eval .api .metrics .pooled_sample_stderr (stderrs , sizes )
563
- )
569
+ results [group ][stderr ] = lm_eval .api .metrics .pooled_sample_stderr (stderrs , sizes )
564
570
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
565
- # To use the old (likely incorrect) variance formula,
566
- # comment out the above and uncomment this line:
567
- # results[group][stderr] = \
568
- # lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
571
+ # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
572
+ # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
569
573
570
574
results [group ]["samples" ] = sum (sizes )
571
575
@@ -578,19 +582,15 @@ def evaluate(
578
582
if len (left_tasks_list ) == 0 :
579
583
break
580
584
581
- _task_hierarchy = {
582
- k : v for k , v in task_hierarchy .items () if k in left_tasks_list
583
- }
585
+ _task_hierarchy = {k : v for k , v in task_hierarchy .items () if k in left_tasks_list }
584
586
_results_agg , _groups_agg = prepare_print_tasks (_task_hierarchy , results )
585
587
586
588
results_agg = {** results_agg , ** _results_agg }
587
589
groups_agg = {** groups_agg , ** _groups_agg }
588
590
589
591
for group_name , task_list in task_hierarchy .items ():
590
592
if task_list :
591
- num_fewshot [group_name ] = num_fewshot [
592
- task_list [0 ]
593
- ] # TODO: validate this
593
+ num_fewshot [group_name ] = num_fewshot [task_list [0 ]] # TODO: validate this
594
594
595
595
results_dict = {
596
596
"results" : dict (results_agg .items ()),
@@ -599,6 +599,17 @@ def evaluate(
599
599
"configs" : dict (sorted (configs .items ())),
600
600
"versions" : dict (sorted (versions .items ())),
601
601
"n-shot" : dict (sorted (num_fewshot .items ())),
602
+ "higher_is_better" : dict (sorted (higher_is_better .items ())),
603
+ "n-samples" : {
604
+ task_output .task_name : {
605
+ "original" : len (task_output .task .eval_docs ),
606
+ "effective" : min (
607
+ limit if limit else len (task_output .task .eval_docs ),
608
+ len (task_output .task .eval_docs ),
609
+ ),
610
+ }
611
+ for task_output in eval_tasks
612
+ },
602
613
}
603
614
if log_samples :
604
615
results_dict ["samples" ] = dict (samples )
@@ -608,7 +619,6 @@ def evaluate(
608
619
else :
609
620
return None
610
621
611
-
612
622
def request_caching_arg_to_dict (cache_requests : str ) -> dict :
613
623
request_caching_args = {
614
624
"cache_requests" : cache_requests in {"true" , "refresh" },
0 commit comments