diff --git a/LLaVA/llava/train/train.py b/LLaVA/llava/train/train.py index a868800..4f4a121 100644 --- a/LLaVA/llava/train/train.py +++ b/LLaVA/llava/train/train.py @@ -781,16 +781,18 @@ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) - eval_dataset = None validation_data_path = data_args.data_path.replace('train', 'val') - if os.path.isfile(validation_data_path): - eval_dataset = LazySupervisedDataset(tokenizer=tokenizer, - data_path=validation_data_path, - data_args=data_args) - print('Eval dataset found at {}'.format(validation_data_path)) - else: - print('No validation dataset found at {}'.format(validation_data_path)) - exit(1) + + assert os.path.isfile(validation_data_path), ( + 'Validation file not found at {}.'.format(validation_data_path) + + 'Make sure you have run preprocess.py first.' + ) + + eval_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=validation_data_path, + data_args=data_args) + print('Eval dataset found at {}'.format(validation_data_path)) + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)