Skip to content

Commit

Permalink
Modify dist dataloader (PaddlePaddle#7286)
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay authored Oct 24, 2023
1 parent 5359c0e commit 4b86006
Showing 1 changed file with 47 additions and 70 deletions.
117 changes: 47 additions & 70 deletions paddlenlp/data/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,12 @@ def __init__(

self._hcg = fleet.get_hybrid_communicate_group()

# init pp data comm group
# Init pp data comm group.
if self._hcg.get_pipe_parallel_world_size() > 1:
self._pp_data_group = self._init_dataloader_comm_group()
else:
self._pp_data_group = None

# tensor parallel message
self.mp_group = self._hcg.get_model_parallel_group()
self.mp_rank = self._hcg.get_model_parallel_rank()
self.mp_src_rank = self._hcg.get_model_parallel_group_src_rank()
Expand All @@ -80,8 +79,10 @@ def __init__(
self.dp_rank = self._hcg.get_data_parallel_rank()
sharding_rank = self._hcg.get_sharding_parallel_rank()
self._need_data = (self.mp_rank == 0) and (self.pp_rank == 0)
self._data_int64_keys, self._data_int64_keys_size = None, None
self._data_fp32_keys, self._data_fp32_keys_size = None, None

# When needed other data types, we can modify dtype_list.
self.dtype_list = [paddle.int64, paddle.float32, paddle.int32]
self._data_keys_list, self._data_keys_size = None, None

if self._need_data:
self._dataloader = paddle.io.DataLoader(
Expand Down Expand Up @@ -139,99 +140,75 @@ def __iter__(self):
return self

def __next__(self):
data_int64_keys_size, data_fp32_keys_size = 0, 0
data_keys_size = [0 for i in range(len(self.dtype_list))]
if self._need_data:
# {'input_ids': int64, 'labels': int64}
data = next(self._dataloader_iter)
data_keys = list(data.keys())

# TODO(daisiming): Better methods are needed to support new data types.
type_check = [paddle.int64, paddle.float32]
for key in data_keys:
if data[key].dtype not in type_check:
if data[key].dtype not in self.dtype_list:
raise ValueError(
f"Dist dataloader requires dtype == `int64` or dtype == 'float32', but got: {data[key].dtype}"
f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: {data[key].dtype}"
)

data_int64_list = [data[key] for key in data_keys if data[key].dtype == paddle.int64]
data_int64_keys = [key for key in data_keys if data[key].dtype == paddle.int64]
data_fp32_list = [data[key] for key in data_keys if data[key].dtype == paddle.float32]
data_fp32_keys = [key for key in data_keys if data[key].dtype == paddle.float32]
data_int64_keys_size, data_fp32_keys_size = len(data_int64_keys), len(data_fp32_keys)
data_list, data_keys_list = [], []
for i, dtype in enumerate(self.dtype_list):
data_list.append([data[key] for key in data_keys if data[key].dtype == dtype])
data_keys_list.append([key for key in data_keys if data[key].dtype == dtype])
data_keys_size = [len(keys) for keys in data_keys_list]

# broadcast data keys size
data_int64_keys_size = paddle.to_tensor(data_int64_keys_size)
data_fp32_keys_size = paddle.to_tensor(data_fp32_keys_size)
if self._data_int64_keys_size is None:
# Broadcast data keys size.
if self._data_keys_size is None:
if self.mp_group is not None and self.pp_rank == 0:
paddle.distributed.broadcast(data_int64_keys_size, src=self.mp_src_rank, group=self.mp_group)
paddle.distributed.broadcast(data_fp32_keys_size, src=self.mp_src_rank, group=self.mp_group)
paddle.distributed.broadcast_object_list(data_keys_size, src=self.mp_src_rank, group=self.mp_group)
if self._pp_data_group is not None:
paddle.distributed.broadcast(
data_int64_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
paddle.distributed.broadcast(
data_fp32_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group
paddle.distributed.broadcast_object_list(
data_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
self._data_int64_keys_size = int(data_int64_keys_size.item())
self._data_fp32_keys_size = int(data_fp32_keys_size.item())
self._data_keys_size = data_keys_size

if not self._need_data:
data_int64_keys = [None for i in range(self._data_int64_keys_size)]
data_fp32_keys = [None for i in range(self._data_fp32_keys_size)]
data_keys_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]

# broadcast data keys name
if self._data_int64_keys is None:
# Broadcast data keys name.
if self._data_keys_list is None:
if self.mp_group is not None and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(data_int64_keys, src=self.mp_src_rank, group=self.mp_group)
if self._data_fp32_keys_size > 0:
paddle.distributed.broadcast_object_list(data_fp32_keys, src=self.mp_src_rank, group=self.mp_group)
paddle.distributed.broadcast_object_list(data_keys_list, src=self.mp_src_rank, group=self.mp_group)
if self._pp_data_group is not None:
paddle.distributed.broadcast_object_list(
data_int64_keys, src=self._pp_data_group.ranks[0], group=self._pp_data_group
data_keys_list, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
if self._data_fp32_keys_size > 0:
paddle.distributed.broadcast_object_list(
data_fp32_keys, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
self._data_int64_keys = data_int64_keys
self._data_fp32_keys = data_fp32_keys
self._data_keys_list = data_keys_list

# broadcast data
# Broadcast data.
if not self._need_data:
data_int64_list = [None for i in range(self._data_int64_keys_size)]
data_fp32_list = [None for i in range(self._data_fp32_keys_size)]
data_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]

if self.mp_group is not None and self.pp_rank == 0:
data_int64_list = broadcast_data_list(
data_int64_list, paddle.int64, self.mp_rank, self.mp_group, self.mp_src_rank
)
if self._data_fp32_keys_size > 0:
data_fp32_list = broadcast_data_list(
data_fp32_list, paddle.float32, self.mp_rank, self.mp_group, self.mp_src_rank
)
for i, dtype in enumerate(self.dtype_list):
if self._data_keys_size[i] > 0:
data_list[i] = broadcast_data_list(
data_list[i], dtype, self.mp_rank, self.mp_group, self.mp_src_rank
)

if self._pp_data_group is not None:
# Note(daisimng): In last stage of pp, we don't need input_ids.
# It will be removed in future.
data_int64_list = broadcast_data_list(
data_int64_list,
paddle.int64,
self.pp_rank,
self._pp_data_group,
self._pp_data_group.ranks[0],
)
if self._data_fp32_keys_size > 0:
data_fp32_list = broadcast_data_list(
data_fp32_list,
paddle.float32,
self.pp_rank,
self._pp_data_group,
self._pp_data_group.ranks[0],
)
for i, dtype in enumerate(self.dtype_list):
if self._data_keys_size[i] > 0:
data_list[i] = broadcast_data_list(
data_list[i],
dtype,
self.pp_rank,
self._pp_data_group,
self._pp_data_group.ranks[0],
)

out_data = {}
for keys, datas in zip(self._data_keys_list, data_list):
out_data.update([(k, d) for k, d in zip(keys, datas)])

out = dict([(key, data) for key, data in zip(self._data_int64_keys, data_int64_list)])
out.update([(key, data) for key, data in zip(self._data_fp32_keys, data_fp32_list)])
return out
return out_data


def broadcast_data_list(data_list, datatype, comm_rank=0, comm_group=None, src_rank=0):
Expand Down

0 comments on commit 4b86006

Please sign in to comment.