Skip to content

Commit

Permalink
refactor(train,compile_jsons): seperated training from json generation
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed Mar 3, 2024
1 parent c50aaa9 commit af74918
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 57 deletions.
73 changes: 73 additions & 0 deletions drivellava/scripts/compile_jsons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Trains LLAVA model on the cumulative dataset.
"""

import json
import os
import random
from typing import List

from drivellava.constants import ENCODED_JSON, VAL_ENCODED_JSON
from drivellava.sparse_llava_dataset import get_drivellava_prompt
from drivellava.trajectory_encoder import TrajectoryEncoder


def load_json_dataset(
json_list: List[str],
trajectory_encoder: TrajectoryEncoder,
):

data = []
for json_path in json_list:
with open(json_path, "r", encoding="utf-8") as f:
loaded = json.load(f)
for index in range(len(loaded)):
assert len(loaded[index]["conversations"][1]["value"]) == 1

loaded[index]["conversations"][1]["value"] = (
"Selected Trajectory: "
+ loaded[index]["conversations"][1]["value"]
)
loaded[index]["conversations"][0]["value"] = (
get_drivellava_prompt(trajectory_encoder)
)
data.extend(loaded)

return data


def main():

trajectory_encoder = TrajectoryEncoder()

train = load_json_dataset(
ENCODED_JSON,
trajectory_encoder,
)
val = load_json_dataset(
VAL_ENCODED_JSON,
trajectory_encoder,
)

print(f"Train: {len(train)}")
print(f"Val: {len(val)}")

# Shuffle train and val
random.shuffle(train)
random.shuffle(val)

new_train_json_path = os.path.abspath("checkpoints/train.json")
new_val_json_path = os.path.abspath("checkpoints/val.json")

# Save train to a temp file
with open(new_train_json_path, "w", encoding="utf-8") as f:
json_data = json.dumps(train, ensure_ascii=False, indent=4)
f.write(json_data)

with open(new_val_json_path, "w", encoding="utf-8") as f:
json_data = json.dumps(val, ensure_ascii=False, indent=4)
f.write(json_data)


if __name__ == "__main__":
main()
57 changes: 0 additions & 57 deletions drivellava/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,27 @@

from drivellava.constants import COMMAVQ_DIR

# from drivellava.constants import ENCODED_JSON, VAL_ENCODED_JSON
from drivellava.trajectory_encoder import TrajectoryEncoder


def load_json_dataset(
json_list: List[str],
trajectory_encoder: TrajectoryEncoder,
):
from drivellava.sparse_llava_dataset import get_drivellava_prompt

data = []
for json_path in json_list:
with open(json_path, "r", encoding="utf-8") as f:
loaded = json.load(f)
for index in range(len(loaded)):
assert len(loaded[index]["conversations"][1]["value"]) == 1
# print('val', loaded[index]["conversations"][1]["value"])
# exit()

loaded[index]["conversations"][1]["value"] = (
"Selected Trajectory: "
+ loaded[index]["conversations"][1]["value"]
)
loaded[index]["conversations"][0]["value"] = (
get_drivellava_prompt(trajectory_encoder)
)
data.extend(loaded)

return data


def load_json_dataset_balanced(
json_list: List[str],
trajectory_encoder: TrajectoryEncoder,
):
from drivellava.sparse_llava_dataset import get_drivellava_prompt

data = []
for json_path in json_list:
with open(json_path, "r", encoding="utf-8") as f:
loaded = json.load(f)
for index in range(len(loaded)):
assert len(loaded[index]["conversations"][1]["value"]) == 1
# print('val', loaded[index]["conversations"][1]["value"])
# exit()

loaded[index]["conversations"][1]["value"] = (
"Selected Trajectory: "
+ loaded[index]["conversations"][1]["value"]
)
loaded[index]["conversations"][0]["value"] = (
get_drivellava_prompt(trajectory_encoder)
)
data.extend(loaded)

# Balance by the class given by data[index]["conversations"][1]["value"]
Expand Down Expand Up @@ -105,43 +73,18 @@ def load_json_dataset_balanced(

def main():

trajectory_encoder = TrajectoryEncoder()

# train = load_json_dataset(
# ENCODED_JSON,
# trajectory_encoder,
# )
# val = load_json_dataset(
# VAL_ENCODED_JSON,
# trajectory_encoder,
# )

# train_json_path = os.path.abspath("checkpoints/train.json")
# val_json_path = os.path.abspath("checkpoints/val.json")

# # Save train to a temp file
# with open(train_json_path, "w", encoding="utf-8") as f:
# json_data = json.dumps(train, ensure_ascii=False, indent=4)
# f.write(json_data)

# with open(val_json_path, "w", encoding="utf-8") as f:
# json_data = json.dumps(val, ensure_ascii=False, indent=4)
# f.write(json_data)

train_json_path = os.path.join(COMMAVQ_DIR, "train.json")
val_json_path = os.path.join(COMMAVQ_DIR, "val.json")

train = load_json_dataset_balanced(
[
train_json_path,
],
trajectory_encoder,
)
val = load_json_dataset(
[
val_json_path,
],
trajectory_encoder,
)

print(f"Train: {len(train)}")
Expand Down

0 comments on commit af74918

Please sign in to comment.