Skip to content

Commit

Permalink
ML UnitTest
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamm0 committed Dec 11, 2023
1 parent 471c7cf commit d16b7be
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@

def custom_collate(batch):
# Separate inputs, targets, and lengths
#inputs = [torch.LongTensor(item['input']) for item in batch]
#targets = [torch.tensor(item['target'], dtype=torch.float32) for item in batch]
#lengths = [item['length'] for item in batch]
inputs = [torch.LongTensor(item['input']) for item in batch]
targets = [torch.tensor(item['target'], dtype=torch.float32) for item in batch]
targets = [item['target'].clone().detach() for item in batch]
lengths = [item['length'] for item in batch]

# Pad sequences
Expand Down

0 comments on commit d16b7be

Please sign in to comment.