From 971e89676d05cc394170a21bdfa06f9b8b582ae8 Mon Sep 17 00:00:00 2001 From: saanikat Date: Tue, 10 Dec 2024 02:53:45 -0500 Subject: [PATCH] config validator, tqdm add --- README.md | 2 +- bedms/train.py | 21 ++++------ bedms/utils_train.py | 97 ++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 98 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index a1e6789..c19c376 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ model = AttrStandardizer( ) results = model.standardize(pep="geo/gse228634:default") -assert results +print(results) #Dictionary of suggested predictions with their confidence: {'attr_1':{'prediction_1': 0.70, 'prediction_2':0.30}} ``` ### Training custom schemas diff --git a/bedms/train.py b/bedms/train.py index 8cb26cd..8cc31ed 100644 --- a/bedms/train.py +++ b/bedms/train.py @@ -26,6 +26,7 @@ model_testing, plot_confusion_matrix, auc_roc_curve, + validate_config, ) from .const import PROJECT_NAME from .model import BoWSTModel @@ -65,19 +66,13 @@ def __init__(self, config: str) -> None: self.all_labels: List[int] = [] self.all_preds: List[int] = [] - with open(config, "r") as file: - self.config = yaml.safe_load(file) - - # self.validate_config(self.config) - - def validate_config(config): - """ - Validates the given config file dictionary - - :param dict config: The config that needs to be validated. - :raises - - """ + try: + with open(config, "r") as file: + self.config = yaml.safe_load(file) + validate_config(self.config) + print("Config file provided is valid!") + except (ValueError, TypeError) as e: + print(f"Config validation error: {e}") def load_data( self, diff --git a/bedms/utils_train.py b/bedms/utils_train.py index 9661988..d9ccc43 100644 --- a/bedms/utils_train.py +++ b/bedms/utils_train.py @@ -7,10 +7,10 @@ from glob import glob import warnings from collections import Counter -from typing import List, Tuple, Iterator, Dict +from typing import List, Tuple, Iterator, Dict, Union import pickle import random - +from tqdm import tqdm import numpy as np import pandas as pd @@ -331,6 +331,85 @@ def encode_data( ) +def validate_config(config: Dict[str, Union[str, int]]) -> None: + """ + Validates the given config file dictionary + + :param dict config: The config that needs to be validated. + :raises + ValueError: Raised when there is an error in the Values provided for a key. + TypeError: Raised when there is an error in the Data Types provided. + + """ + config_structure = { + "dataset": { + "values_dir_pth": str, + "headers_dir_pth": str, + }, + "data_split": { + "train_set": float, + "test_set": float, + "val_set": float, + }, + "model": { + "hidden_size": int, + "dropout_prob": float, + }, + "training": { + "batch_size": int, + "num_epochs": int, + "learning_rate": float, + "l2_regularization": float, + "model_pth": str, + "num_cluster": int, + "vectorizer_pth": str, + "label_encoder_pth": str, + "sentence_transformer_model": str, + "bow_drops": int, + "embedding_size": int, + }, + "visualization": { + "accuracy_fig_pth": str, + "loss_fig_pth": str, + "confusion_matrix_fig_pth": str, + "roc_fig_pth": str, + }, + } + + def validate_section(section, expected_structure, parent_key=""): + for key, expected_type in expected_structure.items(): + full_key = f"{parent_key}.{key}" if parent_key else key + if key not in section: + raise ValueError(f"Missing required key: '{full_key}'") + value = section[key] + if isinstance(expected_type, dict): + if not isinstance(value, dict): + raise TypeError(f"Key '{full_key}' must be a dictionary.") + validate_section(value, expected_type, full_key) + else: + if not isinstance(value, expected_type): + raise TypeError( + f"Key '{full_key}' must be of type {expected_type.__name__}, but got {type(value).__name__}." + ) + if expected_type == str and value.strip() == "": + raise ValueError(f"Key '{full_key}' cannot be an empty string.") + if ( + expected_type == float + and not (0 <= value <= 1) + and parent_key == "data_split" + ): + raise ValueError( + f"Key '{full_key}' must be a float between 0 and 1 for data split values of train, test, validation." + ) + + validate_section(config, config_structure) + + if sum(config["data_split"].values()) != 1.0: + raise ValueError( + "You should provide the percentages for data split. Data split values must add up to 1." + ) + + def data_loader( encoded_data: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], batch_size: int, @@ -441,11 +520,13 @@ def train_model( model.train() - for epoch in range(num_epochs): + for epoch in tqdm(range(num_epochs), desc="Training Progress"): total_samples = 0 correct_predictions = 0 train_loss = 0.0 - for x_values_bow, x_values_embeddings, x_headers_embeddings, y in train_loader: + for x_values_bow, x_values_embeddings, x_headers_embeddings, y in tqdm( + train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}" + ): x_values_bow = x_values_bow.to(device) x_values_embeddings = x_values_embeddings.to(device) x_headers_embeddings = x_headers_embeddings.to(device) @@ -502,10 +583,10 @@ def train_model( val_accuracies.append(val_accuracy) val_losses.append(val_loss) - print( - f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {train_loss:.4f}, \ - Training Accuracy: {train_accuracy:.2f}%, Validation Loss: {val_loss:.4f}, \ - Validation Accuracy: {val_accuracy:.2f}%" + tqdm.write( + f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {train_loss:.4f}, " + f"Training Accuracy: {train_accuracy:.2f}%, Validation Loss: {val_loss:.4f}, " + f"Validation Accuracy: {val_accuracy:.2f}%" ) # Early stop