Skip to content

Commit c500ca0

Browse files
author
Ubuntu
committed
matched inference .generate max_length to that of data.py truncation limit
1 parent 3ee0d56 commit c500ca0

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/bart_reddit_lora/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def main() -> None:
143143
# fast batched generate (with arguments for higher quality generations)
144144
out = model.generate(
145145
**enc,
146-
max_length=500,
146+
max_length=128,
147147
num_beams=5, # improves quality
148148
do_sample=True, # add stochasticity
149149
length_penalty=1.2, # >1 favors longer answers

src/bart_reddit_lora/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
6363
num_train_epochs: int = 12
6464
per_device_train_batch_size: int = 32
6565
per_device_eval_batch_size: int = 64
66-
learning_rate: float = 6e-5
66+
learning_rate: float = 3e-5
6767
lr_scheduler_type: str = "cosine"
6868
warmup_ratio: float = 0.1
6969
max_grad_norm: float = 0.5

0 commit comments

Comments
 (0)