diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index e87d8dbb..e9ad04d8 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -149,7 +149,7 @@ def save( tensor=tied_param, pg=group, msg=lambda err: f"Tied {tied_info.name} are not synced {err}" ) if not optimizer.inherit_from(optim.ZeroDistributedOptimizer): - check_optim_state_in_sync(optimizer, parallel_context.dp_pg) + check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_pg) # SANITY CHECK: tied parameters have their optimizer states synchronized # Compute a mapping from id_ to index in the optimizer sense