Skip to content

Commit

Permalink
fix hanging due to NanotronParameter.__repr__ (param.data == Nanotron…
Browse files Browse the repository at this point in the history
…Parameter)
  • Loading branch information
xrsrke committed Nov 29, 2024
1 parent b764b97 commit afdfbf1
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
8 changes: 7 additions & 1 deletion src/nanotron/fp8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,14 @@ def convert_model_to_fp8(model: NanotronModel, config: FP8Args) -> NanotronModel
# NOTE: convert it to the residual stream's dtype
# for p in module.parameters():
# p.data = p.data.to(self.config.model.dtype)
module.to(dtype=config.resid_dtype)
# for p in module.parameters():
# p.data = p.data.to(dtype=config.resid_dtype) if p.data
# pass
# assert module.weight.data.__class__ == torch.Tensor
# module.to(dtype=config.resid_dtype)
# pass
# assert module.weight.data.__class__ == torch.Tensor
# NOTE: this causes param.data == NanotronParameter
assert config.resid_dtype == torch.float32, "not support datatype conversion, because of error 8"

return model
13 changes: 7 additions & 6 deletions src/nanotron/parallel/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,9 @@ def is_sharded(self) -> bool:
self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME
)

# def __repr__(self):
# return f"NanotronParameter({super().__repr__()})"
def __repr__(self):
# return f"NanotronParameter({super().__repr__()})"
return "NanotronParameter()"

@property
def data(self):
Expand Down Expand Up @@ -293,13 +294,13 @@ def data(self, data):
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
from nanotron.fp8.tensor import FP8Tensor

print(f"__torch_dispatch__ called with func: {func}, args: {args}, kwargs: {kwargs}")
# print(f"__torch_dispatch__ called with func: {func}, args: {args}, kwargs: {kwargs}")

if func in {torch._tensor_str._str, repr}:
return super().__torch_dispatch__(func, types, args, kwargs)
# if func in {torch._tensor_str._str, repr}:
# return super().__torch_dispatch__(func, types, args, kwargs)

def unwrap(e):
print(f"Unwrapping: {e} (type: {type(e)})")
# print(f"Unwrapping: {e} (type: {type(e)})")
return e._data if e.__class__ == NanotronParameter else e

def wrap(e):
Expand Down
15 changes: 9 additions & 6 deletions tests/fp8/test_fp8_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@ def _test_initialize_fp8_model(parallel_context: ParallelContext, fp8_config: FP
assert all(
p.dtype == fp8_config.resid_dtype for p in module.parameters()
), f"name: {name}, __class__: {module.weight.data.__class__}"
try:
assert all(
p.data.__class__ == nn.Parameter for p in module.parameters()
), f"name: {name}, __class__: {module.weight.data.__class__}"
except:
assert 1 == 1
# try:
# assert all(
# p.data.__class__ == nn.Parameter for p in module.parameters()
# ), f"name: {name}, __class__: {module.weight.data.__class__}"
# except:
# assert 1 == 1
assert all(
p.data.__class__ == nn.Parameter for p in module.parameters()
), f"name: {name}, __class__: {module.weight.data.__class__}"
else:
assert all(
isinstance(p.data.__class__, FP8Tensor) for p in module.parameters()
Expand Down

0 comments on commit afdfbf1

Please sign in to comment.