Skip to content

Commit

Permalink
feat(train): eval dataset pointed
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed Mar 3, 2024
1 parent 1fb637d commit 819cd8f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
6 changes: 5 additions & 1 deletion LLaVA/llava/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,12 @@ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_path=data_args.data_path,
data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

eval_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.validation_data_path,
data_args=data_args)
return dict(train_dataset=train_dataset,
eval_dataset=None,
eval_dataset=eval_dataset,
data_collator=data_collator)


Expand Down
12 changes: 9 additions & 3 deletions drivellava/scripts/compile_jsons.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from drivellava.constants import ENCODED_JSON, VAL_ENCODED_JSON
from drivellava.sparse_llava_dataset import get_drivellava_prompt
from drivellava.trajectory_encoder import (
TrajectoryEncoder, NUM_TRAJECTORY_TEMPLATES
NUM_TRAJECTORY_TEMPLATES,
TrajectoryEncoder,
)


def load_json_dataset(
json_list: List[str],
trajectory_encoder: TrajectoryEncoder,
Expand Down Expand Up @@ -57,8 +59,12 @@ def main():
random.shuffle(train)
random.shuffle(val)

new_train_json_path = os.path.abspath(f"checkpoints/train_{str(NUM_TRAJECTORY_TEMPLATES)}.json")
new_val_json_path = os.path.abspath(f"checkpoints/val_{NUM_TRAJECTORY_TEMPLATES}.json")
new_train_json_path = os.path.abspath(
f"checkpoints/train_{str(NUM_TRAJECTORY_TEMPLATES)}.json"
)
new_val_json_path = os.path.abspath(
f"checkpoints/val_{NUM_TRAJECTORY_TEMPLATES}.json"
)

# Save train to a temp file
with open(new_train_json_path, "w", encoding="utf-8") as f:
Expand Down
11 changes: 8 additions & 3 deletions drivellava/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,12 @@ def load_json_dataset_balanced(

def main():

train_json_path = os.path.join(COMMAVQ_DIR, f"train_{str(NUM_TRAJECTORY_TEMPLATES)}.json")
val_json_path = os.path.join(COMMAVQ_DIR, f"val_{str(NUM_TRAJECTORY_TEMPLATES)}.json")
train_json_path = os.path.join(
COMMAVQ_DIR, f"train_{str(NUM_TRAJECTORY_TEMPLATES)}.json"
)
val_json_path = os.path.join(
COMMAVQ_DIR, f"val_{str(NUM_TRAJECTORY_TEMPLATES)}.json"
)

train = load_json_dataset_balanced(
[
Expand Down Expand Up @@ -113,7 +117,7 @@ def main():
DEEPSPEED_JSON = os.path.abspath("./config/zero3.json")
MODEL_NAME = "liuhaotian/llava-v1.5-7b"
DATA_PATH = new_train_json_path # Replace with your JSON data path
# VAL_DATA_PATH = new_val_json_path
VAL_DATA_PATH = new_val_json_path
IMAGE_FOLDER = os.path.expanduser(
"~/Datasets/commavq"
) # Replace with your image folder path
Expand All @@ -131,6 +135,7 @@ def main():
--model_name_or_path {MODEL_NAME} \
--version llava_llama_2 \
--data_path {DATA_PATH} \
--validation_data_path {VAL_DATA_PATH} \
--image_folder {IMAGE_FOLDER} \
--vision_tower {VISION_TOWER} \
--mm_projector_type mlp2x_gelu \
Expand Down

0 comments on commit 819cd8f

Please sign in to comment.