|
1 | 1 | """
|
2 |
| -Metrics computation module for sequence-to-sequence models. |
3 |
| -
|
4 |
| -This module provides a factory function to create a `compute_metrics` callable |
5 |
| -for Hugging Face's `Trainer`. The returned function computes ROUGE-L, BLEU, and |
6 |
| -BERTScore (F1) on decoded model predictions versus labels. |
| 2 | +Module for building a ROUGE-L metric computation function |
| 3 | +for Hugging Face Seq2SeqTrainer. |
7 | 4 | """
|
8 | 5 |
|
9 | 6 | import numpy as np
|
10 | 7 | import evaluate
|
11 |
| -from transformers import EvalPrediction |
12 |
| -from typing import Callable, Dict, Any, Union |
13 |
| -from transformers.tokenization_utils_base import PreTrainedTokenizerBase |
| 8 | +from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, EarlyStoppingCallback, PreTrainedTokenizerBase, EvalPrediction |
| 9 | +from typing import Callable, Dict |
14 | 10 |
|
15 | 11 |
|
16 |
| -def build_compute_metrics( |
17 |
| - tok: PreTrainedTokenizerBase, |
18 |
| - num_process_workers: int = 2 |
19 |
| -) -> Callable[[EvalPrediction], Dict[str, float]]: |
| 12 | +def build_compute_metrics(tok: PreTrainedTokenizerBase) -> Callable[[EvalPrediction], Dict[str, float]]: |
20 | 13 | """
|
21 |
| - Create a metrics computation function for use with Hugging Face `Trainer`. |
| 14 | + Create a compute_metrics function for Seq2SeqTrainer that returns the ROUGE-L score. |
22 | 15 |
|
23 | 16 | Args:
|
24 |
| - tokenizer: A Hugging Face tokenizer for decoding predictions/labels. |
25 |
| - num_process_workers: Number of worker processes for metric computation. |
| 17 | + tok (PreTrainedTokenizerBase): Tokenizer for decoding predictions and labels. |
26 | 18 |
|
27 | 19 | Returns:
|
28 |
| - A callable that takes an `EvalPrediction` and returns a dict with: |
29 |
| - - "rougeL": ROUGE-L score (%) |
30 |
| - - "bleu": BLEU score (%) |
31 |
| - - "bertscore_f1": average BERTScore F1 |
| 20 | + Callable[[EvalPrediction], Dict[str, float]]: Function computing "rougeL" percentage. |
32 | 21 | """
|
33 |
| - rouge = evaluate.load("rouge") # longest-substring overlap |
34 |
| - bleu = evaluate.load("bleu") # n-gram precision |
35 |
| - bertscore = evaluate.load("bertscore") # semantic similarity |
| 22 | + rouge = evaluate.load("rouge", keep_in_memory=True) # keep_in_memory avoids disk I/O |
36 | 23 |
|
37 |
| - def _compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]: |
| 24 | + # 2️⃣ Metric fn: decode → strip → compute → return only rougeL |
| 25 | + def compute_metrics(eval_pred): |
38 | 26 | """
|
39 |
| - Compute ROUGE-L, BLEU, and BERTScore given model predictions and labels. |
| 27 | + Decode predictions and references, compute ROUGE-L, and return as percentage. |
40 | 28 |
|
41 | 29 | Args:
|
42 |
| - eval_pred: An `EvalPrediction` with `predictions` and `label_ids`. |
| 30 | + eval_pred (EvalPrediction): Object with .predictions and .label_ids. |
43 | 31 |
|
44 | 32 | Returns:
|
45 |
| - A dict mapping metric names to rounded scores. |
| 33 | + Dict[str, float]: Dictionary with key "rougeL" and its percentage score. |
46 | 34 | """
|
47 |
| - preds, labels = eval_pred.predictions, eval_pred.label_ids |
48 |
| - |
49 |
| - # handle tuple output (some models return (generated_ids, ...)) |
50 |
| - if isinstance(preds, tuple): |
| 35 | + preds, labels = eval_pred |
| 36 | + if isinstance(preds, tuple): # HF sometimes returns (logits, ...) |
51 | 37 | preds = preds[0]
|
52 | 38 |
|
53 |
| - # decode |
54 |
| - decoded_preds = tok.batch_decode(preds, skip_special_tokens=True) |
| 39 | + # Replace label pad tokens (-100) so they can be decoded |
55 | 40 | labels = np.where(labels != -100, labels, tok.pad_token_id)
|
56 |
| - decoded_labels = tok.batch_decode(labels, skip_special_tokens=True) |
57 | 41 |
|
58 |
| - # metrics |
59 |
| - rouge_l = rouge.compute( |
| 42 | + decoded_preds = tok.batch_decode(preds, skip_special_tokens=True, |
| 43 | + clean_up_tokenization_spaces=True) |
| 44 | + decoded_labels = tok.batch_decode(labels, skip_special_tokens=True, |
| 45 | + clean_up_tokenization_spaces=True) |
| 46 | + |
| 47 | + # Strip white-space/newlines that can hurt ROUGE scores |
| 48 | + decoded_preds = [s.strip() for s in decoded_preds] |
| 49 | + decoded_labels = [s.strip() for s in decoded_labels] |
| 50 | + |
| 51 | + score_dict = rouge.compute( |
60 | 52 | predictions=decoded_preds,
|
61 | 53 | references=decoded_labels,
|
62 |
| - use_stemmer=True, |
63 |
| - num_process_workers=num_process_workers, |
64 |
| - )["rougeL"] |
65 |
| - bleu_score = bleu.compute( |
66 |
| - predictions=decoded_preds, |
67 |
| - references=[[ref] for ref in decoded_labels], # BLEU expects list-of-lists |
68 |
| - smooth=True, |
69 |
| - num_process_workers=num_process_workers, |
70 |
| - )["bleu"] |
71 |
| - bert_f1 = np.mean( |
72 |
| - bertscore.compute( |
73 |
| - predictions=decoded_preds, |
74 |
| - references=decoded_labels, |
75 |
| - lang="en", |
76 |
| - num_process_workers=num_process_workers, |
77 |
| - )["f1"] |
| 54 | + use_stemmer=True, # standard setting for ROUGE-* in HF evaluate |
78 | 55 | )
|
79 | 56 |
|
80 |
| - # round for nice logging |
81 |
| - return { |
82 |
| - "rougeL": round(rouge_l * 100, 4), |
83 |
| - "bleu": round(bleu_score * 100, 4), |
84 |
| - "bertscore_f1": round(bert_f1, 4), |
85 |
| - } |
| 57 | + # HF’s rouge.compute() returns fractional scores; convert to % |
| 58 | + rougeL = round(score_dict["rougeL"] * 100, 4) |
| 59 | + |
| 60 | + return {"rougeL": rougeL} |
| 61 | + |
| 62 | + return compute_metrics |
| 63 | + |
| 64 | + |
86 | 65 |
|
87 |
| - return _compute_metrics |
0 commit comments