diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 7c18e7098c..5ee3da9d65 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -16,12 +16,12 @@ output_dir: ./qlora-out ## You can optionally freeze the entire model and unfreeze a subset of parameters unfrozen_parameters: -# - lm_head.* -# - model.embed_tokens.* -# - model.layers.2[0-9]+.block_sparse_moe.gate.* -# - model.layers.2[0-9]+.block_sparse_moe.experts.* -# - model.layers.3[0-9]+.block_sparse_moe.gate.* -# - model.layers.3[0-9]+.block_sparse_moe.experts.* +# - ^lm_head.weight$ +# - ^model.embed_tokens.weight$[:32000] +# - model.layers.2[0-9]+.block_sparse_moe.gate +# - model.layers.2[0-9]+.block_sparse_moe.experts +# - model.layers.3[0-9]+.block_sparse_moe.gate +# - model.layers.3[0-9]+.block_sparse_moe.experts model_config: output_router_logits: true diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 0b5ef76716..5c4eaf3d51 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -19,7 +19,7 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault -from axolotl.utils.freeze import freeze_parameters_except +from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer @@ -99,7 +99,7 @@ def train( safe_serialization = cfg.save_safetensors is True if cfg.unfrozen_parameters: - freeze_parameters_except(model, cfg.unfrozen_parameters) + freeze_layers_except(model, cfg.unfrozen_parameters) trainer = setup_trainer( cfg, diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py index 05beda1caa..64b994f84d 100644 --- a/src/axolotl/utils/freeze.py +++ b/src/axolotl/utils/freeze.py @@ -3,13 +3,14 @@ """ import logging import re +from typing import Callable, List, Tuple from axolotl.utils.distributed import is_main_process LOG = logging.getLogger("axolotl.utils.freeze") -def freeze_parameters_except(model, regex_patterns): +def freeze_layers_except(model, regex_patterns): """ Freezes all layers of the given model except for the layers that match given regex patterns. Periods in the patterns are treated as literal periods, not as wildcard characters. @@ -17,22 +18,209 @@ def freeze_parameters_except(model, regex_patterns): Parameters: - model (nn.Module): The PyTorch model to be modified. - regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen. + Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names. + Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name. + The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name. + E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"] Returns: None; the model is modified in place. """ - # Escape periods and compile the regex patterns - compiled_patterns = [ - re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns - ] + if isinstance(regex_patterns, str): + regex_patterns = [regex_patterns] - # First, freeze all parameters in the model - for param in model.parameters(): - param.requires_grad = False + patterns = [LayerNamePattern(pattern) for pattern in regex_patterns] # Unfreeze layers that match the regex patterns for name, param in model.named_parameters(): - if any(pattern.match(name) for pattern in compiled_patterns): - if is_main_process(): - LOG.debug(f"unfreezing {name}") + param.requires_grad = False + unfrozen_ranges = [] + for pattern in patterns: + if not pattern.match(name): + continue + param.requires_grad = True + + if pattern.range is not None: + unfrozen_ranges.append(pattern.range) + + merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param)) + + if param.requires_grad and is_main_process(): + unfrozen_ranges = ( + f" with ranges {merged_unfrozen_ranges}" + if merged_unfrozen_ranges + else "" + ) + LOG.debug(f"Unfrozen {name}{unfrozen_ranges}") + + if not merged_unfrozen_ranges: + continue + + # The range list we need is actually the inverted of the merged ranges + ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param)) + + param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze)) + + if is_main_process() and all( + not param.requires_grad for param in model.parameters() + ): + LOG.warning("All parameters are frozen. Model will not be trained.") + + +def _invert_ranges( + given_ranges: List[Tuple[int, int]], layer_size: int +) -> List[Tuple[int, int]]: + """ + Inverts a list of ranges to obtain the ranges not covered by the given ranges. + + Parameters: + - given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices. + - layer_size (int): The length of the layer. E.g., len(model.layer.weight) + Returns: + - List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices. + """ + if not given_ranges: + return [(0, layer_size)] + + inverted_ranges = [] + current_start = 0 + + for start, end in sorted(given_ranges): + if start > current_start: + inverted_ranges.append((current_start, start)) + current_start = max(current_start, end) + + # Handle the case where the last given range does not reach the end of the total_size + if current_start < layer_size: + inverted_ranges.append((current_start, layer_size)) + + return inverted_ranges + + +def _merge_ranges( + given_ranges: List[Tuple[int, int | None]], layer_size: int +) -> List[Tuple[int, int]]: + """ + Merges overlapping ranges and sorts the given ranges. + + This function takes a list of ranges and merges any overlapping ranges. The ranges are represented + as tuples, where the first element is the start index (inclusive) and the second element is the end + index (exclusive). The end index can be None, indicating that the range extends to the end of the + sequence. + + Parameters: + - given_ranges (List[Tuple[int, int | None]]): List of ranges to merge. + - layer_size (int): The length of the layer. E.g., len(model.layer.weight) + + Returns: + - List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices. + """ + # End of each range can be determined now since we have the total size + processed_ranges = [ + (start, end if end is not None else layer_size) for start, end in given_ranges + ] + + # No need to merge if there's only one or no ranges + if len(processed_ranges) <= 1: + return processed_ranges + + sorted_ranges = sorted(processed_ranges) + + merged_ranges = [sorted_ranges[0]] + for start, end in sorted_ranges[1:]: + prev_start, prev_end = merged_ranges[-1] + if start <= prev_end: + merged_ranges[-1] = (prev_start, max(prev_end, end)) + else: + merged_ranges.append((start, end)) + + return merged_ranges + + +def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable: + """ + Create a hook to freeze parameters in specified ranges by setting their gradients to zero. + + This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain + two integers representing the start and end indices of the range. + + Parameters: + - ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze. + + Returns: + - Callable: A hook function to be used with `register_hook` on parameters. + + Example usage: + ``` + ranges_to_freeze = [(0, 10), (20, 30)] + hook = _create_freeze_parameters_hook(ranges_to_freeze) + model.register_hook(hook) + ``` + """ + + def freeze_parameters_hook(gradients): + for start, end in ranges_to_freeze: + gradients[start:end].zero_() + + return freeze_parameters_hook + + +class LayerNamePattern: + """ + Represents a regex pattern for layer names, potentially including a parameter index range. + """ + + def __init__(self, pattern: str): + """ + Initializes a new instance of the LayerNamePattern class. + + Parameters: + - pattern (str): The regex pattern for layer names, potentially including a parameter index range. + """ + self.raw_pattern = pattern + name_pattern, self.range = self._parse_pattern(pattern) + self.name_regex = re.compile(name_pattern.replace(".", "\\.")) + + def match(self, name: str) -> bool: + """ + Checks if the given layer name matches the regex pattern. + + Parameters: + - name (str): The layer name to check. + + Returns: + - bool: True if the layer name matches the pattern, False otherwise. + """ + return self.name_regex.match(name) is not None + + def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]: + """ + Extracts the range pattern from the given pattern. + + Parameters: + - pattern (str): The pattern to extract the range from. + + Returns: + - Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified. + """ + match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern) + if not match: + return pattern, None + + base_pattern, start_part, end_part = match.groups() + + if end_part is None and start_part.isdecimal(): + index = int(start_part) + return base_pattern, (index, index + 1) + + # [:end] or [start:] or [start:end] + start = int(start_part) if start_part else 0 + end = int(end_part) if end_part else None + + if end is not None and start >= end: + raise ValueError( + f"Invalid range in layer name pattern: {pattern}." + "End of range must be greater than start." + ) + return base_pattern, (start, end) diff --git a/tests/test_freeze.py b/tests/test_freeze.py new file mode 100644 index 0000000000..49d30ba5fa --- /dev/null +++ b/tests/test_freeze.py @@ -0,0 +1,285 @@ +""" +This module contains unit tests for the `freeze_layers_except` function. + +The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers. +The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios. +""" + +import unittest + +import torch +from torch import nn + +from axolotl.utils.freeze import freeze_layers_except + +ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + + +class TestFreezeLayersExcept(unittest.TestCase): + """ + A test case class for the `freeze_layers_except` function. + """ + + def setUp(self): + self.model = _TestModel() + + def test_freeze_layers_with_dots_in_name(self): + freeze_layers_except(self.model, ["features.layer"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + def test_freeze_layers_without_dots_in_name(self): + freeze_layers_except(self.model, ["classifier"]) + self.assertFalse( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertTrue( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + def test_freeze_layers_regex_patterns(self): + # The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'. + freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + def test_all_layers_frozen(self): + freeze_layers_except(self.model, []) + self.assertFalse( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be frozen.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + def test_all_layers_unfrozen(self): + freeze_layers_except(self.model, ["features.layer", "classifier"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertTrue( + self.model.classifier.weight.requires_grad, + "model.classifier should be trainable.", + ) + + def test_freeze_layers_with_range_pattern_start_end(self): + freeze_layers_except(self.model, ["features.layer[1:5]"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ + ZERO, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ZERO, + ZERO, + ZERO, + ZERO, + ZERO, + ] + ) + + def test_freeze_layers_with_range_pattern_single_index(self): + freeze_layers_except(self.model, ["features.layer[5]"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO] + ) + + def test_freeze_layers_with_range_pattern_start_omitted(self): + freeze_layers_except(self.model, ["features.layer[:5]"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ZERO, + ZERO, + ZERO, + ZERO, + ZERO, + ] + ) + + def test_freeze_layers_with_range_pattern_end_omitted(self): + freeze_layers_except(self.model, ["features.layer[4:]"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ + ZERO, + ZERO, + ZERO, + ZERO, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ] + ) + + def test_freeze_layers_with_range_pattern_merge_included(self): + freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ + ZERO, + ZERO, + ZERO, + ZERO, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ] + ) + + def test_freeze_layers_with_range_pattern_merge_intersect(self): + freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"]) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ + ZERO, + ZERO, + ZERO, + ZERO, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ONE_TO_TEN, + ZERO, + ZERO, + ] + ) + + def test_freeze_layers_with_range_pattern_merge_separate(self): + freeze_layers_except( + self.model, + ["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"], + ) + self.assertTrue( + self.model.features.layer.weight.requires_grad, + "model.features.layer should be trainable.", + ) + self.assertFalse( + self.model.classifier.weight.requires_grad, + "model.classifier should be frozen.", + ) + + self._assert_gradient_output( + [ + ZERO, + ONE_TO_TEN, + ZERO, + ONE_TO_TEN, + ZERO, + ONE_TO_TEN, + ZERO, + ZERO, + ZERO, + ZERO, + ] + ) + + def _assert_gradient_output(self, expected): + input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32) + + self.model.features.layer.weight.grad = None # Reset gradients + output = self.model.features.layer(input_tensor) + loss = output.sum() + loss.backward() + + expected_grads = torch.tensor(expected) + torch.testing.assert_close( + self.model.features.layer.weight.grad, expected_grads + ) + + +class _SubLayerModule(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(10, 10) + + +class _TestModel(nn.Module): + def __init__(self): + super().__init__() + self.features = _SubLayerModule() + self.classifier = nn.Linear(10, 2) + + +if __name__ == "__main__": + unittest.main()