Skip to content

Commit

Permalink
Merge branch 'main' into tessa/callibration-script
Browse files Browse the repository at this point in the history
  • Loading branch information
tbarton16 authored Feb 2, 2024
2 parents f986175 + 15ee0ac commit 577fe7c
Show file tree
Hide file tree
Showing 11 changed files with 628 additions and 361 deletions.
3 changes: 1 addition & 2 deletions llmfoundry/callbacks/eval_gauntlet_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ def __init__(self,
elif self.weighting == Weighting.SAMPLE_SZ:
weight = cumulative_samples
elif self.weighting == Weighting.LOG_SAMPLE_SZ:
weight = max(math.log(cumulative_samples, 2), 1)

weight = max(math.log2(cumulative_samples), 1)
assert weight is not None
benchmark['weighting'] = weight

Expand Down
52 changes: 41 additions & 11 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import re
import tempfile
from pathlib import Path
from typing import Optional, Sequence, Union
from typing import Any, Dict, Optional, Sequence, Union

import torch
from composer.core import Callback, Event, State, Time, TimeUnit
Expand Down Expand Up @@ -203,14 +203,17 @@ def _save_checkpoint(self, state: State, logger: Logger):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
composer_model = state.model.module
original_model: PreTrainedModel = state.model.module.model
state_dict_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
composer_model = state.model
original_model: PreTrainedModel = state.model.model.module
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
else:
composer_model = state.model
original_model: PreTrainedModel = state.model.model
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
Expand All @@ -237,10 +240,25 @@ def _save_checkpoint(self, state: State, logger: Logger):
copied_config.init_device = 'cpu'

log.debug(f'Creating new model instance')
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)

if composer_model.using_peft:
# We don't use meta here because the state dict does not contain the full
# model, only the adapter weights.
active_adapter = original_model.active_adapter
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(copied_config)

new_model_instance = type(original_model)(
new_base_model_instance,
original_model.peft_config[active_adapter])
else:
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)

new_model_instance.to(dtype=self.dtype)
new_model_instance.load_state_dict(state_dict)

# Then load the state dict in with "assign" so that the state dict
# is loaded properly even though the model is initially on meta device.
Expand Down Expand Up @@ -295,12 +313,24 @@ def _save_checkpoint(self, state: State, logger: Logger):
# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
mlflow_logger.save_model(
flavor='transformers',
transformers_model=components,
path=local_save_path,
**self.mlflow_logging_config,
)
model_saving_kwargs: Dict[str, Any] = {
'path': local_save_path
}
if composer_model.using_peft:
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs[
'save_pretrained_dir'] = temp_save_dir
model_saving_kwargs[
'metadata'] = self.mlflow_logging_config[
'metadata']
else:
model_saving_kwargs['flavor'] = 'transformers'
model_saving_kwargs[
'transformers_model'] = components
model_saving_kwargs.update(
self.mlflow_logging_config)

mlflow_logger.save_model(**model_saving_kwargs)

license_filename = _maybe_get_license_filename(
local_save_path)
Expand Down
Loading

0 comments on commit 577fe7c

Please sign in to comment.