Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Jan 23, 2024
1 parent 7cb401b commit 0757eab
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 59 deletions.
10 changes: 7 additions & 3 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,18 @@ def _autoset_attn_implementation_monkeypatch(
'peft_config',
must_exist=False,
convert=True)

if peft_config is not None:
peft_type = peft_config.get('peft_type', None)
if peft_type.upper() != 'LORA':
raise ValueError(f'Only LORA is supported for peft_type, but got {peft_type}.')
raise ValueError(
f'Only LORA is supported for peft_type, but got {peft_type}.'
)
task_type = peft_config.get('task_type', None)
if task_type.upper() != 'CAUSAL_LM':
raise ValueError(f'Only CAUSAL_LM is supported for task_type, but got {task_type}.')
raise ValueError(
f'Only CAUSAL_LM is supported for task_type, but got {task_type}.'
)
peft_config = LoraConfig(**peft_config)

composer_model = super().__init__(
Expand Down
93 changes: 48 additions & 45 deletions llmfoundry/models/hf/hf_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# which is MIT licensed

import functools
from typing import Any, Iterable, List, Optional, Union, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union

import torch
from transformers import PreTrainedModel
Expand Down Expand Up @@ -165,52 +165,55 @@ def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel,
'follow common layer/weight naming conventions.')
block_type = type(model_block[0])
# TODO: delete this
if init_device == 'mixed':
# For FSDP with models with different device initializations, `mixed`, which
# initializes the model on rank 0 on `cpu` and on all other ranks on `meta,``
# we need to tag all child modules that are torch.nn.Modules with `_fsdp_wrap`.
for child in model.children():
if isinstance(child, type(causal_base_model)):
continue
if isinstance(child, torch.nn.Module):
child._fsdp_wrap = True

for child in causal_base_model.children():
if isinstance(child, torch.nn.ModuleList):
continue
if isinstance(child, torch.nn.Module):
child._fsdp_wrap = True

if model.config.tie_word_embeddings and not model.config.model_type == 'mpt':
raise ValueError(
'The passed in HuggingFaceModel has tied word embeddings ' +
'and the passed in initialization device is `mixed.` ' +
'In order to support this initialization scheme, we would need to break '
+
'the weight tying. As a result, either use a different initialization scheme '
+ 'or in the model config set `tie_word_embeddings=False.`')
else:
# When using the HF LM models,
# the weights of the self.lm_head and self.transformer.wte are tied.
# This tying occurs inside the `self.post_init()` function.
# This is a hurdle for FSDP because they need to be in the same FSDP block
# These lines ensures that both modules stay together in the top-most block when
# the model has this tying enabled (almost all do; this property defaults to True)
if model.config.tie_word_embeddings:
causal_base_model._fsdp_wrap = False
tied_embeddings._fsdp_wrap = False
lm_head._fsdp_wrap = False
# if init_device == 'mixed':
# # For FSDP with models with different device initializations, `mixed`, which
# # initializes the model on rank 0 on `cpu` and on all other ranks on `meta,``
# # we need to tag all child modules that are torch.nn.Modules with `_fsdp_wrap`.
# for child in model.children():
# if isinstance(child, type(causal_base_model)):
# continue
# if isinstance(child, torch.nn.Module):
# child._fsdp_wrap = True

# for child in causal_base_model.children():
# if isinstance(child, torch.nn.ModuleList):
# continue
# if isinstance(child, torch.nn.Module):
# child._fsdp_wrap = True

# if model.config.tie_word_embeddings and not model.config.model_type == 'mpt':
# raise ValueError(
# 'The passed in HuggingFaceModel has tied word embeddings ' +
# 'and the passed in initialization device is `mixed.` ' +
# 'In order to support this initialization scheme, we would need to break '
# +
# 'the weight tying. As a result, either use a different initialization scheme '
# + 'or in the model config set `tie_word_embeddings=False.`')
# else:
# # When using the HF LM models,
# # the weights of the self.lm_head and self.transformer.wte are tied.
# # This tying occurs inside the `self.post_init()` function.
# # This is a hurdle for FSDP because they need to be in the same FSDP block
# # These lines ensures that both modules stay together in the top-most block when
# # the model has this tying enabled (almost all do; this property defaults to True)
if model.config.tie_word_embeddings:
causal_base_model._fsdp_wrap = False
tied_embeddings._fsdp_wrap = False
lm_head._fsdp_wrap = False

if hasattr(model, 'peft_type'):
peft_type = model.peft_type.lower()
active_adapters = [adapter.lower() for adapter in model.active_adapters]
for name, module in model.named_modules():
if peft_type in name.lower() and any(
adapter in name.lower() for adapter in active_adapters):
has_parameters = any(True for _ in module.parameters())
has_buffers = any(True for _ in module.buffers())
if has_parameters or has_buffers:
module._fsdp_wrap = True
if model.peft_type is not None:
peft_type = model.peft_type.lower()
active_adapters = [
adapter.lower() for adapter in model.active_adapters
]
for name, module in model.named_modules():
if peft_type in name.lower() and any(
adapter in name.lower() for adapter in active_adapters):
has_parameters = any(True for _ in module.parameters())
has_buffers = any(True for _ in module.buffers())
if has_parameters or has_buffers:
module._fsdp_wrap = True

# FSDP Wrap and Activation Checkpoint every model block
model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/hf/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from collections import UserDict
from typing import List, Mapping, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, List, Mapping, Optional

import torch
import transformers
Expand Down
8 changes: 5 additions & 3 deletions scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def fetch(
sparkSession: Optional[SparkSession] = None,
dbsql: Optional[Connection] = None,
) -> None:
"""Fetch UC delta table with databricks-connnect as JSONL.
"""Fetch UC delta table with databricks-connect as JSONL.
Args:
method (str): dbconnect or dbsql
Expand Down Expand Up @@ -405,8 +405,10 @@ def validate_and_get_cluster_info(cluster_id: str,
f'Cluster id {cluster_id} does not exist. Check cluster id and try again!'
)
stripped_runtime = re.sub(
r'[a-zA-Z]', '',
res.spark_version.split('-scala')[0].replace('x-snapshot', ''))
r'[a-zA-Z]',
'',
res.spark_version.split('-scala')[0].replace( # type: ignore
'x-snapshot', ''))
runtime_version = re.sub(r'[.-]*$', '', stripped_runtime)
if version.parse(runtime_version) < version.parse(
MINIMUM_SQ_CONNECT_DBR_VERSION):
Expand Down
8 changes: 1 addition & 7 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from omegaconf import OmegaConf as om
from transformers import PreTrainedTokenizerBase

from llmfoundry import (COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM,
MPTForCausalLM)
from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.callbacks import AsyncEval
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
Expand Down Expand Up @@ -210,11 +209,6 @@ def main(cfg: DictConfig) -> Trainer:
must_exist=False,
default_value=None,
convert=True)
lora_config: Optional[Dict[str, Any]] = pop_config(cfg,
'lora',
must_exist=False,
default_value=None,
convert=True)
eval_loader_config: Optional[Union[DictConfig, ListConfig]] = pop_config(
cfg, 'eval_loader', must_exist=False, default_value=None)
icl_tasks_config: Optional[Union[ListConfig,
Expand Down

0 comments on commit 0757eab

Please sign in to comment.