Skip to content

Commit

Permalink
Train parameters exclusively in specific ranges (axolotl-ai-cloud#1390)
Browse files Browse the repository at this point in the history
* Train parameters exclusively in specific ranges

* Fix the style and update docs

* Update yaml example
  • Loading branch information
seungduk-yanolja committed Mar 19, 2024
1 parent 371e294 commit d3e0a07
Show file tree
Hide file tree
Showing 4 changed files with 492 additions and 19 deletions.
12 changes: 6 additions & 6 deletions examples/mistral/mixtral.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ output_dir: ./qlora-out

## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters:
# - lm_head.*
# - model.embed_tokens.*
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
# - ^lm_head.weight$
# - ^model.embed_tokens.weight$[:32000]
# - model.layers.2[0-9]+.block_sparse_moe.gate
# - model.layers.2[0-9]+.block_sparse_moe.experts
# - model.layers.3[0-9]+.block_sparse_moe.gate
# - model.layers.3[0-9]+.block_sparse_moe.experts

model_config:
output_router_logits: true
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer

Expand Down Expand Up @@ -99,7 +99,7 @@ def train(
safe_serialization = cfg.save_safetensors is True

if cfg.unfrozen_parameters:
freeze_parameters_except(model, cfg.unfrozen_parameters)
freeze_layers_except(model, cfg.unfrozen_parameters)

trainer = setup_trainer(
cfg,
Expand Down
210 changes: 199 additions & 11 deletions src/axolotl/utils/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,224 @@
"""
import logging
import re
from typing import Callable, List, Tuple

from axolotl.utils.distributed import is_main_process

LOG = logging.getLogger("axolotl.utils.freeze")


def freeze_parameters_except(model, regex_patterns):
def freeze_layers_except(model, regex_patterns):
"""
Freezes all layers of the given model except for the layers that match given regex patterns.
Periods in the patterns are treated as literal periods, not as wildcard characters.
Parameters:
- model (nn.Module): The PyTorch model to be modified.
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name.
The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name.
E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"]
Returns:
None; the model is modified in place.
"""
# Escape periods and compile the regex patterns
compiled_patterns = [
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
]
if isinstance(regex_patterns, str):
regex_patterns = [regex_patterns]

# First, freeze all parameters in the model
for param in model.parameters():
param.requires_grad = False
patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]

# Unfreeze layers that match the regex patterns
for name, param in model.named_parameters():
if any(pattern.match(name) for pattern in compiled_patterns):
if is_main_process():
LOG.debug(f"unfreezing {name}")
param.requires_grad = False
unfrozen_ranges = []
for pattern in patterns:
if not pattern.match(name):
continue

param.requires_grad = True

if pattern.range is not None:
unfrozen_ranges.append(pattern.range)

merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))

if param.requires_grad and is_main_process():
unfrozen_ranges = (
f" with ranges {merged_unfrozen_ranges}"
if merged_unfrozen_ranges
else ""
)
LOG.debug(f"Unfrozen {name}{unfrozen_ranges}")

if not merged_unfrozen_ranges:
continue

# The range list we need is actually the inverted of the merged ranges
ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))

param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))

if is_main_process() and all(
not param.requires_grad for param in model.parameters()
):
LOG.warning("All parameters are frozen. Model will not be trained.")


def _invert_ranges(
given_ranges: List[Tuple[int, int]], layer_size: int
) -> List[Tuple[int, int]]:
"""
Inverts a list of ranges to obtain the ranges not covered by the given ranges.
Parameters:
- given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
Returns:
- List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
"""
if not given_ranges:
return [(0, layer_size)]

inverted_ranges = []
current_start = 0

for start, end in sorted(given_ranges):
if start > current_start:
inverted_ranges.append((current_start, start))
current_start = max(current_start, end)

# Handle the case where the last given range does not reach the end of the total_size
if current_start < layer_size:
inverted_ranges.append((current_start, layer_size))

return inverted_ranges


def _merge_ranges(
given_ranges: List[Tuple[int, int | None]], layer_size: int
) -> List[Tuple[int, int]]:
"""
Merges overlapping ranges and sorts the given ranges.
This function takes a list of ranges and merges any overlapping ranges. The ranges are represented
as tuples, where the first element is the start index (inclusive) and the second element is the end
index (exclusive). The end index can be None, indicating that the range extends to the end of the
sequence.
Parameters:
- given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
Returns:
- List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.
"""
# End of each range can be determined now since we have the total size
processed_ranges = [
(start, end if end is not None else layer_size) for start, end in given_ranges
]

# No need to merge if there's only one or no ranges
if len(processed_ranges) <= 1:
return processed_ranges

sorted_ranges = sorted(processed_ranges)

merged_ranges = [sorted_ranges[0]]
for start, end in sorted_ranges[1:]:
prev_start, prev_end = merged_ranges[-1]
if start <= prev_end:
merged_ranges[-1] = (prev_start, max(prev_end, end))
else:
merged_ranges.append((start, end))

return merged_ranges


def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:
"""
Create a hook to freeze parameters in specified ranges by setting their gradients to zero.
This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain
two integers representing the start and end indices of the range.
Parameters:
- ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.
Returns:
- Callable: A hook function to be used with `register_hook` on parameters.
Example usage:
```
ranges_to_freeze = [(0, 10), (20, 30)]
hook = _create_freeze_parameters_hook(ranges_to_freeze)
model.register_hook(hook)
```
"""

def freeze_parameters_hook(gradients):
for start, end in ranges_to_freeze:
gradients[start:end].zero_()

return freeze_parameters_hook


class LayerNamePattern:
"""
Represents a regex pattern for layer names, potentially including a parameter index range.
"""

def __init__(self, pattern: str):
"""
Initializes a new instance of the LayerNamePattern class.
Parameters:
- pattern (str): The regex pattern for layer names, potentially including a parameter index range.
"""
self.raw_pattern = pattern
name_pattern, self.range = self._parse_pattern(pattern)
self.name_regex = re.compile(name_pattern.replace(".", "\\."))

def match(self, name: str) -> bool:
"""
Checks if the given layer name matches the regex pattern.
Parameters:
- name (str): The layer name to check.
Returns:
- bool: True if the layer name matches the pattern, False otherwise.
"""
return self.name_regex.match(name) is not None

def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]:
"""
Extracts the range pattern from the given pattern.
Parameters:
- pattern (str): The pattern to extract the range from.
Returns:
- Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.
"""
match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern)
if not match:
return pattern, None

base_pattern, start_part, end_part = match.groups()

if end_part is None and start_part.isdecimal():
index = int(start_part)
return base_pattern, (index, index + 1)

# [:end] or [start:] or [start:end]
start = int(start_part) if start_part else 0
end = int(end_part) if end_part else None

if end is not None and start >= end:
raise ValueError(
f"Invalid range in layer name pattern: {pattern}."
"End of range must be greater than start."
)
return base_pattern, (start, end)
Loading

0 comments on commit d3e0a07

Please sign in to comment.