Skip to content

Commit

Permalink
fix fp32 dd
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 17, 2023
1 parent 671775b commit 8a642b2
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions paddlenlp/data/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,16 @@ def __next__(self):
if self._data_int64_keys is None:
if self.mp_group is not None and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(data_int64_keys, src=self.mp_src_rank, group=self.mp_group)
paddle.distributed.broadcast_object_list(data_fp32_keys, src=self.mp_src_rank, group=self.mp_group)
if self._data_fp32_keys_size > 0:
paddle.distributed.broadcast_object_list(data_fp32_keys, src=self.mp_src_rank, group=self.mp_group)
if self._pp_data_group is not None:
paddle.distributed.broadcast_object_list(
data_int64_keys, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
paddle.distributed.broadcast_object_list(
data_fp32_keys, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
if self._data_fp32_keys_size > 0:
paddle.distributed.broadcast_object_list(
data_fp32_keys, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
self._data_int64_keys = data_int64_keys
self._data_fp32_keys = data_fp32_keys

Expand All @@ -203,9 +205,10 @@ def __next__(self):
data_int64_list = broadcast_data_list(
data_int64_list, paddle.int64, self.mp_rank, self.mp_group, self.mp_src_rank
)
data_fp32_list = broadcast_data_list(
data_fp32_list, paddle.float32, self.mp_rank, self.mp_group, self.mp_src_rank
)
if self._data_fp32_keys_size > 0:
data_fp32_list = broadcast_data_list(
data_fp32_list, paddle.float32, self.mp_rank, self.mp_group, self.mp_src_rank
)

if self._pp_data_group is not None:
# Note(daisimng): In last stage of pp, we don't need input_ids.
Expand All @@ -217,13 +220,14 @@ def __next__(self):
self._pp_data_group,
self._pp_data_group.ranks[0],
)
data_fp32_list = broadcast_data_list(
data_fp32_list,
paddle.float32,
self.pp_rank,
self._pp_data_group,
self._pp_data_group.ranks[0],
)
if self._data_fp32_keys_size > 0:
data_fp32_list = broadcast_data_list(
data_fp32_list,
paddle.float32,
self.pp_rank,
self._pp_data_group,
self._pp_data_group.ranks[0],
)

out = dict([(key, data) for key, data in zip(self._data_int64_keys, data_int64_list)])
out.update([(key, data) for key, data in zip(self._data_fp32_keys, data_fp32_list)])
Expand Down

0 comments on commit 8a642b2

Please sign in to comment.