Skip to content

Commit

Permalink
Fix optimizer logging (#597)
Browse files Browse the repository at this point in the history
* fix optimizer logging

* lint
  • Loading branch information
mvpatel2000 authored Sep 14, 2023
1 parent 30544f0 commit c9dda15
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 80 deletions.
58 changes: 14 additions & 44 deletions llmfoundry/optim/adaptive_lion.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,28 +206,10 @@ def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]):

def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]):
"""Preprocess metrics to reduce across ranks correctly."""
# Sort L2 norms first so they are squared before other metrics, which depend on squared values
metrics = optimizer_metrics.keys()
metrics = sorted(metrics,
key=lambda metric: 0 if 'l2_norm' in metric else 1)
for metric in metrics:
if metric.startswith('l2_norm'):
# L2 norms need to be squared, before they are reduced via summation
optimizer_metrics[metric] = optimizer_metrics[metric]**2
elif metric.startswith('cosine'):
_, vectors, layer = tuple(metric.split('/'))

A, B = tuple(vectors.split('_'))

# L2 norm would've been squared in previous branch
A_rank_subset_norm = math.sqrt(
optimizer_metrics[f'l2_norm/{A}/{layer}'])
B_rank_subset_norm = math.sqrt(
optimizer_metrics[f'l2_norm/{B}/{layer}'])

optimizer_metrics[
metric] *= A_rank_subset_norm * B_rank_subset_norm

# Only L2 norm metric keys are present, can skip sorting at this stage
for metric in optimizer_metrics:
# L2 norms need to be squared, before they are reduced via summation
optimizer_metrics[metric] = optimizer_metrics[metric]**2
return optimizer_metrics

def report_per_parameter_metrics(self, param: torch.Tensor, name: str,
Expand Down Expand Up @@ -287,14 +269,6 @@ class DecoupledClipLion(Optimizer):
'l2_norm/grad':
lambda param, optim_state, step_tensor: torch.linalg.vector_norm(
param.grad),
'cosine/update_grad':
lambda param, optim_state, step_tensor: torch.nn.functional.
cosine_similarity(
param.grad.flatten(), step_tensor.flatten(), dim=0),
'cosine/moment_grad':
lambda param, optim_state, step_tensor: torch.nn.functional.
cosine_similarity(
param.grad.flatten(), optim_state['exp_avg'].flatten(), dim=0),
}

def __init__(self,
Expand Down Expand Up @@ -384,26 +358,22 @@ def step(self, closure: Optional[Callable] = None):
return loss

def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]):
for metric in optimizer_metrics:
local_keys = list(optimizer_metrics.keys())
all_gathered_keys = dist.all_gather_object(local_keys)
all_keys = set()
for keys in all_gathered_keys:
all_keys.update(keys)

# Sort keys to ensure every rank has the same keys order
# Only L2 norm metric keys are present, can apply regular sort
all_keys = sorted(all_keys)
for metric in all_keys:
if metric.startswith('l2_norm'):
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')

optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced))
elif metric.startswith('cosine'):
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')

_, vectors, layer = tuple(metric.split('/'))

A, B = tuple(vectors.split('_'))

A_reduced_norm = optimizer_metrics[f'l2_norm/{A}/{layer}']
B_reduced_norm = optimizer_metrics[f'l2_norm/{B}/{layer}']
optimizer_metrics[metric] = reduced / (A_reduced_norm *
B_reduced_norm)
elif metric.startswith('clipped_batches'):
continue
else:
Expand Down
50 changes: 14 additions & 36 deletions llmfoundry/optim/lion.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,26 +99,22 @@ def step(self, closure: Optional[Callable] = None):
return loss

def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]):
for metric in optimizer_metrics:
local_keys = list(optimizer_metrics.keys())
all_gathered_keys = dist.all_gather_object(local_keys)
all_keys = set()
for keys in all_gathered_keys:
all_keys.update(keys)

# Sort keys to ensure every rank has the same keys order
# Only L2 norm metric keys are present, can apply regular sort
all_keys = sorted(all_keys)
for metric in all_keys:
if metric.startswith('l2_norm'):
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')

optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced))
elif metric.startswith('cosine'):
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
dist.all_reduce(reduced, reduce_operation='SUM')

_, vectors, layer = tuple(metric.split('/'))

A, B = tuple(vectors.split('_'))

A_reduced_norm = optimizer_metrics[f'l2_norm/{A}/{layer}']
B_reduced_norm = optimizer_metrics[f'l2_norm/{B}/{layer}']
optimizer_metrics[metric] = reduced / (A_reduced_norm *
B_reduced_norm)
else:
reduced = optimizer_metrics[metric]
if dist.get_world_size() > 1:
Expand All @@ -129,28 +125,10 @@ def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]):

def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]):
"""Preprocess metrics to reduce across ranks correctly."""
# Sort L2 norms first so they are squared before other metrics, which depend on squared values
metrics = optimizer_metrics.keys()
metrics = sorted(metrics,
key=lambda metric: 0 if 'l2_norm' in metric else 1)
for metric in metrics:
if metric.startswith('l2_norm'):
# L2 norms need to be squared, before they are reduced via summation
optimizer_metrics[metric] = optimizer_metrics[metric]**2
elif metric.startswith('cosine'):
_, vectors, layer = tuple(metric.split('/'))

A, B = tuple(vectors.split('_'))

# L2 norm would've been squared in previous branch
A_rank_subset_norm = math.sqrt(
optimizer_metrics[f'l2_norm/{A}/{layer}'])
B_rank_subset_norm = math.sqrt(
optimizer_metrics[f'l2_norm/{B}/{layer}'])

optimizer_metrics[
metric] *= A_rank_subset_norm * B_rank_subset_norm

# Only L2 norm metric keys are present, can skip sorting at this stage
for metric in optimizer_metrics:
# L2 norms need to be squared, before they are reduced via summation
optimizer_metrics[metric] = optimizer_metrics[metric]**2
return optimizer_metrics

def report_per_parameter_metrics(self, param: torch.Tensor, name: str,
Expand Down

0 comments on commit c9dda15

Please sign in to comment.