Skip to content

Commit

Permalink
train, train_from_config
Browse files Browse the repository at this point in the history
  • Loading branch information
homerjed committed Oct 8, 2024
1 parent 04f5d2b commit a9ed119
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 24 deletions.
2 changes: 1 addition & 1 deletion sbgm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import sde as sde
from . import _ode as ode
from . import _train as train
from . import _train as train
from . import _sample as sample
from . import _misc as utils
from . import _shard as shard
Expand Down
12 changes: 7 additions & 5 deletions sbgm/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ def count_params(model):
)


def plot_model_sample(eu_sample, ode_sample, dataset, config, filename):
def plot_model_sample(eu_sample, ode_sample, dataset, cmap, filename):

def plot_sample(samples, mode):
fig, ax = plt.subplots(dpi=300)
samples_onto_ax(samples, fig, ax, vs=None, cmap=config.cmap)
samples_onto_ax(
samples, fig, ax, vs=None, cmap=cmap if cmap is not None else "gray_r"
)
plt.savefig(filename + "_" + mode, bbox_inches="tight")
plt.close()

Expand Down Expand Up @@ -102,10 +104,10 @@ def plot_sde(sde, filename):
plt.close()


def make_dirs(root_dir, config):
def make_dirs(root_dir, dataset_name):
# Make experiment and image save directories
img_dir = os.path.join(root_dir, "exps/", config.dataset_name + "/")
exp_dir = os.path.join(root_dir, "imgs/", config.dataset_name + "/")
exp_dir = os.path.join(root_dir, "exps/", dataset_name + "/")
img_dir = os.path.join(exp_dir, "imgs/")
for _dir in [img_dir, exp_dir]:
if not os.path.exists(_dir):
os.makedirs(_dir, exist_ok=True)
Expand Down
251 changes: 240 additions & 11 deletions sbgm/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import jax.random as jr
import jax.tree_util as jtu
import equinox as eqx
from jaxtyping import PyTree, Key, Array
from jaxtyping import Key, Array
from ml_collections import ConfigDict
import optax
from tqdm import trange
from tqdm.auto import trange

from .sde import SDE
from ._sample import get_eu_sample_fn, get_ode_sample_fn
Expand Down Expand Up @@ -58,7 +58,7 @@ def single_loss_fn(
key: Key
) -> Array:
key_noise, key_apply = jr.split(key)
mean, std = sde.marginal_prob(x, t) # std = jnp.sqrt(jnp.maximum(std, 1e-5))
mean, std = sde.marginal_prob(x, t)
noise = jr.normal(key_noise, x.shape)
y = mean + std * noise
y_ = model(t, y, q=q, a=a, key=key_apply) # Inference is true in validation
Expand Down Expand Up @@ -108,7 +108,7 @@ def make_step(
model = eqx.nn.inference_mode(model, False)
loss_fn = eqx.filter_value_and_grad(batch_loss_fn)
loss, grads = loss_fn(model, sde, x, q, a, key)
updates, opt_state = opt_update(grads, opt_state)
updates, opt_state = opt_update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
key, _ = jr.split(key)
return loss, model, key, opt_state
Expand All @@ -132,7 +132,7 @@ def get_opt(config: ConfigDict):
return getattr(optax, config.opt)(config.lr, **config.opt_kwargs)


def train(
def train_from_config(
key: Key,
# Diffusion model and SDE
model: eqx.Module,
Expand All @@ -145,14 +145,14 @@ def train(
reload_opt_state: bool = False,
# Sharding of devices to run on
sharding: Optional[jax.sharding.Sharding] = None,
*,
# Location to save model, figs, .etc in
save_dir: Optional[str] = None,
plot_train_data: bool = False
):
"""
Trains a diffusion model built from a score network (`model`) using a stochastic
differential equation (SDE, `sde`) with a given dataset, with support for various
configurations and saving options.
differential equation (SDE, `sde`) with a given dataset. Requires a config object.
Parameters:
-----------
Expand Down Expand Up @@ -200,10 +200,10 @@ def train(
training from where it left off.
"""

print(f"Training SGM with {config.sde.sde} SDE on {config.dataset_name} dataset.")
print(f"Training SGM with a {config.sde.sde} SDE on {config.dataset_name} dataset.")

# Experiment and image save directories
exp_dir, img_dir = make_dirs(save_dir, config)
exp_dir, img_dir = make_dirs(save_dir, config.dataset_name)

# Model and optimiser save filenames
model_filename = os.path.join(
Expand Down Expand Up @@ -287,7 +287,7 @@ def train(
}
)

if (step % config.print_every) == 0 or step == config.n_steps - 1:
if (step % config.sample_and_save_every) == 0 or step == config.n_steps - 1:
# Sample model
key_Q, key_sample = jr.split(sample_key) # Fixed key
sample_keys = jr.split(key_sample, config.sample_size ** 2)
Expand Down Expand Up @@ -315,7 +315,7 @@ def train(
eu_sample,
ode_sample,
dataset,
config,
cmap=config.cmap,
filename=os.path.join(img_dir, f"samples_{step:06d}"),
)

Expand All @@ -335,4 +335,233 @@ def train(
# Plot losses etc
plot_metrics(train_losses, valid_losses, step, exp_dir)

return model


def train(
key: Key,
# Diffusion model and SDE
model: eqx.Module,
sde: SDE,
# Dataset
dataset: dataclass,
# Training
opt: optax.GradientTransformation,
n_steps: int,
batch_size: int,
use_ema: bool = True,
sample_and_save_every: int = 1_000,
# Sampling
sample_size: int = 1,
eu_sample: bool = False,
ode_sample: bool = False,
# Reload optimiser or not
reload_opt_state: bool = False,
# Sharding of devices to run on
sharding: Optional[jax.sharding.Sharding] = None,
*,
# Location to save model, figs, .etc in
save_dir: Optional[str] = None,
# Plotting
plot_train_data: bool = False,
cmap: str = "gray_r"
):
"""
Trains a diffusion model using a stochastic differential equation (SDE) based on
the provided score network model and dataset, with support for optimizer state reloading,
Exponential Moving Average (EMA), and periodic model sampling and evaluation.
Parameters:
-----------
`key` : `Key`
A JAX random key for sampling and model initialization.
`model` : `eqx.Module`
The score network model to be trained.
`sde` : `SDE`
The stochastic differential equation (SDE) defining the forward and reverse diffusion processes.
`dataset` : `dataclass`
A dataset object containing the data loaders for training and validation.
`opt` : `optax.GradientTransformation`
The optimizer transformation function (from Optax) for updating model parameters.
`n_steps` : `int`
The total number of training steps.
`batch_size` : `int`
The size of the mini-batches to be used for each training step.
`use_ema` : `bool`, default: `True`
Whether to use Exponential Moving Average (EMA) of model parameters for validation and sampling.
`sample_and_save_every` : `int`, default: `1_000`
The frequency in training steps in which the model is saved and sampled.
`sample_size` : `int`, default: `None`
Number of samples to generate during the sampling phase at each logging step.
`reload_opt_state` : `bool`, default: `False`
Whether to reload the model and optimizer state from a previous checkpoint to continue training.
`sharding` : `Optional[jax.sharding.Sharding]`, default: `None`
Optional sharding scheme to distribute training across multiple devices.
`save_dir` : `Optional[str]`, default: `None`
Directory path to save the model, optimizer state, and training logs. If `None`, a default path is generated.
`plot_train_data` : `bool`, default: `False`
If `True`, plots a sample of the training data at the beginning of training.
`cmap` : `str`, default: `"gray_r"`
The colormap to be used for plotting sampled data. Ignored for non-image data.
Returns:
--------
`model` : `eqx.Module`
The trained model after completing the specified number of training steps.
Notes:
------
- The function trains a model using a diffusion process governed by an SDE and saves checkpoints
and intermediate samples during the training process.
- Supports both Euler-Maruyama (EU) and ODE sampling techniques depending on the type of diffusion process.
- If `use_ema` is enabled, EMA of the model parameters is applied for better stability and performance.
- Supports restarting the training process by reloading the optimizer state and model from previously saved checkpoints.
"""

print(f"Training SGM with a {sde.__class__.__name__} on {dataset.name} dataset.")

# Experiment and image save directories
exp_dir, img_dir = make_dirs(save_dir, dataset.name) # Dataset name from config

# Model and optimiser save filenames
model_type = model.__class__.__name__
model_filename = os.path.join(
exp_dir, f"sgm_{dataset.name}_{model_type}.eqx"
)
state_filename = os.path.join(
exp_dir, f"state_{dataset.name}_{model_type}.obj"
)

# Plot SDE over time
plot_sde(sde, filename=os.path.join(exp_dir, "sde.png"))

# Plot a sample of training data
if plot_train_data:
plot_train_sample(
dataset,
sample_size=sample_size,
cmap=cmap,
vs=None,
filename=os.path.join(img_dir, "data.png")
)

# Reload optimiser and state if so desired
if not reload_opt_state:
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))
start_step = 0
else:
state = load_opt_state(filename=state_filename)
model = load_model(model, model_filename)

opt, opt_state, start_step = state.values()

print("Loaded model and optimiser state.")

train_key, sample_key, valid_key = jr.split(key, 3)

train_total_value = 0
valid_total_value = 0
train_total_size = 0
valid_total_size = 0
train_losses = []
valid_losses = []

if use_ema:
ema_model = deepcopy(model)

with trange(start_step, n_steps, colour="red") as steps:
for step, train_batch, valid_batch in zip(
steps,
dataset.train_dataloader.loop(batch_size),
dataset.valid_dataloader.loop(batch_size)
):
# Train
x, q, a = shard_batch(train_batch, sharding)
_Lt, model, train_key, opt_state = make_step(
model, sde, x, q, a, train_key, opt_state, opt.update
)

train_total_value += _Lt.item()
train_total_size += 1
train_losses.append(train_total_value / train_total_size)

if use_ema:
ema_model = apply_ema(ema_model, model)

# Validate
x, q, a = shard_batch(valid_batch, sharding)
_Lv = evaluate(
ema_model if use_ema else model, sde, x, q, a, valid_key
)

valid_total_value += _Lv.item()
valid_total_size += 1
valid_losses.append(valid_total_value / valid_total_size)

steps.set_postfix(
{"Lt" : f"{train_losses[-1]:.3E}", "Lv" : f"{valid_losses[-1]:.3E}"}
)

if (step % sample_and_save_every) == 0 or step == n_steps - 1:
# Sample model
key_Q, key_sample = jr.split(sample_key) # Fixed key
sample_keys = jr.split(key_sample, sample_size ** 2)

# Sample random labels or use parameter prior for labels
Q, A = dataset.label_fn(key_Q, sample_size ** 2)

# EU sampling
if eu_sample:
sample_fn = get_eu_sample_fn(
ema_model if use_ema else model, sde, dataset.data_shape
)
eu_sample = jax.vmap(sample_fn)(sample_keys, Q, A)

# ODE sampling
if ode_sample:
sample_fn = get_ode_sample_fn(
ema_model if use_ema else model, sde, dataset.data_shape
)
ode_sample = jax.vmap(sample_fn)(sample_keys, Q, A)

# Sample images and plot
if eu_sample or ode_sample:
plot_model_sample(
eu_sample,
ode_sample,
dataset,
cmap=cmap,
filename=os.path.join(img_dir, f"samples_{step:06d}"),
)

# Save model
save_model(
ema_model if use_ema else model, model_filename
)

# Save optimiser state
save_opt_state(
opt,
opt_state,
i=step,
filename=state_filename
)

# Plot losses etc
plot_metrics(train_losses, valid_losses, step, exp_dir)

return model
Loading

0 comments on commit a9ed119

Please sign in to comment.