From f403c19c860becfcecb0471c382229be3c8ce124 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Mon, 18 Nov 2024 21:27:09 +0800 Subject: [PATCH] [DistDataloader] fix eval new_group (#9438) --- paddlenlp/data/dist_dataloader.py | 28 ++++++++++++++-------------- paddlenlp/trainer/trainer.py | 31 ++++++++++++++++--------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/paddlenlp/data/dist_dataloader.py b/paddlenlp/data/dist_dataloader.py index a6330ce1fe08..01f1828b535a 100644 --- a/paddlenlp/data/dist_dataloader.py +++ b/paddlenlp/data/dist_dataloader.py @@ -66,6 +66,7 @@ def __init__( eval = kwargs.pop("eval", False) is_iterable_dataset = kwargs.pop("is_iterable_dataset", False) + self._pp_data_group = kwargs.pop("pp_data_group", None) if dataset is None: dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset() @@ -78,10 +79,8 @@ def __init__( # Init pp data comm group. if self._hcg.get_pipe_parallel_world_size() > 1: - self._pp_data_group = self._init_dataloader_comm_group() self._pp_group = self._hcg.get_pipe_parallel_group() else: - self._pp_data_group = None self._pp_group = None self.mp_group = self._hcg.get_model_parallel_group() @@ -132,18 +131,6 @@ def __len__(self): else: raise ValueError("raise error for `paddlenlp.trainer.trainer_utils.has_length`") - def _init_dataloader_comm_group(self): - topo = self._hcg._topo - parallel_comm_group = None - parallel_groups = topo.get_comm_list("pipe") - - for group in parallel_groups: - ranks = [group[0], group[-1]] - comm_group = paddle.distributed.new_group(ranks=ranks) - if paddle.distributed.get_rank() in ranks: - parallel_comm_group = comm_group - return parallel_comm_group - def __iter__(self): return self @@ -212,3 +199,16 @@ def __next__(self): logger.debug(e) data = self._broadcast_data(data) return data + + +def init_dataloader_comm_group(): + hcg = fleet.get_hybrid_communicate_group() + topo = hcg._topo + parallel_groups = topo.get_comm_list("pipe") + + for group in parallel_groups: + ranks = [group[0], group[-1]] + comm_group = paddle.distributed.new_group(ranks=ranks) + if paddle.distributed.get_rank() in ranks: + parallel_comm_group = comm_group + return parallel_comm_group diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 174a322382f2..a489a2b7566c 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -79,6 +79,7 @@ DataCollatorWithPadding, DistDataLoader, default_data_collator, + init_dataloader_comm_group, ) from ..peft import LoRAModel, PrefixModelForCausalLM, VeRAModel @@ -422,6 +423,10 @@ def fn(layer): model.apply(fn) + self._pp_data_group = None + if self.args.pipeline_parallel_degree > 1 and self.args.distributed_dataloader: + self._pp_data_group = init_dataloader_comm_group() + default_label_names = ( ["start_positions", "end_positions"] if "QusetionAnswering" in type(self.model).__name__ or "UIE" in type(self.model).__name__ @@ -1505,6 +1510,7 @@ def get_train_dataloader(self): train_dataset = self._remove_unused_columns(train_dataset, description="training") _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader + additional_configs = {} if is_iterable_dataset: # For iterable dataset if self.args.dataset_world_size > 1 and train_dataset is not None: train_dataset = IterableDatasetShard( @@ -1517,9 +1523,7 @@ def get_train_dataloader(self): if self.args.distributed_dataloader: logger.info("Training using DistDataLoader.") - additional_configs = {"is_iterable_dataset": True} - else: - additional_configs = {} + additional_configs = {"is_iterable_dataset": True, "pp_data_group": self._pp_data_group} return _DataLoader( train_dataset, batch_size=self.args.per_device_train_batch_size, @@ -1531,11 +1535,13 @@ def get_train_dataloader(self): train_sampler = self._get_train_sampler() if self.args.distributed_dataloader: logger.info("Training using DistDataLoader.") + additional_configs = {"pp_data_group": self._pp_data_group} return _DataLoader( train_dataset, batch_sampler=train_sampler, collate_fn=self.data_collator, num_workers=self.args.dataloader_num_workers, + **additional_configs, ) def _get_eval_sampler(self, eval_dataset: Dataset): @@ -1591,6 +1597,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader + additional_configs = {} if is_iterable_dataset: if self.args.dataset_world_size > 1 and eval_dataset is not None: eval_dataset = IterableDatasetShard( @@ -1600,11 +1607,10 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa num_processes=self.args.dataset_world_size, process_index=self.args.dataset_rank, ) + if self.args.distributed_dataloader: logger.info("Eval using DistDataLoader.") - additional_configs = {"eval": True, "is_iterable_dataset": True} - else: - additional_configs = {} + additional_configs = {"eval": True, "is_iterable_dataset": True, "pp_data_group": self._pp_data_group} return _DataLoader( eval_dataset, batch_size=self.args.per_device_eval_batch_size, @@ -1616,9 +1622,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa eval_sampler = self._get_eval_sampler(eval_dataset) if self.args.distributed_dataloader: logger.info("Eval using DistDataLoader.") - additional_configs = {"eval": True} - else: - additional_configs = {} + additional_configs = {"eval": True, "pp_data_group": self._pp_data_group} return _DataLoader( eval_dataset, batch_sampler=eval_sampler, @@ -1651,6 +1655,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: test_dataset = self._remove_unused_columns(test_dataset, description="test") _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader + additional_config = {} if is_iterable_dataset: if self.args.dataset_world_size > 1 and test_dataset is not None: test_dataset = IterableDatasetShard( @@ -1663,9 +1668,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: if self.args.distributed_dataloader: logger.info("Test using DistDataLoader.") - additional_config = {"eval": True, "is_iterable_dataset": True} - else: - additional_config = {} + additional_config = {"eval": True, "is_iterable_dataset": True, "pp_data_group": self._pp_data_group} return _DataLoader( test_dataset, batch_size=self.args.per_device_eval_batch_size * self.world_size, @@ -1677,9 +1680,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: test_sampler = self._get_eval_sampler(test_dataset) if self.args.distributed_dataloader: logger.info("Test using DistDataLoader.") - additional_config = {"eval": True} - else: - additional_config = {} + additional_config = {"eval": True, "pp_data_group": self._pp_data_group} # We use the same batch_size as for eval. return _DataLoader( test_dataset,