Skip to content

Commit

Permalink
Merge branch 'soft_cap_attn' of github.com:ShashankMosaicML/llm-found…
Browse files Browse the repository at this point in the history
…ry into soft_cap_attn
  • Loading branch information
ShashankMosaicML committed Jul 23, 2024
2 parents b06adf6 + c72458a commit 258e048
Show file tree
Hide file tree
Showing 20 changed files with 1,339 additions and 870 deletions.
26 changes: 12 additions & 14 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,8 @@ def _save_checkpoint(self, state: State, logger: Logger):

cpu_offload = True

# Add a dtensor->cpu tensor hook to avoid CUDA OOM
def dtensor_to_tensor_hook(
# Add hook to move tensors to cpu to avoid CUDA OOM
def tensor_hook(
module: nn.Module,
state_dict: Dict[str, Any],
prefix: str,
Expand All @@ -449,20 +449,23 @@ def dtensor_to_tensor_hook(
dtensor_fqns.append(fqn)
tensor = tensor.full_tensor() # type: ignore
if dist.get_global_rank() == 0:
# Offload any DTensors to CPU
if cpu_offload:
tensor = tensor.cpu()
state_dict[fqn] = tensor
else:
state_dict[fqn] = None
# Convert the state dict to the requested precision
if isinstance(tensor, torch.Tensor):
state_dict[fqn] = tensor.to(dtype=self.dtype)
del tensor
if dist.get_global_rank() != 0:
for fqn in dtensor_fqns:
del state_dict[fqn]
state_dict = {}
return state_dict

hooks = []
for _, module in state_dict_model.named_modules():
if isinstance(module, FSDP):
hooks.append(
module._register_state_dict_hook(dtensor_to_tensor_hook),
)
hooks.append(module._register_state_dict_hook(tensor_hook),)

state_dict = get_model_state_dict(
state_dict_model,
Expand All @@ -474,11 +477,6 @@ def dtensor_to_tensor_hook(
for hook in hooks:
hook.remove()

# Convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)

new_model_instance = None # Need this for pyright because variable could be unbound

if dist.get_global_rank() == 0:
Expand Down Expand Up @@ -537,7 +535,7 @@ def dtensor_to_tensor_hook(
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if original_model.config.model_type == 'mpt':
if new_model_instance.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(
temp_save_dir,
Expand Down
26 changes: 26 additions & 0 deletions llmfoundry/cli/data_prep_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Annotated, Optional

import psutil
Expand All @@ -9,6 +10,7 @@
from llmfoundry.command_utils import (
convert_dataset_hf_from_args,
convert_dataset_json_from_args,
convert_delta_to_json_from_args,
convert_finetuning_dataset_from_args,
convert_text_to_mds_from_args,
)
Expand Down Expand Up @@ -240,3 +242,27 @@ def convert_text_to_mds(
trust_remote_code=trust_remote_code,
logging_level=logging_level,
)


@app.command(name='convert_delta_to_json')
def convert_delta_to_json_cli(
delta_table_name: Annotated[str, Option(..., help='UC table <catalog>.<schema>.<table name>')],
json_output_folder: Annotated[str, Option(..., help='Local path to save the converted json')],
http_path: Annotated[Optional[str], Option(help='If set, dbsql method is used')] = None,
batch_size: Annotated[int, Option(help='Row chunks to transmit a time to avoid OOM')] = 1 << 30,
processes: Annotated[int, Option(help='Number of processes allowed to use')] = os.cpu_count(), # type: ignore
cluster_id: Annotated[Optional[str], Option(help='Cluster ID with runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.')] = None,
use_serverless: Annotated[bool, Option(help='Use serverless or not. Make sure the workspace is entitled with serverless')] = False,
json_output_filename: Annotated[str, Option(help='The name of the combined final jsonl that combines all partitioned jsonl')] = 'train-00000-of-00001.jsonl',
):
"""Convert a Delta table into JSON files."""
convert_delta_to_json_from_args(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
batch_size=batch_size,
processes=processes,
cluster_id=cluster_id,
use_serverless=use_serverless,
json_output_filename=json_output_filename,
)
6 changes: 6 additions & 0 deletions llmfoundry/command_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
convert_dataset_json,
convert_dataset_json_from_args,
)
from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
convert_delta_to_json_from_args,
fetch_DT,
)
from llmfoundry.command_utils.data_prep.convert_finetuning_dataset import (
convert_finetuning_dataset,
convert_finetuning_dataset_from_args,
Expand Down Expand Up @@ -44,4 +48,6 @@
'convert_finetuning_dataset',
'convert_text_to_mds',
'convert_text_to_mds_from_args',
'convert_delta_to_json_from_args',
'fetch_DT',
]
Loading

0 comments on commit 258e048

Please sign in to comment.