diff --git a/llmfoundry/optim/adaptive_lion.py b/llmfoundry/optim/adaptive_lion.py index 06110bab23..0ce76e905e 100644 --- a/llmfoundry/optim/adaptive_lion.py +++ b/llmfoundry/optim/adaptive_lion.py @@ -206,28 +206,10 @@ def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): """Preprocess metrics to reduce across ranks correctly.""" - # Sort L2 norms first so they are squared before other metrics, which depend on squared values - metrics = optimizer_metrics.keys() - metrics = sorted(metrics, - key=lambda metric: 0 if 'l2_norm' in metric else 1) - for metric in metrics: - if metric.startswith('l2_norm'): - # L2 norms need to be squared, before they are reduced via summation - optimizer_metrics[metric] = optimizer_metrics[metric]**2 - elif metric.startswith('cosine'): - _, vectors, layer = tuple(metric.split('/')) - - A, B = tuple(vectors.split('_')) - - # L2 norm would've been squared in previous branch - A_rank_subset_norm = math.sqrt( - optimizer_metrics[f'l2_norm/{A}/{layer}']) - B_rank_subset_norm = math.sqrt( - optimizer_metrics[f'l2_norm/{B}/{layer}']) - - optimizer_metrics[ - metric] *= A_rank_subset_norm * B_rank_subset_norm - + # Only L2 norm metric keys are present, can skip sorting at this stage + for metric in optimizer_metrics: + # L2 norms need to be squared, before they are reduced via summation + optimizer_metrics[metric] = optimizer_metrics[metric]**2 return optimizer_metrics def report_per_parameter_metrics(self, param: torch.Tensor, name: str, @@ -287,14 +269,6 @@ class DecoupledClipLion(Optimizer): 'l2_norm/grad': lambda param, optim_state, step_tensor: torch.linalg.vector_norm( param.grad), - 'cosine/update_grad': - lambda param, optim_state, step_tensor: torch.nn.functional. - cosine_similarity( - param.grad.flatten(), step_tensor.flatten(), dim=0), - 'cosine/moment_grad': - lambda param, optim_state, step_tensor: torch.nn.functional. - cosine_similarity( - param.grad.flatten(), optim_state['exp_avg'].flatten(), dim=0), } def __init__(self, @@ -384,26 +358,22 @@ def step(self, closure: Optional[Callable] = None): return loss def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): - for metric in optimizer_metrics: + local_keys = list(optimizer_metrics.keys()) + all_gathered_keys = dist.all_gather_object(local_keys) + all_keys = set() + for keys in all_gathered_keys: + all_keys.update(keys) + + # Sort keys to ensure every rank has the same keys order + # Only L2 norm metric keys are present, can apply regular sort + all_keys = sorted(all_keys) + for metric in all_keys: if metric.startswith('l2_norm'): reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: dist.all_reduce(reduced, reduce_operation='SUM') optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced)) - elif metric.startswith('cosine'): - reduced = optimizer_metrics[metric] - if dist.get_world_size() > 1: - dist.all_reduce(reduced, reduce_operation='SUM') - - _, vectors, layer = tuple(metric.split('/')) - - A, B = tuple(vectors.split('_')) - - A_reduced_norm = optimizer_metrics[f'l2_norm/{A}/{layer}'] - B_reduced_norm = optimizer_metrics[f'l2_norm/{B}/{layer}'] - optimizer_metrics[metric] = reduced / (A_reduced_norm * - B_reduced_norm) elif metric.startswith('clipped_batches'): continue else: diff --git a/llmfoundry/optim/lion.py b/llmfoundry/optim/lion.py index cc171290b7..0caa7d2877 100644 --- a/llmfoundry/optim/lion.py +++ b/llmfoundry/optim/lion.py @@ -99,26 +99,22 @@ def step(self, closure: Optional[Callable] = None): return loss def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): - for metric in optimizer_metrics: + local_keys = list(optimizer_metrics.keys()) + all_gathered_keys = dist.all_gather_object(local_keys) + all_keys = set() + for keys in all_gathered_keys: + all_keys.update(keys) + + # Sort keys to ensure every rank has the same keys order + # Only L2 norm metric keys are present, can apply regular sort + all_keys = sorted(all_keys) + for metric in all_keys: if metric.startswith('l2_norm'): reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: dist.all_reduce(reduced, reduce_operation='SUM') optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced)) - elif metric.startswith('cosine'): - reduced = optimizer_metrics[metric] - if dist.get_world_size() > 1: - dist.all_reduce(reduced, reduce_operation='SUM') - - _, vectors, layer = tuple(metric.split('/')) - - A, B = tuple(vectors.split('_')) - - A_reduced_norm = optimizer_metrics[f'l2_norm/{A}/{layer}'] - B_reduced_norm = optimizer_metrics[f'l2_norm/{B}/{layer}'] - optimizer_metrics[metric] = reduced / (A_reduced_norm * - B_reduced_norm) else: reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: @@ -129,28 +125,10 @@ def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): """Preprocess metrics to reduce across ranks correctly.""" - # Sort L2 norms first so they are squared before other metrics, which depend on squared values - metrics = optimizer_metrics.keys() - metrics = sorted(metrics, - key=lambda metric: 0 if 'l2_norm' in metric else 1) - for metric in metrics: - if metric.startswith('l2_norm'): - # L2 norms need to be squared, before they are reduced via summation - optimizer_metrics[metric] = optimizer_metrics[metric]**2 - elif metric.startswith('cosine'): - _, vectors, layer = tuple(metric.split('/')) - - A, B = tuple(vectors.split('_')) - - # L2 norm would've been squared in previous branch - A_rank_subset_norm = math.sqrt( - optimizer_metrics[f'l2_norm/{A}/{layer}']) - B_rank_subset_norm = math.sqrt( - optimizer_metrics[f'l2_norm/{B}/{layer}']) - - optimizer_metrics[ - metric] *= A_rank_subset_norm * B_rank_subset_norm - + # Only L2 norm metric keys are present, can skip sorting at this stage + for metric in optimizer_metrics: + # L2 norms need to be squared, before they are reduced via summation + optimizer_metrics[metric] = optimizer_metrics[metric]**2 return optimizer_metrics def report_per_parameter_metrics(self, param: torch.Tensor, name: str,