MAML & LLMs #827
-
Is mlx flexible enough to fine-tune a LLM in the style of MAML? In other words, is it possible to fine-tune a LLM in mlx via bi-level gradient descent? |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 18 replies
-
Yes for sure you can compose def step(outer_w, inner_w):
def loss(inner_w, x, y)
nn.losses.mse(inner_w @ x, y)
dloss_dinner_w = mx.grad(loss)(inner_w, x, y)
inner_w = inner_w + (outer_w @ x) * d_loss_dinner_w
dstep_douter_w = mx.grad(step)(outer_w, inner_w) (Super simple + untested but just to give you the flavor of how that could go). |
Beta Was this translation helpful? Give feedback.
-
Thanks @awni for your answer & example. I'm now exploring whether it's possible to apply The first step is to process the text correctly. I understand why the following doesn't work, but I haven't yet come up with an alternative.
|
Beta Was this translation helpful? Give feedback.
-
No you probably can't vmap that out of the box. If you want to vmap over the call of BERT you'd have to do something like: model = BERT()
def forward(params, x):
model.update(params)
return model(x)
vmapfn = mx.vmap(forward)
y = vmapfn(model.parameters(), x) |
Beta Was this translation helpful? Give feedback.
-
@awni I followed your suggestion on how to define
|
Beta Was this translation helpful? Give feedback.
Yes for sure you can compose
vjp
/value_and_grad
/grad
to any depth and it will work. So to do a bilevel thing you would do something like:(Super simple + untested but just to give you the flavor of how that could go).