diff --git a/concept_linking/solutions/MachineLearning/src/model_training.py b/concept_linking/solutions/MachineLearning/src/model_training.py index 6191c40..aefa47b 100644 --- a/concept_linking/solutions/MachineLearning/src/model_training.py +++ b/concept_linking/solutions/MachineLearning/src/model_training.py @@ -11,8 +11,11 @@ def custom_collate(batch): # Separate inputs, targets, and lengths + #inputs = [torch.LongTensor(item['input']) for item in batch] + #targets = [torch.tensor(item['target'], dtype=torch.float32) for item in batch] + #lengths = [item['length'] for item in batch] inputs = [torch.LongTensor(item['input']) for item in batch] - targets = [torch.tensor(item['target'], dtype=torch.float32) for item in batch] + targets = [item['target'].clone().detach() for item in batch] lengths = [item['length'] for item in batch] # Pad sequences