Skip to content

Commit 3fbd23f

Browse files
author
Ubuntu
committed
fixed inference errors
1 parent c500ca0 commit 3fbd23f

File tree

3 files changed

+57
-70
lines changed

3 files changed

+57
-70
lines changed

src/bart_reddit_lora/evaluation.py

Lines changed: 36 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,65 @@
11
"""
2-
Metrics computation module for sequence-to-sequence models.
3-
4-
This module provides a factory function to create a `compute_metrics` callable
5-
for Hugging Face's `Trainer`. The returned function computes ROUGE-L, BLEU, and
6-
BERTScore (F1) on decoded model predictions versus labels.
2+
Module for building a ROUGE-L metric computation function
3+
for Hugging Face Seq2SeqTrainer.
74
"""
85

96
import numpy as np
107
import evaluate
11-
from transformers import EvalPrediction
12-
from typing import Callable, Dict, Any, Union
13-
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
8+
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, EarlyStoppingCallback, PreTrainedTokenizerBase, EvalPrediction
9+
from typing import Callable, Dict
1410

1511

16-
def build_compute_metrics(
17-
tok: PreTrainedTokenizerBase,
18-
num_process_workers: int = 2
19-
) -> Callable[[EvalPrediction], Dict[str, float]]:
12+
def build_compute_metrics(tok: PreTrainedTokenizerBase) -> Callable[[EvalPrediction], Dict[str, float]]:
2013
"""
21-
Create a metrics computation function for use with Hugging Face `Trainer`.
14+
Create a compute_metrics function for Seq2SeqTrainer that returns the ROUGE-L score.
2215
2316
Args:
24-
tokenizer: A Hugging Face tokenizer for decoding predictions/labels.
25-
num_process_workers: Number of worker processes for metric computation.
17+
tok (PreTrainedTokenizerBase): Tokenizer for decoding predictions and labels.
2618
2719
Returns:
28-
A callable that takes an `EvalPrediction` and returns a dict with:
29-
- "rougeL": ROUGE-L score (%)
30-
- "bleu": BLEU score (%)
31-
- "bertscore_f1": average BERTScore F1
20+
Callable[[EvalPrediction], Dict[str, float]]: Function computing "rougeL" percentage.
3221
"""
33-
rouge = evaluate.load("rouge") # longest-substring overlap
34-
bleu = evaluate.load("bleu") # n-gram precision
35-
bertscore = evaluate.load("bertscore") # semantic similarity
22+
rouge = evaluate.load("rouge", keep_in_memory=True) # keep_in_memory avoids disk I/O
3623

37-
def _compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
24+
# 2️⃣ Metric fn: decode → strip → compute → return only rougeL
25+
def compute_metrics(eval_pred):
3826
"""
39-
Compute ROUGE-L, BLEU, and BERTScore given model predictions and labels.
27+
Decode predictions and references, compute ROUGE-L, and return as percentage.
4028
4129
Args:
42-
eval_pred: An `EvalPrediction` with `predictions` and `label_ids`.
30+
eval_pred (EvalPrediction): Object with .predictions and .label_ids.
4331
4432
Returns:
45-
A dict mapping metric names to rounded scores.
33+
Dict[str, float]: Dictionary with key "rougeL" and its percentage score.
4634
"""
47-
preds, labels = eval_pred.predictions, eval_pred.label_ids
48-
49-
# handle tuple output (some models return (generated_ids, ...))
50-
if isinstance(preds, tuple):
35+
preds, labels = eval_pred
36+
if isinstance(preds, tuple): # HF sometimes returns (logits, ...)
5137
preds = preds[0]
5238

53-
# decode
54-
decoded_preds = tok.batch_decode(preds, skip_special_tokens=True)
39+
# Replace label pad tokens (-100) so they can be decoded
5540
labels = np.where(labels != -100, labels, tok.pad_token_id)
56-
decoded_labels = tok.batch_decode(labels, skip_special_tokens=True)
5741

58-
# metrics
59-
rouge_l = rouge.compute(
42+
decoded_preds = tok.batch_decode(preds, skip_special_tokens=True,
43+
clean_up_tokenization_spaces=True)
44+
decoded_labels = tok.batch_decode(labels, skip_special_tokens=True,
45+
clean_up_tokenization_spaces=True)
46+
47+
# Strip white-space/newlines that can hurt ROUGE scores
48+
decoded_preds = [s.strip() for s in decoded_preds]
49+
decoded_labels = [s.strip() for s in decoded_labels]
50+
51+
score_dict = rouge.compute(
6052
predictions=decoded_preds,
6153
references=decoded_labels,
62-
use_stemmer=True,
63-
num_process_workers=num_process_workers,
64-
)["rougeL"]
65-
bleu_score = bleu.compute(
66-
predictions=decoded_preds,
67-
references=[[ref] for ref in decoded_labels], # BLEU expects list-of-lists
68-
smooth=True,
69-
num_process_workers=num_process_workers,
70-
)["bleu"]
71-
bert_f1 = np.mean(
72-
bertscore.compute(
73-
predictions=decoded_preds,
74-
references=decoded_labels,
75-
lang="en",
76-
num_process_workers=num_process_workers,
77-
)["f1"]
54+
use_stemmer=True, # standard setting for ROUGE-* in HF evaluate
7855
)
7956

80-
# round for nice logging
81-
return {
82-
"rougeL": round(rouge_l * 100, 4),
83-
"bleu": round(bleu_score * 100, 4),
84-
"bertscore_f1": round(bert_f1, 4),
85-
}
57+
# HF’s rouge.compute() returns fractional scores; convert to %
58+
rougeL = round(score_dict["rougeL"] * 100, 4)
59+
60+
return {"rougeL": rougeL}
61+
62+
return compute_metrics
63+
64+
8665

87-
return _compute_metrics

src/bart_reddit_lora/inference.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class InferenceArgs:
3535
mode: Either 'test' to evaluate on the test dataset or 'predict' to generate outputs for raw texts.
3636
batch_size: Batch size used for evaluation or prediction.
3737
texts: List of input strings when running in 'predict' mode.
38-
num_process_workers: Number of processes for parallel metric computation.
3938
use_sdpa_attention: Whether to enable SDPA attention for memory efficiency.
4039
"""
4140
mode: str = field(
@@ -51,10 +50,6 @@ class InferenceArgs:
5150
default_factory=list,
5251
metadata={"help": "One or more input texts for `predict` mode."},
5352
)
54-
num_process_workers: int = field(
55-
default=2,
56-
metadata={"help": "Number of workers to parallelize n-gram counting."},
57-
)
5853
use_sdpa_attention: bool = field(
5954
default=True, metadata={"help": "Enable Sdpa for mem-efficient kernel."}
6055
)
@@ -114,19 +109,33 @@ def main() -> None:
114109
args=Seq2SeqTrainingArguments(
115110
output_dir="outputs/inference",
116111
per_device_eval_batch_size=inf_args.batch_size,
117-
predict_with_generate=True,
118-
generation_max_length=384,
112+
predict_with_generate=False,
113+
# generation_max_length=640,
119114
report_to=[],
120115
),
121116
eval_dataset=ds_tok["test"],
122117
data_collator=data_collator,
123118
tokenizer=tok,
124-
compute_metrics=build_compute_metrics(tok, inf_args.num_process_workers),
119+
# compute_metrics=build_compute_metrics(tok),
125120
)
126-
127-
pred_output = trainer.predict(ds_tok["test"])
128-
metrics = pred_output.metrics
121+
metrics = trainer.evaluate(ds_tok["test"])
129122
logger.info(f"Test metrics: {metrics}")
123+
124+
test_loader = trainer.get_eval_dataloader()
125+
model = trainer.model
126+
device = trainer.args.device
127+
128+
losses = []
129+
with torch.no_grad():
130+
for batch in test_loader:
131+
# move inputs → device
132+
batch = {k: v.to(device) for k, v in batch.items()}
133+
# forward pass: passing in labels returns `loss`
134+
outputs = model(**batch)
135+
losses.append(outputs.loss.item())
136+
137+
mean_loss = sum(losses) / len(losses)
138+
logger.info(f"Test loss: {mean_loss:.4f}")
130139

131140
elif inf_args.mode == "predict":
132141
if not inf_args.texts:

src/bart_reddit_lora/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
6363
num_train_epochs: int = 12
6464
per_device_train_batch_size: int = 32
6565
per_device_eval_batch_size: int = 64
66-
learning_rate: float = 3e-5
66+
learning_rate: float = 6e-5
6767
lr_scheduler_type: str = "cosine"
6868
warmup_ratio: float = 0.1
6969
max_grad_norm: float = 0.5

0 commit comments

Comments
 (0)