From 38b040b5f51dd3a7bde6e5e291b0113f6a00c16b Mon Sep 17 00:00:00 2001 From: Aditya NG Date: Sun, 3 Mar 2024 11:03:16 +0530 Subject: [PATCH] fix(args): parsing fix --- LLaVA/llava/train/train.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/LLaVA/llava/train/train.py b/LLaVA/llava/train/train.py index a868800..4f4a121 100644 --- a/LLaVA/llava/train/train.py +++ b/LLaVA/llava/train/train.py @@ -781,16 +781,18 @@ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) - eval_dataset = None validation_data_path = data_args.data_path.replace('train', 'val') - if os.path.isfile(validation_data_path): - eval_dataset = LazySupervisedDataset(tokenizer=tokenizer, - data_path=validation_data_path, - data_args=data_args) - print('Eval dataset found at {}'.format(validation_data_path)) - else: - print('No validation dataset found at {}'.format(validation_data_path)) - exit(1) + + assert os.path.isfile(validation_data_path), ( + 'Validation file not found at {}.'.format(validation_data_path) + + 'Make sure you have run preprocess.py first.' + ) + + eval_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=validation_data_path, + data_args=data_args) + print('Eval dataset found at {}'.format(validation_data_path)) + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)