Skip to content

Commit

Permalink
pick from json fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gshennvm committed Nov 15, 2023
1 parent 7687661 commit 98ccd45
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions nemo_aligner/data/nlp/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def __init__(

self.max_sample_length = max_sample_length

self.use_json = self.cfg.data.data_impl.startswith("json")

self.shuffled_indices = list(range(len(self)))

np_rng = np.random.default_rng(seed=seed)
Expand Down Expand Up @@ -98,7 +100,7 @@ def __getitem__(self, idx):
while True:
shuffled_idx = self.shuffled_indices[idx]
sample = self.data[shuffled_idx]
if self.cfg.data.data_impl.startswith("json"):
if self.use_json:
sample, _ = self.encode(sample["text"])
if len(sample) <= self.max_sample_length:
break
Expand All @@ -115,7 +117,12 @@ def __getitem__(self, idx):
)
mask_sample = True

sample_tensor = torch.from_numpy(sample.astype(np.int64))
if self.use_json:
# `sample` is a regular Python list.
sample_tensor = torch.tensor(sample, dtype=torch.int64)
else:
# `sample` is a NumPy array.
sample_tensor = torch.from_numpy(sample.astype(np.int64))

# if we want to mask the sample we should
# set the loss multiplier to 0
Expand Down

0 comments on commit 98ccd45

Please sign in to comment.