Skip to content

Commit

Permalink
fix(args): parsing fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed Mar 3, 2024
1 parent 9b8917c commit 38b040b
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions LLaVA/llava/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,16 +781,18 @@ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

eval_dataset = None
validation_data_path = data_args.data_path.replace('train', 'val')
if os.path.isfile(validation_data_path):
eval_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=validation_data_path,
data_args=data_args)
print('Eval dataset found at {}'.format(validation_data_path))
else:
print('No validation dataset found at {}'.format(validation_data_path))
exit(1)

assert os.path.isfile(validation_data_path), (
'Validation file not found at {}.'.format(validation_data_path) +
'Make sure you have run preprocess.py first.'
)

eval_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=validation_data_path,
data_args=data_args)
print('Eval dataset found at {}'.format(validation_data_path))

return dict(train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator)
Expand Down

0 comments on commit 38b040b

Please sign in to comment.