-
Consider the simplified example where we have 2 hidden layer network with 10 units each. I would like to optimize different groups of parameters with different optimizers, e.g., all the 2D parameters (i.e., weights) with Adam and the rest (i.e., biases) with SGD. I was experimenting with initializing the optimizer parameters with a custom filter function, but then was getting key errors with the update function of the model. Perhaps I'm missing a more intuitive method... in PyTorch we could do the following: param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
adam_params, sgd_params = [], []
for n, p in param_dict.items():
if p.dim() == 2:
adam_params.append(p)
else:
sgd_params.append(p)
optimizer1 = torch.optim.Adam(adam_params)
optimizer2 = torch.optim.SGD(sgd_params)
optimizers = [optimizer1, optimizer2]
# training...
for epoch in range(epochs):
for opt in optimizers:
opt.zero_grad()
loss = criterion(model(X), T)
loss.backward()
for opt in optimizers:
opt.step() Our step function in MLX is a bit different and I'm not sure what exactly needs to change. Below is the boilerplate I'm working with. It has the potential (breaking) changes to optimize only a subset of parameters commented out. Any input or guidance on how we can use multiple optimizers on the same model would be greatly appreciated! import mlx.nn as nn
import mlx.core as mx
import mlx.optimizers as optim
from tqdm import tqdm
from functools import partial
from typing import List
class MLP(nn.Module):
def __init__(self, n_inputs: int, n_hiddens: List[int], n_outputs: int):
super().__init__()
self.layers = []
ni = n_inputs
for _, n_units in enumerate(n_hiddens):
self.layers.append(nn.Linear(ni, n_units))
self.layers.append(nn.Tanh())
ni = n_units
self.layers.append(nn.Linear(ni, n_outputs))
def __call__(self, x):
x = x.reshape(x.shape[0], -1)
for l in self.layers:
x = l(x)
return x
class Manager:
def __init__(self, model: nn.Module, optimizer: optim.Optimizer):
self.model = model
self.optimizer = optimizer
self.batch_size = None
self.train_error_trace = []
def _make_batches(self, X, T):
bs = self.batch_size if self.batch_size != -1 else X.shape[0]
for i in range(0, X.shape[0], bs):
yield X[i:i+bs], T[i:i+bs]
def eval_fn(self, X, T):
return nn.losses.mse_loss(self.model(X), T, reduction='mean')
def train(self, data, epochs: int, batch_size: int = 64):
self.batch_size = batch_size
state = [self.model.state, self.optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(X, T):
train_step_fn = nn.value_and_grad(self.model, self.eval_fn)
loss, grads = train_step_fn(X, T)
self.optimizer.update(self.model, grads)
return loss
epoch_bar = tqdm(range(epochs), desc='Training', unit='epoch')
self.model.train()
for _ in epoch_bar:
inds = mx.random.permutation(data[0].shape[0])
data = [v[inds] for v in data]
total_loss = 0
for X, T in self._make_batches(*data):
loss = step(X, T)
mx.eval(state)
total_loss += loss.item() * X.shape[0]
total_loss /= data[0].shape[0]
self.train_error_trace.append(total_loss)
postfix = {'loss': f'{total_loss:.3f}'}
epoch_bar.set_postfix(postfix)
if __name__ == '__main__':
X = mx.random.normal((256, 8))
T = mx.sin(X) + 1e-1 * mx.random.normal(X.shape)
model = MLP(X.shape[1], [10, 10], X.shape[1])
optimizer = optim.Adam(learning_rate=0.001)
# def custom_param_filter(module, key, value):
# o = nn.Module.trainable_parameter_filter(module, key, value)
# if isinstance(value, mx.array):
# return o if value.ndim == 2 else False
# return o
#
# optimizer.init(model.filter_and_map(custom_param_filter))
manager = Manager(model, optimizer)
manager.train((X, T), epochs=100) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
One way to do it is with multiple optimizers. Using your code as an example: import mlx.nn as nn
import mlx.core as mx
import mlx.optimizers as optim
from mlx.utils import tree_flatten, tree_unflatten
from tqdm import tqdm
from functools import partial
from typing import List
class MLP(nn.Module):
def __init__(self, n_inputs: int, n_hiddens: List[int], n_outputs: int):
super().__init__()
self.layers = []
ni = n_inputs
for _, n_units in enumerate(n_hiddens):
self.layers.append(nn.Linear(ni, n_units))
self.layers.append(nn.Tanh())
ni = n_units
self.layers.append(nn.Linear(ni, n_outputs))
def __call__(self, x):
x = x.reshape(x.shape[0], -1)
for l in self.layers:
x = l(x)
return x
class Manager:
def __init__(self, model: nn.Module, optimizers: List[optim.Optimizer]):
self.model = model
self.optimizers = optimizers
self.batch_size = None
self.train_error_trace = []
def _make_batches(self, X, T):
bs = self.batch_size if self.batch_size != -1 else X.shape[0]
for i in range(0, X.shape[0], bs):
yield X[i:i+bs], T[i:i+bs]
def eval_fn(self, X, T):
return nn.losses.mse_loss(self.model(X), T, reduction='mean')
def train(self, data, epochs: int, batch_size: int = 64):
self.batch_size = batch_size
state = [model] + [o.state for o in self.optimizers]
def split_grads(grads):
grads = tree_flatten(grads)
weights = [(k, v) for k, v in grads if v.ndim == 2]
biases = [(k, v) for k, v in grads if v.ndim == 1]
weights = tree_unflatten(weights)
biases = tree_unflatten(biases)
return weights, biases
@partial(mx.compile, inputs=state, outputs=state)
def step(X, T):
train_step_fn = nn.value_and_grad(self.model, self.eval_fn)
loss, grads = train_step_fn(X, T)
weights, biases = split_grads(grads)
self.optimizers[0].update(self.model, weights)
self.optimizers[1].update(self.model, biases)
return loss
epoch_bar = tqdm(range(epochs), desc='Training', unit='epoch')
self.model.train()
for _ in epoch_bar:
inds = mx.random.permutation(data[0].shape[0])
data = [v[inds] for v in data]
total_loss = 0
for X, T in self._make_batches(*data):
loss = step(X, T)
mx.eval(state)
total_loss += loss.item() * X.shape[0]
total_loss /= data[0].shape[0]
self.train_error_trace.append(total_loss)
postfix = {'loss': f'{total_loss:.3f}'}
epoch_bar.set_postfix(postfix)
if __name__ == '__main__':
X = mx.random.normal((256, 8))
T = mx.sin(X) + 1e-1 * mx.random.normal(X.shape)
model = MLP(X.shape[1], [10, 10], X.shape[1])
optimizers = [optim.Adam(learning_rate=0.001), optim.SGD(learning_rate=0.1)]
manager = Manager(model, optimizers)
manager.train((X, T), epochs=100)
|
Beta Was this translation helpful? Give feedback.
One way to do it is with multiple optimizers. Using your code as an example: