diff --git a/python/paddle/nn/clip.py b/python/paddle/nn/clip.py index b3fe014b27a350..6b8e47083d4d50 100644 --- a/python/paddle/nn/clip.py +++ b/python/paddle/nn/clip.py @@ -717,6 +717,7 @@ def _dygraph_clip(self, params_grads): sum_square_list = [] sum_square_list_fp16 = [] sum_square_list_fp32 = [] + flag_new_pp = True if len(params_grads) > 0 and len(params_grads[0]) > 0: src_mesh = params_grads[0][0].process_mesh else: @@ -742,6 +743,7 @@ def _dygraph_clip(self, params_grads): # if the gradient mesh is not equal to src mesh # do reshard to get the result of squared_l2 from other pp stage mesh if src_mesh is not None and g.process_mesh != src_mesh: + flag_new_pp = False pp_mesh = get_complete_pp_mesh(g.process_mesh) if set(g.process_mesh.process_ids) < set(pp_mesh.process_ids): sum_square = dist.reshard( @@ -790,6 +792,44 @@ def async_add_n(var_list): global_norm_var.append(global_norm_var_fp64) global_norm_var = async_add_n(global_norm_var) + global_mesh = dist.get_mesh() + is_pp_enable = False + if global_mesh is not None: + is_pp_enable = ( + "pp" in global_mesh.dim_names + and global_mesh.get_dim_size("pp") > 1 + ) + if ( + flag_new_pp and src_mesh is not None and is_pp_enable + ): # Use new pp_flask,At this point global_norm_var it's sub_norm_var_sum,we need to sum it between different pp_stage + global_pp_mesh = global_mesh.get_mesh_with_dim("pp") + reorder_mesh = global_pp_mesh._mesh.reshape( + global_mesh.get_dim_size("pp"), -1 + ) + curr_rank = dist.get_rank() + assert ( + curr_rank in global_pp_mesh.process_ids + ), "current rank is not in pp process mesh" + curr_rank_sub_group = None + for col in range( + reorder_mesh.shape[-1] + ): # every_sub_mesh need to create a new group,otherwise,the group id of sub_mesh will be the same,which will cause the all_gather error + sub_mesh = dist.ProcessMesh(reorder_mesh[:, col], ["pp"]) + sub_group = dist.new_group(sub_mesh.process_ids) + if curr_rank in reorder_mesh[:, col]: + curr_rank_sub_group = sub_group + global_norm_var_list = [] + dist.all_gather( + global_norm_var_list, + global_norm_var._local_value(), + group=curr_rank_sub_group, + ) + real_global_norm_var = async_add_n(global_norm_var_list) + global_norm_var = dist.shard_tensor( + real_global_norm_var, + global_norm_var.process_mesh, + global_norm_var.placements, + ) if self.should_comm_on_shard_dim and hasattr(self, 'sharding_group'): paddle.distributed.all_reduce( diff --git a/test/auto_parallel/PP_Schedules_demo.py b/test/auto_parallel/PP_Schedules_demo.py index 865fe881920cad..4e41ab055702b3 100644 --- a/test/auto_parallel/PP_Schedules_demo.py +++ b/test/auto_parallel/PP_Schedules_demo.py @@ -15,15 +15,19 @@ import random import numpy as np -from schedules import Schedule1F1B, ScheduleFThenB, ScheduleVPP -from stage import ( - PipelineStage, -) import paddle import paddle.distributed as dist from paddle import nn from paddle.distributed import fleet +from paddle.distributed.auto_parallel.pipelining.schedules import ( + Schedule1F1B, + ScheduleFThenB, + ScheduleVPP, +) +from paddle.distributed.auto_parallel.pipelining.stage import ( + PipelineStage, +) from paddle.io import DataLoader, Dataset @@ -405,6 +409,67 @@ def test_dp_pp(self): opt.clear_grad() return losses_by_step + def test_pp_model_with_ClipGradByGlobalNorm(self): + """Test pipeline parallel model with ClipGradByGlobalNorm using PPMyModel as the baseline""" + fix_seeds() + pp_model = PPMyModel() + opt = paddle.optimizer.AdamW( + learning_rate=0.001, + parameters=pp_model.parameters(), + grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0), + ) + loss_fn = nn.MSELoss() + dataset = RandomDataset(image_size=8, output_size=8, num_samples=8) + loader = DataLoader(dataset, batch_size=1) + pp_losses_step = [] + num_iterations = 20 + + for iter_idx in range(num_iterations): + pp_losses_micro_batch = [] + for i, (data, label) in enumerate(loader): + output = pp_model(data) + loss = loss_fn(output, label) + pp_losses_micro_batch.append(loss.item()) + loss.backward() + pp_losses_step.append( + np.array(pp_losses_micro_batch, dtype=np.float32).mean() + ) + opt.step() + opt.clear_grad() + return pp_losses_step + + def test_ScheduleFThenB_with_ClipGradByGlobalNorm(self): + fix_seeds() + self.model = PPMyModel_SingleStage() + self.micro_batches = 8 + self.stage = PipelineStage(self.model, self.rank, 4, group=self.group) + self.stage.has_backward = True + loss_fn_ = nn.MSELoss() + schedule = ScheduleFThenB( + self.stage, self.micro_batches, loss_fn=loss_fn_ + ) + opt = paddle.optimizer.AdamW( + learning_rate=0.001, + parameters=self.model.parameters(), + grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0), + ) + dataset = RandomDataset(image_size=8, output_size=8, num_samples=8) + loader = DataLoader(dataset, batch_size=8) + losses_by_step = [] + num_iterations = 20 + + for iter_idx in range(num_iterations): + losses_by_micro_batch = [] + for i, (data, label) in enumerate(loader): + schedule.step(data, target=label, losses=losses_by_micro_batch) + if self.rank == 3: + losses_by_step.append( + np.array(losses_by_micro_batch, dtype=np.float32).mean() + ) + opt.step() + opt.clear_grad() + return losses_by_step + def run_test(self): """Compare losses between three training methods""" self.setUpClass() @@ -412,6 +477,12 @@ def run_test(self): scheduleFThenB_losses = self.test_ScheduleFThenB() schedule1f1b_losses = self.test_Schedule1F1B() schedulevpp_losses = self.test_ScheduleVPP() + pp_model_with_ClipGradByGlobalNorm_losses = ( + self.test_pp_model_with_ClipGradByGlobalNorm() + ) + scheduleFThenB_with_ClipGradByGlobalNorm_losses = ( + self.test_ScheduleFThenB_with_ClipGradByGlobalNorm() + ) dp_pp_losses = self.test_dp_pp() if self.rank == 3: @@ -439,6 +510,12 @@ def run_test(self): rtol=1e-5, ) + np.testing.assert_allclose( + pp_model_with_ClipGradByGlobalNorm_losses, + scheduleFThenB_with_ClipGradByGlobalNorm_losses, + rtol=1e-5, + ) + if __name__ == '__main__': Test_Schedules().run_test()