Skip to content

Commit

Permalink
by default, do not quantize the first and last layer
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 30, 2024
1 parent afdfbf1 commit 79341ea
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 102 deletions.
10 changes: 5 additions & 5 deletions src/nanotron/fp8/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@
weight_grad=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=1),
output_grad=FP8TensorRecipe(dtype=DTypes.FP8E5M2, margin=0, interval=16),
# NOTE: tested, and it works
# split_accumulator=FP8SplitAccumulator(output=True, input_grad=True, weight_grad=True),
# accumulate=FP8SplitAccumulator(output=True, input_grad=True, weight_grad=True),
# NOTE: passes the test with 4% speed up relative to the above
split_accumulator=FP8SplitAccumulator(output=False, input_grad=True, weight_grad=True),
accumulate=FP8SplitAccumulator(output=False, input_grad=False, weight_grad=True),
split_accumulator=FP8SplitAccumulator(output=True, input_grad=True, weight_grad=True),
accumulate=FP8SplitAccumulator(output=True, input_grad=True, weight_grad=True),
# # NOTE: passes the test with 4% speed up relative to the above
# split_accumulator=FP8SplitAccumulator(output=False, input_grad=True, weight_grad=True),
# accumulate=FP8SplitAccumulator(output=False, input_grad=False, weight_grad=True),
)

FP8LM_OPTIM_RECIPE = FP8OptimRecipe(
Expand Down
70 changes: 13 additions & 57 deletions src/nanotron/fp8/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,50 +55,6 @@ def __init__(
self.master_weight_dtype = recipe.master_weight_dtype
self.optim_accum_dtype = recipe.accum_dtype

# NOTE: torch.Tensor is bias
# self.fp8_weights: List[Union[FP8Parameter, torch.Tensor]] = []

# NOTE: move to gradient accumulator
# # NOTE: create master weights for FP8 Parameter
# self.mappping_fp8_to_master_weight: Dict[str, Union[FP16Tensor, torch.Tensor]] = {}

# for group in self.param_groups:
# for p in group["params"]:
# # data = get_data_from_param(p)
# # if p.data.__class__ != FP8Tensor:
# # continue
# # # NOTE: this parameter we don't convert to FP8, so no need master weight
# # if not isinstance(p.data, FP8Tensor):
# # continue

# assert 1 == 1
# # if p._is_future_fp8 is not True:
# # continue
# if not isinstance(p.data, FP8Tensor):
# continue

# # assert p.dtype == data.dtype

# # if isinstance(p, NanotronParameter):
# # raw_data = p.data.orig_data if hasattr(p.data, "orig_data") else p.data
# # else:
# # raw_data = p.orig_data if hasattr(p, "orig_data") else p.data
# # assert raw_data.dtype in [torch.float32], f"raw_data.dtype={raw_data.dtype}"

# assert p.data.dtype in [torch.float32], f"raw_data.dtype={p.data.dtype}"
# self.mappping_fp8_to_master_weight[hash(p)] = self._create_master_weight(p.data)

# # self.fp8_weights.append(p.data)

# # delete_tensor_from_memory(raw_data)

# # p.orig_data = None
# # if hasattr(p.data, "orig_data"):
# # p.data.orig_data = None

# # assert len(self.mappping_fp8_to_master_weight) == len(self.fp8_weights)
# # TODO(xrsrke): auto free fp32 weights from memory

self.loggings = []
self._is_overflow = False

Expand Down Expand Up @@ -278,6 +234,19 @@ def step(self, closure=None):

assert fp32_grad.dtype == self.optim_accum_dtype

if is_overflow_underflow_nan(fp32_grad):
self._is_overflow = True

if constants.CONFIG.fp8.skip_param_update_if_nan is True:
log_rank(
f"[Optim] param_name, skipping update due to overflow/underflow/nan", # noqa
logger=logger,
level=logging.INFO,
)
continue
else:
raise ValueError("Overflow, underflow, or NaN detected in the gradients")

if isinstance(p.data, FP8Tensor):
assert p.data.dtype in FP8_DTYPES
assert hash(p) in self.mappping_fp8_to_master_weight, "Can't find master weight for FP8 parameter"
Expand All @@ -295,19 +264,6 @@ def step(self, closure=None):

assert fp32_data.dtype == self.optim_accum_dtype

if is_overflow_underflow_nan(fp32_grad):
self._is_overflow = True

if constants.CONFIG.fp8.skip_param_update_if_nan is True:
log_rank(
f"[Optim] param_name, skipping update due to overflow/underflow/nan", # noqa
logger=logger,
level=logging.INFO,
)
continue
else:
raise ValueError("Overflow, underflow, or NaN detected in the gradients")

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

fp32_exp_avg = self._dequantize_optim_state(exp_avg)
Expand Down
65 changes: 34 additions & 31 deletions src/nanotron/fp8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.linear import FP8Linear
from nanotron.fp8.meta import FP8Meta
from nanotron.logging import log_rank
from nanotron.models.base import NanotronModel

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -236,20 +235,22 @@ def find_fp8_config_by_module_name(target_module_name: str, config: FP8Args) ->
# assert config.fp8.model is not None or config.fp8.is_quant_all_except_first_and_last is not None

# TODO(xrsrke): remove config.is_quant_all_except_first_and_last
from nanotron.fp8.constants import FP8LM_LINEAR_RECIPE

if config.model is not None:
for layer_args in config.model:
if layer_args.module_name == target_module_name:
return layer_args
elif config.is_quant_all_except_first_and_last:
# elif config.is_quant_all_except_first_and_last:
else:

def match_layer_pattern(name, layer_idxs):
patterns = [
"model.decoder.{}.attn.qkv_proj",
"model.decoder.{}.attn.o_proj",
"model.decoder.{}.mlp.down_proj",
"model.decoder.{}.pp_block.attn.qkv_proj",
"model.decoder.{}.pp_block.attn.o_proj",
"model.decoder.{}.pp_block.mlp.down_proj",
# "model.decoder.{}.mlp.up_proj",
"model.decoder.{}.mlp.gate_up_proj",
"model.decoder.{}.pp_block.mlp.gate_up_proj",
]

for idx in layer_idxs:
Expand All @@ -259,38 +260,40 @@ def match_layer_pattern(name, layer_idxs):

return False

num_layers = config.model.model_config.num_hidden_layers
from nanotron import constants

num_layers = constants.CONFIG.model.model_config.num_hidden_layers
assert num_layers > 2, "num_hidden_layers must be greater than 2"
assert config.fp8_linear_config_temp is not None
# assert config.fp8_linear_config_temp is not None

quant_layer_idxs = list(range(1, num_layers - 1))
if match_layer_pattern(target_module_name, quant_layer_idxs) is True:
from copy import deepcopy

config_temp = deepcopy(config.fp8_linear_config_temp)
config_temp.module_name = target_module_name
# config_temp = deepcopy(config.fp8_linear_config_temp)
config_temp = deepcopy(FP8LM_LINEAR_RECIPE)
# config_temp.module_name = target_module_name
return config_temp
else:
from nanotron.fp8.constant_recipe import MODULE_NAMES_THAT_NOT_FP8
from nanotron.fp8.constants import FP8LM_LINEAR_RECIPE

if any(module_name in target_module_name for module_name in MODULE_NAMES_THAT_NOT_FP8):
return None
else:
# NOTE: return default recipe
# NOTE: based on the global setting smooth_quant to decide whether to do smooth quantization
# or not
recipe = FP8LM_LINEAR_RECIPE
recipe.smooth_quant = config.smooth_quant
log_rank(
f"target_module_name={target_module_name}, smooth_quant={recipe.smooth_quant}",
logger=logger,
level=logging.INFO,
rank=0,
)

return recipe
return None
# else:
# from nanotron.fp8.constant_recipe import MODULE_NAMES_THAT_NOT_FP8

# if any(module_name in target_module_name for module_name in MODULE_NAMES_THAT_NOT_FP8):
# return None
# else:
# # NOTE: return default recipe
# # NOTE: based on the global setting smooth_quant to decide whether to do smooth quantization
# # or not
# recipe = FP8LM_LINEAR_RECIPE
# recipe.smooth_quant = config.smooth_quant
# log_rank(
# f"target_module_name={target_module_name}, smooth_quant={recipe.smooth_quant}",
# logger=logger,
# level=logging.INFO,
# rank=0,
# )

# return recipe
# return None


def get_modules_not_in_fp16():
Expand Down
25 changes: 17 additions & 8 deletions src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,19 +285,28 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None:
# except AssertionError:
# assert 1 == 1
assert half_param.grad is not None, f"Expected param {name} to have gradient."
from nanotron.fp8.tensor import convert_tensor_from_fp8

if isinstance(half_param.data, FP8Tensor):
grad = convert_tensor_from_fp8(half_param.grad, half_param.grad.fp8_meta, torch.float32)
else:
grad = half_param.grad

from nanotron.fp8.utils import is_overflow_underflow_nan

assert is_overflow_underflow_nan(grad) is False

fp32_grad = self.get_grad_buffer(name=name)

if self._is_accumulation_sync_step is False:
# WARNING: We assume fp32_grad_bucket is already zeroed
if not isinstance(half_param.data, FP8Tensor):
fp32_grad.add_(half_param.grad)
else:
from nanotron.fp8.tensor import convert_tensor_from_fp8

assert half_param.grad.dtype in [torch.int8, torch.uint8]
# TODO(xrsrke): move .convert_tensor_from_fp8 to .to(dtype), so we have an unified API
fp32_grad.add_(convert_tensor_from_fp8(half_param.grad, half_param.grad.fp8_meta, torch.float32))
# if not isinstance(half_param.data, FP8Tensor):
# fp32_grad.add_(grad)
# else:
# assert grad.dtype in [torch.int8, torch.uint8]
# # TODO(xrsrke): move .convert_tensor_from_fp8 to .to(dtype), so we have an unified API
# fp32_grad.add_(grad)
fp32_grad.add_(grad)
# In case _is_accumulation_sync_step = True: no need to add half gradients, because it's done in the allreduce hook

# TODO @thomasw21: Is it better to set to zero instead?
Expand Down
3 changes: 2 additions & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from nanotron.constants import MODEL_CONFIG_FILE_NAME
from nanotron.dataloader import sanity_check_dataloader
from nanotron.fp8.utils import convert_model_to_fp8
from nanotron.fp8.utils import convert_model_to_fp8, is_overflow_underflow_nan
from nanotron.helpers import (
_vocab_size_with_padding,
compute_remain_train_steps_of_a_data_stage_from_ckp,
Expand Down Expand Up @@ -649,6 +649,7 @@ def training_step(
assert p.grad.dtype in [torch.uint8, torch.int8], f"got {p.grad.dtype}"
else:
assert p.grad.dtype == constants.CONFIG.fp8.resid_dtype
assert is_overflow_underflow_nan(p.grad) is False

# NOTE: sanity check that parameters has gradient
assert 1 == 1
Expand Down

0 comments on commit 79341ea

Please sign in to comment.