Skip to content

Commit

Permalink
refactor + add test for fp8 initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 28, 2024
1 parent a4d6f15 commit fbbbf4d
Show file tree
Hide file tree
Showing 31 changed files with 1,154 additions and 806 deletions.
2 changes: 1 addition & 1 deletion src/nanotron/fp8/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from nanotron.distributed import *
from nanotron.fp8.tensor import FP8Tensor, convert_tensor_from_fp8
from nanotron.parallel.parameters import NanotronParameter, get_data_from_param
from nanotron.parallel.parameters import NanotronParameter


def all_reduce(
Expand Down
90 changes: 47 additions & 43 deletions src/nanotron/fp8/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,7 @@ def step(self, closure=None):
if not isinstance(p.data, FP8Tensor) and p.requires_grad is False:
continue

try:
assert p.grad is not None
except:
assert 1 == 1
assert p.grad is not None

# p_name = self.params_id_to_param_names[id(p)]
# loggings[p] = {}
Expand All @@ -231,46 +228,53 @@ def step(self, closure=None):
# NOTE: if use gradient accumulation, after the backward pass
# we set the param.grad to None, so we need to retrieve it from accumulator

if constants.CONFIG.optimizer.accumulate_grad_in_fp32 is True:
# fp32_grad = self.grad_accumulator.get_grad_buffer(name=p_name)

# if "model.decoder.8.pp_block.attn_layer_scale" in p_name:
# assert 1 == 1

# if constants.CONFIG.fp8.is_save_grad_for_accum_debugging is True:
# from nanotron.helpers import create_folder_and_save_tensor

# create_folder_and_save_tensor(
# fp32_grad,
# f"/fsx/phuc/temp/temp3_env_for_fp8/nanotron/debug_accum/{constants.CONFIG.general.run}/aggr_grads/{p_name}.pt",
# )
raise NotImplementedError("accumulate_grad_in_fp32 is not implemented")
# if constants.CONFIG.optimizer.accumulate_grad_in_fp32 is True:
# # fp32_grad = self.grad_accumulator.get_grad_buffer(name=p_name)

# # if "model.decoder.8.pp_block.attn_layer_scale" in p_name:
# # assert 1 == 1

# # if constants.CONFIG.fp8.is_save_grad_for_accum_debugging is True:
# # from nanotron.helpers import create_folder_and_save_tensor

# # create_folder_and_save_tensor(
# # fp32_grad,
# # f"/fsx/phuc/temp/temp3_env_for_fp8/nanotron/debug_accum/{constants.CONFIG.general.run}/aggr_grads/{p_name}.pt",
# # )
# raise NotImplementedError("accumulate_grad_in_fp32 is not implemented")
# else:
# if isinstance(p.data, FP8Tensor):
# if constants.CONFIG.fp8.is_directly_keep_accum_grad_of_fp8 is True:
# # fp32_grad = constants.ACCUM_GRADS[p_name]
# # grad = get_accum_grad(p_name)
# # fp32_grad = (
# # grad.to(self.optim_accum_dtype) if grad.dtype != self.optim_accum_dtype else grad
# # )
# # assert fp32_grad.dtype == torch.float32

# # # constants.ACCUM_GRADS[p_name] = None
# # set_accum_grad(p_name, None)
# raise NotImplementedError("is_directly_keep_accum_grad_of_fp8 is not implemented")
# else:
# assert p.grad.dtype in FP8_DTYPES
# fp32_grad = convert_tensor_from_fp8(p.grad, p.grad.fp8_meta, self.optim_accum_dtype)
# else:
# # grad = get_grad_from_parameter(p)

# # assert grad is not None
# assert p.grad.dtype == non_fp8_accum_dtype

# fp32_grad = p.grad.to(self.optim_accum_dtype)

# NOTE: Case 1: With gradient accumulator => the grad is already in the correct dtype
# Case 2: Without gradient accumulator =>
# 2.1 Non-FP8 parameter => cast the grad to the correct dtype
# 2.2 FP8 parameter => dequantize the grad to the correct dtype
grad = p.grad
if isinstance(p.data, FP8Tensor):
fp32_grad = convert_tensor_from_fp8(grad, grad.fp8_meta, self.optim_accum_dtype)
else:
if isinstance(p.data, FP8Tensor):
if constants.CONFIG.fp8.is_directly_keep_accum_grad_of_fp8 is True:
# fp32_grad = constants.ACCUM_GRADS[p_name]
# grad = get_accum_grad(p_name)
# fp32_grad = (
# grad.to(self.optim_accum_dtype) if grad.dtype != self.optim_accum_dtype else grad
# )
# assert fp32_grad.dtype == torch.float32

# # constants.ACCUM_GRADS[p_name] = None
# set_accum_grad(p_name, None)
raise NotImplementedError("is_directly_keep_accum_grad_of_fp8 is not implemented")
else:
assert p.grad.dtype in FP8_DTYPES
fp32_grad = convert_tensor_from_fp8(p.grad, p.grad.fp8_meta, self.optim_accum_dtype)
else:
# grad = get_grad_from_parameter(p)

# assert grad is not None
try:
assert p.grad.dtype == non_fp8_accum_dtype
except:
assert 1 == 1

fp32_grad = p.grad.to(self.optim_accum_dtype)
fp32_grad = grad.to(self.optim_accum_dtype)

assert fp32_grad.dtype == self.optim_accum_dtype

Expand Down
168 changes: 109 additions & 59 deletions src/nanotron/fp8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from torch import nn

from nanotron import logging
from nanotron.config import Config
from nanotron.config.fp8_config import FP8LayerArgs
from nanotron.config.fp8_config import FP8Args, FP8LayerArgs
from nanotron.fp8.constants import FP8_GPU_NAMES, FP8LM_RECIPE, QTYPE_TO_DTYPE
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.linear import FP8Linear
from nanotron.fp8.meta import FP8Meta
from nanotron.logging import log_rank
from nanotron.models.base import NanotronModel

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -231,64 +231,65 @@ def convert_logs_to_flat_logs(logs, prefix):
return flat_logs


def find_fp8_config_by_module_name(config: Config, target_module_name: str) -> Optional[FP8LayerArgs]:
if hasattr(config, "fp8") and hasattr(config.fp8, "model"):
# NOTE: either model or is_quant_all_except_first_and_last must be specified, not both
# assert config.fp8.model is not None or config.fp8.is_quant_all_except_first_and_last is not None

if config.fp8.model is not None:
for layer_args in config.fp8.model:
if layer_args.module_name == target_module_name:
return layer_args
elif config.fp8.is_quant_all_except_first_and_last:

def match_layer_pattern(name, layer_idxs):
patterns = [
"model.decoder.{}.attn.qkv_proj",
"model.decoder.{}.attn.o_proj",
"model.decoder.{}.mlp.down_proj",
# "model.decoder.{}.mlp.up_proj",
"model.decoder.{}.mlp.gate_up_proj",
]

for idx in layer_idxs:
for pattern in patterns:
if name == pattern.format(idx):
return True

return False

num_layers = config.model.model_config.num_hidden_layers
assert num_layers > 2, "num_hidden_layers must be greater than 2"
assert config.fp8.fp8_linear_config_temp is not None

quant_layer_idxs = list(range(1, num_layers - 1))
if match_layer_pattern(target_module_name, quant_layer_idxs) is True:
from copy import deepcopy

config_temp = deepcopy(config.fp8.fp8_linear_config_temp)
config_temp.module_name = target_module_name
return config_temp
else:
from nanotron.fp8.constant_recipe import MODULE_NAMES_THAT_NOT_FP8
from nanotron.fp8.constants import FP8LM_LINEAR_RECIPE
def find_fp8_config_by_module_name(target_module_name: str, config: FP8Args) -> Optional[FP8LayerArgs]:
# NOTE: either model or is_quant_all_except_first_and_last must be specified, not both
# assert config.fp8.model is not None or config.fp8.is_quant_all_except_first_and_last is not None

if any(module_name in target_module_name for module_name in MODULE_NAMES_THAT_NOT_FP8):
return None
else:
# NOTE: return default recipe
# NOTE: based on the global setting smooth_quant to decide whether to do smooth quantization
# or not
recipe = FP8LM_LINEAR_RECIPE
recipe.smooth_quant = config.fp8.smooth_quant
log_rank(
f"target_module_name={target_module_name}, smooth_quant={recipe.smooth_quant}",
logger=logger,
level=logging.INFO,
rank=0,
)

return recipe
# TODO(xrsrke): remove config.is_quant_all_except_first_and_last

if config.model is not None:
for layer_args in config.model:
if layer_args.module_name == target_module_name:
return layer_args
elif config.is_quant_all_except_first_and_last:

def match_layer_pattern(name, layer_idxs):
patterns = [
"model.decoder.{}.attn.qkv_proj",
"model.decoder.{}.attn.o_proj",
"model.decoder.{}.mlp.down_proj",
# "model.decoder.{}.mlp.up_proj",
"model.decoder.{}.mlp.gate_up_proj",
]

for idx in layer_idxs:
for pattern in patterns:
if name == pattern.format(idx):
return True

return False

num_layers = config.model.model_config.num_hidden_layers
assert num_layers > 2, "num_hidden_layers must be greater than 2"
assert config.fp8_linear_config_temp is not None

quant_layer_idxs = list(range(1, num_layers - 1))
if match_layer_pattern(target_module_name, quant_layer_idxs) is True:
from copy import deepcopy

config_temp = deepcopy(config.fp8_linear_config_temp)
config_temp.module_name = target_module_name
return config_temp
else:
from nanotron.fp8.constant_recipe import MODULE_NAMES_THAT_NOT_FP8
from nanotron.fp8.constants import FP8LM_LINEAR_RECIPE

if any(module_name in target_module_name for module_name in MODULE_NAMES_THAT_NOT_FP8):
return None
else:
# NOTE: return default recipe
# NOTE: based on the global setting smooth_quant to decide whether to do smooth quantization
# or not
recipe = FP8LM_LINEAR_RECIPE
recipe.smooth_quant = config.smooth_quant
log_rank(
f"target_module_name={target_module_name}, smooth_quant={recipe.smooth_quant}",
logger=logger,
level=logging.INFO,
rank=0,
)

return recipe
return None


Expand Down Expand Up @@ -330,3 +331,52 @@ def is_convert_to_fp16(module) -> bool:
IS_CONVERT_TO_FLOAT16 = True

return IS_CONVERT_TO_FLOAT16


def convert_model_to_fp8(model: NanotronModel, config: FP8Args) -> NanotronModel:
from nanotron.fp8.utils import get_leaf_modules

assert 1 == 1
# NOTE: convert to FP8
from nanotron.fp8.tensor import FP8Tensor

# from nanotron import constants
from nanotron.fp8.utils import find_fp8_config_by_module_name
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.tensor_parallel.nn import (
FP8TensorParallelColumnLinear,
FP8TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelRowLinear,
)

TP_LINEAR_CLS_TO_FP8_LINEAR_CLS = {
TensorParallelColumnLinear: FP8TensorParallelColumnLinear,
TensorParallelRowLinear: FP8TensorParallelRowLinear,
}
for name, module in get_leaf_modules(model):
if any(p.numel() > 0 for p in module.parameters()) is False:
continue

recipe = find_fp8_config_by_module_name(name, config)

# if isinstance(module, (TensorParallelColumnLinear, TensorParallelRowLinear)):
if recipe is not None:
print(f"Converting {name} to FP8")
module.__class__ = TP_LINEAR_CLS_TO_FP8_LINEAR_CLS[module.__class__]
# TODO(xrsrke): retrieve custom recipe
module._set_and_quantize_weights(module.weight.data)

assert isinstance(module.weight, NanotronParameter)
assert isinstance(module.weight.data, FP8Tensor)
assert module.weight.data.dtype in [
torch.uint8,
torch.int8,
], f"got {module.weight.data.dtype}, name: {name}"
else:
# NOTE: convert it to the residual stream's dtype
# for p in module.parameters():
# p.data = p.data.to(self.config.model.dtype)
module.to(dtype=config.resid_dtype)

return model
Loading

0 comments on commit fbbbf4d

Please sign in to comment.