From e01684236fb60cba886be1c958d75d26d1a9f4fd Mon Sep 17 00:00:00 2001 From: saanikat Date: Tue, 10 Dec 2024 01:05:08 -0500 Subject: [PATCH] percentages for data split --- bedms/train.py | 35 +++++++++++++++++++++++------------ training_config.yaml | 6 +++--- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/bedms/train.py b/bedms/train.py index b7a5c77..8cb26cd 100644 --- a/bedms/train.py +++ b/bedms/train.py @@ -1,6 +1,7 @@ """ This is the training script with which the user can train their own models.""" import logging +import random import torch from torch import nn from torch import optim @@ -67,6 +68,17 @@ def __init__(self, config: str) -> None: with open(config, "r") as file: self.config = yaml.safe_load(file) + # self.validate_config(self.config) + + def validate_config(config): + """ + Validates the given config file dictionary + + :param dict config: The config that needs to be validated. + :raises + + """ + def load_data( self, ) -> Tuple[ @@ -104,26 +116,25 @@ def load_data( ) return - total_files = len(values_files_list) - paired_files = list(zip(values_files_list, headers_files_list)) + random.shuffle(paired_files) + train_size = self.config["data_split"]["train_set"] test_size = self.config["data_split"]["test_set"] val_size = self.config["data_split"]["val_set"] - if train_size + val_size + test_size > total_files: - logger.error( - f"Data split sizes exceed total number of files: " - f"train({train_size}) + val({val_size}) + \ - test({test_size}) > total_files({total_files})" - ) - return + num_train_files = int(train_size * len(paired_files)) + num_test_files = int(test_size * len(paired_files)) + num_val_files = int(val_size * len(paired_files)) - train_files = paired_files[:train_size] - val_files = paired_files[train_size : train_size + val_size] + train_files = paired_files[:num_train_files] + val_files = paired_files[num_train_files : num_train_files + num_val_files] test_files = paired_files[ - train_size + val_size : train_size + val_size + test_size + num_train_files + + num_val_files : num_train_files + + num_val_files + + num_test_files ] logger.info(f"Training on {len(train_files)} file sets") diff --git a/training_config.yaml b/training_config.yaml index 75910d2..4ac6756 100644 --- a/training_config.yaml +++ b/training_config.yaml @@ -3,9 +3,9 @@ dataset: headers_dir_pth: "path/to/training/headers/directory" #Path to the attributes directory data_split: - train_set: 8000 #Number of csv value-attribute file pairs for training set - test_set: 100 #Number of csv value-attribute file pairs for testing set - val_set: 100 #Number of csv value-attribute file pairs for validation set + train_set: 0.80 #Percentage of csv value-attribute file pairs for training set + test_set: 0.10 #Percentage of csv value-attribute file pairs for testing set + val_set: 0.10 #Percentage of csv value-attribute file pairs for validation set model: hidden_size: 32 #Hidden size for training the model