Skip to content

Commit

Permalink
set states to reset for 8bit optimizers and handle quantile runtime e…
Browse files Browse the repository at this point in the history
…rror for embeddings
  • Loading branch information
winglian committed Dec 3, 2024
1 parent 9c8557d commit e0b26f0
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/axolotl/monkeypatch/relora.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def reset_optimizer(
*,
reset_params: List[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str],
prune_ratio: float = 0.9,
optimizer_magnitude_pruning: float = 0.9,
):
# pylint:disable=unused-argument
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
n_zeros = 0
n_total = 0

Expand All @@ -64,9 +64,15 @@ def reset_optimizer(
if key not in optimizer_state_keys:
continue
if torch.is_tensor(value):
pruning_fn(value)
n_total += value.numel()
n_zeros += torch.sum(value == 0).item()
try:
pruning_fn(value)
n_total += value.numel()
n_zeros += torch.sum(value == 0).item()
except RuntimeError as exc:
if "quantile() input tensor is too large" in str(exc):
pass
else:
raise exc

_zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
Expand Down Expand Up @@ -130,6 +136,9 @@ def on_step_begin(

if "adam" in args.optim.lower():
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
if "8bit" in args.optim.lower():
optimizer_state_keys.append("state1")
optimizer_state_keys.append("state2")
else:
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")

Expand Down Expand Up @@ -161,7 +170,7 @@ def on_step_begin(
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
prune_ratio=args.relora_prune_ratio,
optimizer_magnitude_pruning=args.relora_prune_ratio,
)

if self.quantized:
Expand Down

0 comments on commit e0b26f0

Please sign in to comment.