From 585d60f39c45d5bf51412b82403885a4d1a09f8d Mon Sep 17 00:00:00 2001 From: peterjc123 Date: Fri, 19 Jan 2024 02:37:22 +0000 Subject: [PATCH] fixes for eval --- megatron/training.py | 7 ------- pretrain_gpt.py | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index edc591e473..1076d8cd3d 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -844,13 +844,6 @@ def evaluate(forward_step_func, for key in total_loss_dict: total_loss_dict[key] /= args.eval_iters * get_num_microbatches() - # Sum LBLs across pipeline-model-parallel shards. - if args.model_type == ModelType.encoder_or_decoder_with_lbl: - assert "load balancing loss" in total_loss_dict - torch.distributed.all_reduce( - total_loss_dict["load balancing loss"], - group=mpu.get_pipeline_model_parallel_group()) - return total_loss_dict, collected_non_loss_data def evaluate_and_print_results(prefix, forward_step_func, diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 2e045ff331..9882496008 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -125,7 +125,7 @@ def forward_step(data_iterator, model): labels=labels) loss_fn = ( - moe_loss_func if args.moe_num_experts is not None else loss_func) + moe_loss_func if args.moe_num_experts is not None and model.training else loss_func) return output_tensor, partial(loss_fn, loss_mask) def train_valid_test_datasets_provider(train_val_test_num_samples):