Skip to content

Commit

Permalink
config validator, tqdm add
Browse files Browse the repository at this point in the history
  • Loading branch information
saanikat committed Dec 10, 2024
1 parent e016842 commit 971e896
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ model = AttrStandardizer(
)
results = model.standardize(pep="geo/gse228634:default")

assert results
print(results) #Dictionary of suggested predictions with their confidence: {'attr_1':{'prediction_1': 0.70, 'prediction_2':0.30}}
```

### Training custom schemas
Expand Down
21 changes: 8 additions & 13 deletions bedms/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
model_testing,
plot_confusion_matrix,
auc_roc_curve,
validate_config,
)
from .const import PROJECT_NAME
from .model import BoWSTModel
Expand Down Expand Up @@ -65,19 +66,13 @@ def __init__(self, config: str) -> None:
self.all_labels: List[int] = []
self.all_preds: List[int] = []

with open(config, "r") as file:
self.config = yaml.safe_load(file)

# self.validate_config(self.config)

def validate_config(config):
"""
Validates the given config file dictionary
:param dict config: The config that needs to be validated.
:raises
"""
try:
with open(config, "r") as file:
self.config = yaml.safe_load(file)
validate_config(self.config)
print("Config file provided is valid!")
except (ValueError, TypeError) as e:
print(f"Config validation error: {e}")

def load_data(
self,
Expand Down
97 changes: 89 additions & 8 deletions bedms/utils_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from glob import glob
import warnings
from collections import Counter
from typing import List, Tuple, Iterator, Dict
from typing import List, Tuple, Iterator, Dict, Union
import pickle
import random

from tqdm import tqdm

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -331,6 +331,85 @@ def encode_data(
)


def validate_config(config: Dict[str, Union[str, int]]) -> None:
"""
Validates the given config file dictionary
:param dict config: The config that needs to be validated.
:raises
ValueError: Raised when there is an error in the Values provided for a key.
TypeError: Raised when there is an error in the Data Types provided.
"""
config_structure = {
"dataset": {
"values_dir_pth": str,
"headers_dir_pth": str,
},
"data_split": {
"train_set": float,
"test_set": float,
"val_set": float,
},
"model": {
"hidden_size": int,
"dropout_prob": float,
},
"training": {
"batch_size": int,
"num_epochs": int,
"learning_rate": float,
"l2_regularization": float,
"model_pth": str,
"num_cluster": int,
"vectorizer_pth": str,
"label_encoder_pth": str,
"sentence_transformer_model": str,
"bow_drops": int,
"embedding_size": int,
},
"visualization": {
"accuracy_fig_pth": str,
"loss_fig_pth": str,
"confusion_matrix_fig_pth": str,
"roc_fig_pth": str,
},
}

def validate_section(section, expected_structure, parent_key=""):
for key, expected_type in expected_structure.items():
full_key = f"{parent_key}.{key}" if parent_key else key
if key not in section:
raise ValueError(f"Missing required key: '{full_key}'")
value = section[key]
if isinstance(expected_type, dict):
if not isinstance(value, dict):
raise TypeError(f"Key '{full_key}' must be a dictionary.")
validate_section(value, expected_type, full_key)
else:
if not isinstance(value, expected_type):
raise TypeError(
f"Key '{full_key}' must be of type {expected_type.__name__}, but got {type(value).__name__}."
)
if expected_type == str and value.strip() == "":
raise ValueError(f"Key '{full_key}' cannot be an empty string.")
if (
expected_type == float
and not (0 <= value <= 1)
and parent_key == "data_split"
):
raise ValueError(
f"Key '{full_key}' must be a float between 0 and 1 for data split values of train, test, validation."
)

validate_section(config, config_structure)

if sum(config["data_split"].values()) != 1.0:
raise ValueError(
"You should provide the percentages for data split. Data split values must add up to 1."
)


def data_loader(
encoded_data: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
batch_size: int,
Expand Down Expand Up @@ -441,11 +520,13 @@ def train_model(

model.train()

for epoch in range(num_epochs):
for epoch in tqdm(range(num_epochs), desc="Training Progress"):
total_samples = 0
correct_predictions = 0
train_loss = 0.0
for x_values_bow, x_values_embeddings, x_headers_embeddings, y in train_loader:
for x_values_bow, x_values_embeddings, x_headers_embeddings, y in tqdm(
train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"
):
x_values_bow = x_values_bow.to(device)
x_values_embeddings = x_values_embeddings.to(device)
x_headers_embeddings = x_headers_embeddings.to(device)
Expand Down Expand Up @@ -502,10 +583,10 @@ def train_model(
val_accuracies.append(val_accuracy)
val_losses.append(val_loss)

print(
f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {train_loss:.4f}, \
Training Accuracy: {train_accuracy:.2f}%, Validation Loss: {val_loss:.4f}, \
Validation Accuracy: {val_accuracy:.2f}%"
tqdm.write(
f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {train_loss:.4f}, "
f"Training Accuracy: {train_accuracy:.2f}%, Validation Loss: {val_loss:.4f}, "
f"Validation Accuracy: {val_accuracy:.2f}%"
)

# Early stop
Expand Down

0 comments on commit 971e896

Please sign in to comment.