Skip to content

Commit

Permalink
[DistDataloader] fix eval new_group (PaddlePaddle#9438)
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay authored and lugimzzz committed Nov 19, 2024
1 parent 0503fe0 commit f403c19
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
28 changes: 14 additions & 14 deletions paddlenlp/data/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(

eval = kwargs.pop("eval", False)
is_iterable_dataset = kwargs.pop("is_iterable_dataset", False)
self._pp_data_group = kwargs.pop("pp_data_group", None)

if dataset is None:
dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset()
Expand All @@ -78,10 +79,8 @@ def __init__(

# Init pp data comm group.
if self._hcg.get_pipe_parallel_world_size() > 1:
self._pp_data_group = self._init_dataloader_comm_group()
self._pp_group = self._hcg.get_pipe_parallel_group()
else:
self._pp_data_group = None
self._pp_group = None

self.mp_group = self._hcg.get_model_parallel_group()
Expand Down Expand Up @@ -132,18 +131,6 @@ def __len__(self):
else:
raise ValueError("raise error for `paddlenlp.trainer.trainer_utils.has_length`")

def _init_dataloader_comm_group(self):
topo = self._hcg._topo
parallel_comm_group = None
parallel_groups = topo.get_comm_list("pipe")

for group in parallel_groups:
ranks = [group[0], group[-1]]
comm_group = paddle.distributed.new_group(ranks=ranks)
if paddle.distributed.get_rank() in ranks:
parallel_comm_group = comm_group
return parallel_comm_group

def __iter__(self):
return self

Expand Down Expand Up @@ -212,3 +199,16 @@ def __next__(self):
logger.debug(e)
data = self._broadcast_data(data)
return data


def init_dataloader_comm_group():
hcg = fleet.get_hybrid_communicate_group()
topo = hcg._topo
parallel_groups = topo.get_comm_list("pipe")

for group in parallel_groups:
ranks = [group[0], group[-1]]
comm_group = paddle.distributed.new_group(ranks=ranks)
if paddle.distributed.get_rank() in ranks:
parallel_comm_group = comm_group
return parallel_comm_group
31 changes: 16 additions & 15 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
DataCollatorWithPadding,
DistDataLoader,
default_data_collator,
init_dataloader_comm_group,
)
from ..peft import LoRAModel, PrefixModelForCausalLM, VeRAModel

Expand Down Expand Up @@ -422,6 +423,10 @@ def fn(layer):

model.apply(fn)

self._pp_data_group = None
if self.args.pipeline_parallel_degree > 1 and self.args.distributed_dataloader:
self._pp_data_group = init_dataloader_comm_group()

default_label_names = (
["start_positions", "end_positions"]
if "QusetionAnswering" in type(self.model).__name__ or "UIE" in type(self.model).__name__
Expand Down Expand Up @@ -1505,6 +1510,7 @@ def get_train_dataloader(self):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

additional_configs = {}
if is_iterable_dataset: # For iterable dataset
if self.args.dataset_world_size > 1 and train_dataset is not None:
train_dataset = IterableDatasetShard(
Expand All @@ -1517,9 +1523,7 @@ def get_train_dataloader(self):

if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")
additional_configs = {"is_iterable_dataset": True}
else:
additional_configs = {}
additional_configs = {"is_iterable_dataset": True, "pp_data_group": self._pp_data_group}
return _DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
Expand All @@ -1531,11 +1535,13 @@ def get_train_dataloader(self):
train_sampler = self._get_train_sampler()
if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")
additional_configs = {"pp_data_group": self._pp_data_group}
return _DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
**additional_configs,
)

def _get_eval_sampler(self, eval_dataset: Dataset):
Expand Down Expand Up @@ -1591,6 +1597,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

additional_configs = {}
if is_iterable_dataset:
if self.args.dataset_world_size > 1 and eval_dataset is not None:
eval_dataset = IterableDatasetShard(
Expand All @@ -1600,11 +1607,10 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
num_processes=self.args.dataset_world_size,
process_index=self.args.dataset_rank,
)

if self.args.distributed_dataloader:
logger.info("Eval using DistDataLoader.")
additional_configs = {"eval": True, "is_iterable_dataset": True}
else:
additional_configs = {}
additional_configs = {"eval": True, "is_iterable_dataset": True, "pp_data_group": self._pp_data_group}
return _DataLoader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
Expand All @@ -1616,9 +1622,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
eval_sampler = self._get_eval_sampler(eval_dataset)
if self.args.distributed_dataloader:
logger.info("Eval using DistDataLoader.")
additional_configs = {"eval": True}
else:
additional_configs = {}
additional_configs = {"eval": True, "pp_data_group": self._pp_data_group}
return _DataLoader(
eval_dataset,
batch_sampler=eval_sampler,
Expand Down Expand Up @@ -1651,6 +1655,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
test_dataset = self._remove_unused_columns(test_dataset, description="test")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

additional_config = {}
if is_iterable_dataset:
if self.args.dataset_world_size > 1 and test_dataset is not None:
test_dataset = IterableDatasetShard(
Expand All @@ -1663,9 +1668,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:

if self.args.distributed_dataloader:
logger.info("Test using DistDataLoader.")
additional_config = {"eval": True, "is_iterable_dataset": True}
else:
additional_config = {}
additional_config = {"eval": True, "is_iterable_dataset": True, "pp_data_group": self._pp_data_group}
return _DataLoader(
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
Expand All @@ -1677,9 +1680,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
test_sampler = self._get_eval_sampler(test_dataset)
if self.args.distributed_dataloader:
logger.info("Test using DistDataLoader.")
additional_config = {"eval": True}
else:
additional_config = {}
additional_config = {"eval": True, "pp_data_group": self._pp_data_group}
# We use the same batch_size as for eval.
return _DataLoader(
test_dataset,
Expand Down

0 comments on commit f403c19

Please sign in to comment.