Skip to content

Commit

Permalink
improve typing across the package and make some slight quality of lif…
Browse files Browse the repository at this point in the history
…e improvements
  • Loading branch information
AlexanderFengler committed Dec 8, 2024
1 parent 725630c commit a313a33
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 218 deletions.
117 changes: 67 additions & 50 deletions lanfactory/trainers/jax_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import jax
from jax import numpy as jnp
from typing import Sequence, Callable, Any

import flax
from flax.training import train_state
Expand All @@ -27,12 +28,12 @@
"""


def MLPJaxFactory(network_config={}, train=True):
def MLPJaxFactory(network_config: dict | str = {}, train: bool = True) -> "MLPJax":
"""Factory function to create a MLPJax object.
Arguments
---------
network_config (dict):
Dictionary containing the network configuration.
network_config (dict | str):
Dictionary containing the network configuration or path to pickled config.
train (bool):
Whether the model should be trained or not.
Returns
Expand Down Expand Up @@ -87,7 +88,7 @@ class MLPJax(nn.Module):
# Define network type
# network_type = "lan" if train_output_type == "logprob" else "cpn"

def setup(self):
def setup(self) -> None:
"""Setup function for the JaxMLP class.
Initializes the layers and activation functions.
"""
Expand All @@ -103,7 +104,7 @@ def setup(self):
# Identification
self.network_type = self.network_type_dict[self.train_output_type]

def __call__(self, inputs):
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
"""Call function for the JaxMLP class.
Performs forward pass through the network.
Expand Down Expand Up @@ -137,7 +138,9 @@ def __call__(self, inputs):

return x

def load_state_from_file(self, seed=42, input_dim=6, file_path=None):
def load_state_from_file(
self, seed: int = 42, input_dim: int = 6, file_path: str | None = None
) -> flax.core.frozen_dict.FrozenDict:
"""Loads the state dictionary from a file.
Arguments
Expand Down Expand Up @@ -183,11 +186,11 @@ def load_state_from_file(self, seed=42, input_dim=6, file_path=None):

def make_forward_partial(
self,
seed=42,
input_dim=6,
state=None,
add_jitted=False,
):
seed: int = 42,
input_dim: int = 6,
state: str | dict | None = None,
add_jitted: bool = False,
) -> tuple[Callable, Callable | None]:
"""Creates a partial function for the forward pass of the network.
Arguments
Expand Down Expand Up @@ -230,17 +233,19 @@ def make_forward_partial(


class ModelTrainerJaxMLP:
"""Class for training JaxMLP models."""

def __init__(
self,
train_config=None,
model=None,
train_dl=None,
valid_dl=None,
allow_abs_path_folder_generation=False,
pin_memory=False,
seed=None,
):
"""Class for training JaxMLP models.
train_config: dict,
model: MLPJax,
train_dl: Any,
valid_dl: Any,
allow_abs_path_folder_generation: bool = False,
pin_memory: bool = False,
seed: int | None = None,
) -> None:
"""Initialize class for training JaxMLP models.
Arguments
---------
Expand All @@ -266,7 +271,7 @@ def __init__(
"""
if "loss_dict" not in train_config.keys():
self.loss_dict = {
self.loss_dict: dict[str, dict] = {
"huber": {"fun": optax.huber_loss, "kwargs": {"delta": 1}},
"mse": {"fun": optax.l2_loss, "kwargs": {}},
"bcelogit": {"fun": optax.sigmoid_binary_cross_entropy, "kwargs": {}},
Expand All @@ -276,7 +281,7 @@ def __init__(

if "lr_dict" not in train_config.keys():
# Todo: Add more schedules (for now warmup_cosine_decay_schedule)
self.lr_dict = {
self.lr_dict: dict[str, float] = {
"init_value": 0.0002,
"peak_value": 0.02,
"end_value": 0.0,
Expand All @@ -294,31 +299,33 @@ def __init__(
else:
self.seed = seed
self.allow_abs_path_folder_generation = allow_abs_path_folder_generation
self.wandb_on = 0
self.wandb_on: int = 0

self.__get_loss()
self.apply_model_train = self.__make_apply_model(train=True)
self.apply_model_eval = self.__make_apply_model(train=False)
self.update_model = self.__make_update_model()

self.training_history = (
self.training_history: str = (
"Please run training for this attribute to be specified!"
)
self.state = "Please run training for this attribute to be specified!"
self.state: str = "Please run training for this attribute to be specified!"

def __get_loss(self):
def __get_loss(self) -> None:
"""Define loss function."""
self.loss = partial(
self.loss_dict[self.train_config["loss"]]["fun"],
**self.loss_dict[self.train_config["loss"]]["kwargs"],
)

def __make_apply_model(self, train=True):
def __make_apply_model(self, train: bool = True) -> Callable:
"""Compile forward pass with loss aplication"""

@jax.jit
def apply_model_core(state, features, labels):
def loss_fn(params):
def apply_model_core(
state: train_state.TrainState, features: jnp.ndarray, labels: jnp.ndarray
) -> tuple[Any, float] | float:
def loss_fn(params: dict) -> tuple[float, jnp.ndarray]:
pred = state.apply_fn(params, features)
loss = self.loss(pred, labels)
loss = jnp.mean(loss)
Expand All @@ -334,18 +341,23 @@ def loss_fn(params):

return apply_model_core

def __make_update_model(self):
def __make_update_model(self) -> Callable:
"""Compile gradient application"""

@jax.jit
def update_model(state, grads):
def update_model(
state: train_state.TrainState, grads: dict
) -> train_state.TrainState:
return state.apply_gradients(grads=grads)

return update_model

def __try_wandb(
self, wandb_project_id="projectid", file_id="fileid", run_id="runid"
):
self,
wandb_project_id: str = "projectid",
file_id: str = "fileid",
run_id: str = "runid",
) -> None:
"""Helper function to initialize wandb
Arguments
Expand Down Expand Up @@ -375,11 +387,9 @@ def __try_wandb(
except ModuleNotFoundError:
print("No wandb found, proceeding without logging")

def create_train_state(self, rng):
def create_train_state(self, rng: jax.random.PRNGKey) -> train_state.TrainState:
"""Create initial train state"""
params = self.model.init(
rng, jnp.ones((1, self.train_dl.dataset.input_dim))
) # self.train_config['input_size'])))
params = self.model.init(rng, jnp.ones((1, self.train_dl.dataset.input_dim)))
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=self.lr_dict["init_value"],
peak_value=self.lr_dict["peak_value"],
Expand All @@ -392,7 +402,14 @@ def create_train_state(self, rng):
apply_fn=self.model.apply, params=params, tx=tx
)

def run_epoch(self, state, train=True, verbose=1, epoch=0, max_epochs=0):
def run_epoch(
self,
state: train_state.TrainState,
train: bool = True,
verbose: int = 1,
epoch: int = 0,
max_epochs: int = 0,
) -> tuple[train_state.TrainState, float]:
"""Run one epoch of training or validation
Arguments
---------
Expand Down Expand Up @@ -487,18 +504,18 @@ def run_epoch(self, state, train=True, verbose=1, epoch=0, max_epochs=0):

def train_and_evaluate(
self,
output_folder="data/",
output_file_id="fileid",
run_id="runid",
wandb_on=True,
wandb_project_id="projectid",
save_history=True,
save_model=True,
save_config=True,
save_all=True,
save_data_details=True,
verbose=1,
):
output_folder: str = "data/",
output_file_id: str = "fileid",
run_id: str = "runid",
wandb_on: bool = True,
wandb_project_id: str = "projectid",
save_history: bool = True,
save_model: bool = True,
save_config: bool = True,
save_all: bool = True,
save_data_details: bool = True,
verbose: int = 1,
) -> train_state.TrainState:
"""Train and evaluate JAXMLP model.
Arguments
---------
Expand Down
Loading

0 comments on commit a313a33

Please sign in to comment.