Skip to content

Commit

Permalink
percentages for data split
Browse files Browse the repository at this point in the history
  • Loading branch information
saanikat committed Dec 10, 2024
1 parent fef2ed7 commit e016842
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
35 changes: 23 additions & 12 deletions bedms/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" This is the training script with which the user can train their own models."""

import logging
import random
import torch
from torch import nn
from torch import optim
Expand Down Expand Up @@ -67,6 +68,17 @@ def __init__(self, config: str) -> None:
with open(config, "r") as file:
self.config = yaml.safe_load(file)

# self.validate_config(self.config)

def validate_config(config):
"""
Validates the given config file dictionary
:param dict config: The config that needs to be validated.
:raises
"""

def load_data(
self,
) -> Tuple[
Expand Down Expand Up @@ -104,26 +116,25 @@ def load_data(
)
return

total_files = len(values_files_list)

paired_files = list(zip(values_files_list, headers_files_list))

random.shuffle(paired_files)

train_size = self.config["data_split"]["train_set"]
test_size = self.config["data_split"]["test_set"]
val_size = self.config["data_split"]["val_set"]

if train_size + val_size + test_size > total_files:
logger.error(
f"Data split sizes exceed total number of files: "
f"train({train_size}) + val({val_size}) + \
test({test_size}) > total_files({total_files})"
)
return
num_train_files = int(train_size * len(paired_files))
num_test_files = int(test_size * len(paired_files))
num_val_files = int(val_size * len(paired_files))

train_files = paired_files[:train_size]
val_files = paired_files[train_size : train_size + val_size]
train_files = paired_files[:num_train_files]
val_files = paired_files[num_train_files : num_train_files + num_val_files]
test_files = paired_files[
train_size + val_size : train_size + val_size + test_size
num_train_files
+ num_val_files : num_train_files
+ num_val_files
+ num_test_files
]

logger.info(f"Training on {len(train_files)} file sets")
Expand Down
6 changes: 3 additions & 3 deletions training_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ dataset:
headers_dir_pth: "path/to/training/headers/directory" #Path to the attributes directory

data_split:
train_set: 8000 #Number of csv value-attribute file pairs for training set
test_set: 100 #Number of csv value-attribute file pairs for testing set
val_set: 100 #Number of csv value-attribute file pairs for validation set
train_set: 0.80 #Percentage of csv value-attribute file pairs for training set
test_set: 0.10 #Percentage of csv value-attribute file pairs for testing set
val_set: 0.10 #Percentage of csv value-attribute file pairs for validation set

model:
hidden_size: 32 #Hidden size for training the model
Expand Down

0 comments on commit e016842

Please sign in to comment.