diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index ef85332047..062aa41bf4 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -19,7 +19,7 @@ defaults: working-directory: . jobs: code-quality: - runs-on: ubuntu-20.04 + runs-on: linux-ubuntu-latest timeout-minutes: 30 strategy: matrix: diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index fc511d7e60..cf3581f716 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -8,7 +8,7 @@ on: jobs: coverage: timeout-minutes: 5 - runs-on: ubuntu-latest + runs-on: linux-ubuntu-latest steps: - name: Checkout Repo uses: actions/checkout@v3 diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index e4e6f83551..17bb976a5d 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -24,13 +24,6 @@ jobs: base_image: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04-aws dep_groups: "[gpu]" steps: - - name: Maximize Build Space on Worker - uses: easimon/maximize-build-space@v4 - with: - overprovision-lvm: true - remove-dotnet: true - remove-android: true - remove-haskell: true - name: Checkout uses: actions/checkout@v3 @@ -47,6 +40,13 @@ jobs: username: ${{ secrets.DOCKER_HUB_USERNAME }} password: ${{ secrets.DOCKER_HUB_PASSWORD }} + - name: Login to GHCR + uses: docker/login-action@v2 + with: + username: ${{ secrets.GHCR_USERNAME }} + password: ${{ secrets.GHCR_TOKEN }} + registry: ghcr.io + - name: Calculate Docker Image Variables run: | set -euxo pipefail @@ -60,13 +60,17 @@ jobs: if [ "${{ github.event_name }}" == "pull_request" ]; then echo "Triggered by pull_request event." STAGING_REPO="mosaicml/ci-staging" - IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA}" + GHCR_STAGING_REPO="ghcr.io/databricks-mosaic/ci-staging" + GHCR_IMAGE_TAG="${GHCR_STAGING_REPO}:${{matrix.name}}-${GIT_SHA}" + IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA},${GHCR_IMAGE_TAG}" IMAGE_CACHE="${STAGING_REPO}:${{matrix.name}}-buildcache" else # Triggered by push or workflow_dispatch event echo "Triggered by ${{ github.event_name }} event, releasing to prod" PROD_REPO="mosaicml/llm-foundry" - IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest" + GHCR_PROD_REPO="ghcr.io/databricks-mosaic/llm-foundry" + GHCR_IMAGE_TAG="${GHCR_PROD_REPO}:${{matrix.name}}-${GIT_SHA},${GHCR_PROD_REPO}:${{matrix.name}}-latest" + IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest,${GHCR_IMAGE_TAG}" IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache" fi diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 2dd1c0edab..2c85719756 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -15,23 +15,28 @@ concurrency: cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} jobs: pytest-cpu: - uses: mosaicml/ci-testing/.github/workflows/pytest-cpu.yaml@v0.0.9 + name: ${{ matrix.name }} + runs-on: ubuntu-latest strategy: matrix: include: - name: "cpu-2.3.1" + pip_deps: "[all-cpu]" container: mosaicml/pytorch:2.3.1_cpu-python3.11-ubuntu20.04 markers: "not gpu" pytest_command: "coverage run -m pytest" - name: ${{ matrix.name }} - if: github.repository_owner == 'mosaicml' - with: - container: ${{ matrix.container }} - name: ${{ matrix.name }} - pip_deps: "[all-cpu]" - pytest-command: ${{ matrix.pytest_command }} - pytest-markers: ${{ matrix.markers }} - safe_directory: llm-foundry + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Run PR CPU Tests + uses: mosaicml/ci-testing/.github/actions/pytest-cpu@v0.1.0 + with: + name: ${{ matrix.name }} + container: ${{ matrix.container }} + pip_deps: ${{ matrix.pip_deps }} + pytest_command: ${{ matrix.pytest_command }} + pytest_markers: ${{ matrix.markers }} + safe_directory: llm-foundry coverage: uses: ./.github/workflows/coverage.yaml name: Coverage Results diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index c5638e403d..ba1a4f9ba4 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -9,12 +9,15 @@ on: - main - release/** workflow_dispatch: +# Cancel old runs when a new commit is pushed to the same branch if not on main or dev concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} jobs: pytest-gpu-1: - uses: mosaicml/ci-testing/.github/workflows/pytest-gpu.yaml@v0.0.9 + name: ${{ matrix.name }} + if: github.repository_owner == 'mosaicml' + runs-on: linux-ubuntu-latest strategy: fail-fast: false matrix: @@ -22,24 +25,28 @@ jobs: - name: "gpu-2.3.1-1" container: mosaicml/llm-foundry:2.3.1_cu121-latest markers: "gpu" - pytest_command: "coverage run -m pytest" pip_deps: "[all]" + pytest_command: "coverage run -m pytest" + ci_repo_gpu_test_ref: v0.1.0 + steps: + - name: Run PR GPU Tests + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.0 + with: + container: ${{ matrix.container }} + git_repo: mosaicml/llm-foundry + mcloud_timeout: 1800 + name: ${{ matrix.name }} + pip_deps: ${{ matrix.pip_deps }} + pytest_command: ${{ matrix.pytest_command }} + pytest_markers: ${{ matrix.markers }} + python_version: 3.9 + gpu_num: 1 + mcloud_api_key: ${{ secrets.MCLOUD_API_KEY }} + ci_repo_gpu_test_ref: ${{ matrix.ci_repo_gpu_test_ref }} + pytest-gpu-2: name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' - with: - container: ${{ matrix.container }} - git_repo: mosaicml/llm-foundry - mcloud-timeout: 1800 - name: ${{ matrix.name }} - pip_deps: ${{ matrix.pip_deps }} - pytest-command: ${{ matrix.pytest_command }} - pytest-markers: ${{ matrix.markers }} - python-version: 3.9 - gpu_num: 1 - secrets: - mcloud-api-key: ${{ secrets.MCLOUD_API_KEY }} - pytest-gpu-2: - uses: mosaicml/ci-testing/.github/workflows/pytest-gpu.yaml@v0.0.9 + runs-on: linux-ubuntu-latest strategy: fail-fast: false matrix: @@ -47,24 +54,28 @@ jobs: - name: "gpu-2.3.1-2" container: mosaicml/llm-foundry:2.3.1_cu121-latest markers: "gpu" - pytest_command: "coverage run -m pytest" pip_deps: "[all]" + pytest_command: "coverage run -m pytest" + ci_repo_gpu_test_ref: v0.1.0 + steps: + - name: Run PR GPU Tests + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.0 + with: + container: ${{ matrix.container }} + git_repo: mosaicml/llm-foundry + mcloud_timeout: 1800 + name: ${{ matrix.name }} + pip_deps: ${{ matrix.pip_deps }} + pytest_command: ${{ matrix.pytest_command }} + pytest_markers: ${{ matrix.markers }} + python_version: 3.9 + gpu_num: 2 + mcloud_api_key: ${{ secrets.MCLOUD_API_KEY }} + ci_repo_gpu_test_ref: ${{ matrix.ci_repo_gpu_test_ref }} + pytest-gpu-4: name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' - with: - container: ${{ matrix.container }} - git_repo: mosaicml/llm-foundry - mcloud-timeout: 1800 - name: ${{ matrix.name }} - pip_deps: ${{ matrix.pip_deps }} - pytest-command: ${{ matrix.pytest_command }} - pytest-markers: ${{ matrix.markers }} - python-version: 3.9 - gpu_num: 2 - secrets: - mcloud-api-key: ${{ secrets.MCLOUD_API_KEY }} - pytest-gpu-4: - uses: mosaicml/ci-testing/.github/workflows/pytest-gpu.yaml@v0.0.9 + runs-on: linux-ubuntu-latest strategy: fail-fast: false matrix: @@ -72,19 +83,21 @@ jobs: - name: "gpu-2.3.1-4" container: mosaicml/llm-foundry:2.3.1_cu121-latest markers: "gpu" - pytest_command: "coverage run -m pytest" pip_deps: "[all]" - name: ${{ matrix.name }} - if: github.repository_owner == 'mosaicml' - with: - container: ${{ matrix.container }} - git_repo: mosaicml/llm-foundry - mcloud-timeout: 1800 - name: ${{ matrix.name }} - pip_deps: ${{ matrix.pip_deps }} - pytest-command: ${{ matrix.pytest_command }} - pytest-markers: ${{ matrix.markers }} - python-version: 3.9 - gpu_num: 4 - secrets: - mcloud-api-key: ${{ secrets.MCLOUD_API_KEY }} + pytest_command: "coverage run -m pytest" + ci_repo_gpu_test_ref: v0.1.0 + steps: + - name: Run PR GPU Tests + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.0 + with: + container: ${{ matrix.container }} + git_repo: mosaicml/llm-foundry + mcloud_timeout: 1800 + name: ${{ matrix.name }} + pip_deps: ${{ matrix.pip_deps }} + pytest_command: ${{ matrix.pytest_command }} + pytest_markers: ${{ matrix.markers }} + python_version: 3.9 + gpu_num: 4 + mcloud_api_key: ${{ secrets.MCLOUD_API_KEY }} + ci_repo_gpu_test_ref: ${{ matrix.ci_repo_gpu_test_ref }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 144e3f1ad3..c09f9bb7a5 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -14,7 +14,7 @@ jobs: name: Build and Publish llm-foundry PyPI Package needs: - code-quality - runs-on: ubuntu-latest + runs-on: linux-ubuntu-latest steps: - name: Checkout source uses: actions/checkout@v3 diff --git a/.github/workflows/smoketest.yaml b/.github/workflows/smoketest.yaml index 2163111710..d38849cddc 100644 --- a/.github/workflows/smoketest.yaml +++ b/.github/workflows/smoketest.yaml @@ -18,7 +18,7 @@ defaults: working-directory: . jobs: smoketest: - runs-on: ubuntu-20.04 + runs-on: linux-ubuntu-latest timeout-minutes: 20 strategy: matrix: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc2e3f55cd..b45021dd8c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -77,17 +77,6 @@ repos: hooks: - id: docformatter args: [--in-place, --wrap-summaries=80, --wrap-descriptions=80] -- repo: https://github.com/PyCQA/pydocstyle - hooks: - - id: pydocstyle - name: pydocstyle - entry: pydocstyle - language: python - types: [python] - exclude: (.ci|.github) - additional_dependencies: - - toml - rev: 6.1.1 - repo: https://github.com/adrienverge/yamllint.git rev: v1.28.0 hooks: diff --git a/README.md b/README.md index 0299e43710..e8a6708c5a 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ DBRX is a state-of-the-art open source LLM trained by Databricks Mosaic team. It | DBRX Base | 32768 | https://huggingface.co/databricks/dbrx-base | | DBRX Instruct | 32768 | https://huggingface.co/databricks/dbrx-instruct | -Our model weights and code are licensed for both researchers and commercial entities. The Databricks Open Source License can be found at [LICENSE](https://github.com/databricks/dbrx/LICENSE), and our Acceptable Use Policy can be found [here](https://www.databricks.com/legal/acceptable-use-policy-open-model). +Our model weights and code are licensed for both researchers and commercial entities. The Databricks Open Source License can be found at [LICENSE](https://github.com/databricks/dbrx/blob/main/LICENSE), and our Acceptable Use Policy can be found [here](https://www.databricks.com/legal/acceptable-use-policy-open-model). For more information about the DBRX models, see https://github.com/databricks/dbrx. @@ -309,10 +309,15 @@ dependencies = [ "llm-foundry", ] +# Note: Even though in python code, this would be llmfoundry.registry.loggers, +# when specified in the entry_points, it has to be "llmfoundry_loggers". That is, +# the segments of the name should be joined by an _ in the entry_points section. [project.entry-points."llmfoundry_loggers"] my_logger = "foundry_registry.loggers:MyLogger" ``` +If developing new components via entrypoints, it is important to note that Python entrypoints are global to the Python environment. This means that if you have multiple packages that register components with the same key, the last one installed will be the one used. This can be useful for overriding components in LLM Foundry, but can also lead to unexpected behavior if not careful. Additionally, if you change the pyproject.toml, you will need to reinstall the package for the changes to take effect. You can do this quickly by installing with `pip install -e . --no-deps` to avoid reinstalling dependencies. + ### Direct call to register You can also register a component directly in your code: diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 8dbd180c0a..b851aaa559 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -50,6 +50,7 @@ tokenizers, utils, ) +from llmfoundry._version import __version__ from llmfoundry.data import StreamingFinetuningDataset, StreamingTextDataset from llmfoundry.eval import InContextLearningDataset, InContextLearningMetric from llmfoundry.models.hf import ComposerHFCausalLM @@ -63,6 +64,7 @@ from llmfoundry.optim import DecoupledLionW __all__ = [ + '__version__', 'StreamingFinetuningDataset', 'StreamingTextDataset', 'InContextLearningDataset', @@ -87,5 +89,3 @@ 'tokenizers', 'utils', ] - -__version__ = '0.11.0.dev0' diff --git a/llmfoundry/_version.py b/llmfoundry/_version.py new file mode 100644 index 0000000000..4c11746b43 --- /dev/null +++ b/llmfoundry/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""The LLM Foundry Version.""" + +__version__ = '0.11.0.dev' diff --git a/llmfoundry/callbacks/async_eval_callback.py b/llmfoundry/callbacks/async_eval_callback.py index 646d86c8d3..1b3c31e861 100644 --- a/llmfoundry/callbacks/async_eval_callback.py +++ b/llmfoundry/callbacks/async_eval_callback.py @@ -557,7 +557,8 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run: installation_path = i['path'] if not found_llm_foundry: - from llmfoundry import __version__ as latest_foundry_version + from llmfoundry._version import \ + __version__ as latest_foundry_version # If github integration is not found, foundry is likely installed # through the run command. In this case, we'll add the integration diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 98a672f8db..449ab338bc 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -128,18 +128,17 @@ def after_load(self, state: State, logger: Logger): self._validate_dataloader(state.train_dataloader) # If checkpoint was saved before iteration was incremented, we need to increment it now + duration = self._schedule[self._schedule_index]['duration'] if (( - self._schedule[self._schedule_index]['duration'].unit - == TimeUnit.TOKEN and state.timestamp.token_in_iteration >= - self._schedule[self._schedule_index]['duration'].value + duration.unit == TimeUnit.TOKEN and + state.timestamp.token_in_iteration >= duration.value ) or ( - self._schedule[self._schedule_index]['duration'].unit - == TimeUnit.EPOCH and state.timestamp.epoch_in_iteration >= - self._schedule[self._schedule_index]['duration'].value + duration.unit == TimeUnit.EPOCH and + state.timestamp.epoch_in_iteration >= duration.value )): log.warning(( - 'The CurriculumLearning callback has detected that the previous run did not correctly ' - 'increment the iteration.' + 'The CurriculumLearning callback has detected that the ' + 'previous run did not correctly increment the iteration.' )) self._schedule_index += 1 state.timestamp = state.timestamp.to_next_iteration() @@ -199,24 +198,13 @@ def load_state_dict(self, state: dict[str, Any]): f'Expected {saved_loader} but got {current_loader}', )) - # Ensure that the current datamix duration is greater than timestamp + # Ensure that the current datamix duration is in the correct units duration = self._schedule[self._schedule_index]['duration'] if duration.unit != TimeUnit.TOKEN and duration.unit != TimeUnit.EPOCH: raise ValueError(( f'Duration must be in terms of tokens or epochs, but got ', f'{duration.unit}.', )) - if (( - duration.unit == TimeUnit.TOKEN and - duration > state['timestamp'].token_in_iteration - ) or ( - duration.unit == TimeUnit.EPOCH and - duration > state['timestamp'].epoch_in_iteration - )): - raise ValueError(( - 'The duration of the current datamix must be less or equal to ' - 'than the saved timestamp.' - )) def _build_train_loader( self, diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 4de7f9f2c6..79dc73de98 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -18,7 +18,6 @@ import torch import torch.nn as nn from composer.core import Callback, Event, Precision, State, Time, TimeUnit -from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger from composer.models import HuggingFaceModel from composer.utils import ( @@ -29,7 +28,12 @@ ) from composer.utils.misc import create_interval_scheduler from mlflow.transformers import _fetch_model_card, _write_license_information -from packaging import version +from torch.distributed._tensor import DTensor +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import ( PretrainedConfig, PreTrainedModel, @@ -179,6 +183,7 @@ def __init__( 'bfloat16': torch.bfloat16, }[precision] self.flatten_imports = flatten_imports + self.using_peft = False # mlflow config setup self.mlflow_registered_model_name = mlflow_registered_model_name @@ -212,6 +217,14 @@ def __init__( ) self.mlflow_logging_config = mlflow_logging_config + if 'metadata' in self.mlflow_logging_config: + self.pretrained_model_name = self.mlflow_logging_config[ + 'metadata'].get( + 'pretrained_model_name', + None, + ) + else: + self.pretrained_model_name = None self.huggingface_folder_name_fstr = os.path.join( 'huggingface', @@ -274,6 +287,15 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( '1GB', ) + + # Check if the model is using PEFT + if state.is_model_ddp: + composer_model = state.model.module + elif isinstance(state.model.model, FSDP): + composer_model = state.model + else: + composer_model = state.model + self.using_peft = composer_model.using_peft elif event == Event.FIT_END: # Wait for all child processes spawned by the callback to finish. timeout = 3600 @@ -362,6 +384,34 @@ def transform_config( copied_config.ffn_config['moe_world_size'] = 1 return copied_config + def pre_register_edit(self, local_save_path: str): + """Edit the model before registering with MLflow. + + This allows a subclass to modify the model before registering with MLflow. The base class implementation will + make no modifications. + + Args: + local_save_path (str): The path to the model to be transformed. + """ + pass + + def transform_model_pre_registration( + self, + model: PreTrainedModel, + ) -> PreTrainedModel: + """Transform the model before registering with MLflow. + + This allows a subclass to modify the model before registering with MLflow. The base class implementation will + make no modifications. + + Args: + model (PreTrainedModel): The model to be transformed. + + Returns: + PreTrainedModel: The transformed model. + """ + return model + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -388,85 +438,63 @@ def _save_checkpoint(self, state: State, logger: Logger): temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir log.debug('Gathering state dict') - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP if state.is_model_ddp: - composer_model = state.model.module original_model: PreTrainedModel = state.model.module.model state_dict_model = state.model.module.model original_tokenizer = state.model.module.tokenizer elif isinstance(state.model.model, FSDP): - composer_model = state.model original_model: PreTrainedModel = state.model.model.module state_dict_model = state.model.model original_tokenizer = state.model.tokenizer else: - composer_model = state.model original_model: PreTrainedModel = state.model.model state_dict_model = state.model.model original_tokenizer = state.model.tokenizer - if version.parse(torch.__version__) > version.parse('2.2.9'): - from torch.distributed._tensor import DTensor - from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - get_model_state_dict, - ) - cpu_offload = True - - # Add a dtensor->cpu tensor hook to avoid CUDA OOM - def dtensor_to_tensor_hook( - module: nn.Module, - state_dict: Dict[str, Any], - prefix: str, - *args: Any, - ) -> Dict[str, Any]: - dtensor_fqns = [] - for fqn in state_dict.keys(): - tensor = state_dict[fqn] - if isinstance(tensor, DTensor): - dtensor_fqns.append(fqn) - tensor = tensor.full_tensor() # type: ignore - if dist.get_global_rank() == 0: - if cpu_offload: - tensor = tensor.cpu() - state_dict[fqn] = tensor - if dist.get_global_rank() != 0: - for fqn in dtensor_fqns: - del state_dict[fqn] - return state_dict - - hooks = [] - for _, module in state_dict_model.named_modules(): - if isinstance(module, FSDP): - hooks.append( - module. - _register_state_dict_hook(dtensor_to_tensor_hook), - ) - - state_dict = get_model_state_dict( - state_dict_model, - options=StateDictOptions( - full_state_dict=True, - cpu_offload=cpu_offload, - ), - ) - for hook in hooks: - hook.remove() - else: - state_dict_context = fsdp_state_dict_type_context( - original_model, - state_dict_type='full', - ) if ((not state.is_model_ddp) and - isinstance(state_dict_model, - FSDP)) else contextlib.nullcontext() - with state_dict_context: - state_dict = state_dict_model.state_dict() - - # Convert the state dict to the requested precis - for k, v in state_dict.items(): - if isinstance(v, torch.Tensor): - state_dict[k] = v.to(dtype=self.dtype) + cpu_offload = True + + # Add hook to move tensors to cpu to avoid CUDA OOM + def tensor_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> Dict[str, Any]: + dtensor_fqns = [] + for fqn in state_dict.keys(): + tensor = state_dict[fqn] + if isinstance(tensor, DTensor): + dtensor_fqns.append(fqn) + tensor = tensor.full_tensor() # type: ignore + if dist.get_global_rank() == 0: + # Offload any DTensors to CPU + if cpu_offload: + tensor = tensor.cpu() + state_dict[fqn] = tensor + else: + state_dict[fqn] = None + + if isinstance(state_dict[fqn], torch.Tensor): + state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) + del tensor + if dist.get_global_rank() != 0: + state_dict = {} + return state_dict + + hooks = [] + for _, module in state_dict_model.named_modules(): + hooks.append(module._register_state_dict_hook(tensor_hook),) + + state_dict = get_model_state_dict( + state_dict_model, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=cpu_offload, + ), + ) + for hook in hooks: + hook.remove() new_model_instance = None # Need this for pyright because variable could be unbound @@ -480,22 +508,19 @@ def dtensor_to_tensor_hook( log.debug(f'Creating new model instance') - if composer_model.using_peft: - # We don't use meta here because the state dict does not contain the full - # model, only the adapter weights. - active_adapter = original_model.active_adapter - base_model = original_model.get_base_model() - new_base_model_instance = type(base_model)(new_config) - - new_model_instance = type(original_model)( - new_base_model_instance, - original_model.peft_config[active_adapter], - ) - new_model_instance.to(dtype=self.dtype) - else: - # First create the model instance on meta device to avoid the - # initialization cost. - with init_empty_weights(): + # First create the model instance on meta device to avoid the + # initialization cost. + with init_empty_weights(): + if self.using_peft: + active_adapter = original_model.active_adapter + base_model = original_model.get_base_model() + new_base_model_instance = type(base_model)(new_config) + + new_model_instance = type(original_model)( + new_base_model_instance, + original_model.peft_config[active_adapter], + ) + else: new_model_instance = type(original_model)(new_config) new_model_instance.generation_config.update( **original_model.generation_config.to_dict(), @@ -512,6 +537,16 @@ def dtensor_to_tensor_hook( original_tokenizer, ) + # Ensure that the pretrained model name is correctly set on the saved HF checkpoint. + if self.pretrained_model_name is not None: + new_model_instance.name_or_path = self.pretrained_model_name + if self.using_peft: + new_model_instance.base_model.name_or_path = self.pretrained_model_name + for k in new_model_instance.peft_config.keys(): + new_model_instance.peft_config[ + k + ].base_model_name_or_path = self.pretrained_model_name + log.debug('Saving Hugging Face checkpoint to disk') # This context manager casts the TE extra state in io.BytesIO format to tensor format # Needed for proper hf ckpt saving. @@ -529,7 +564,7 @@ def dtensor_to_tensor_hook( original_tokenizer.save_pretrained(temp_save_dir) # Only need to edit files for MPT because it has custom code - if original_model.config.model_type == 'mpt': + if new_model_instance.config.model_type == 'mpt': log.debug('Editing MPT files for HuggingFace compatibility') edit_files_for_hf_compatibility( temp_save_dir, @@ -556,6 +591,11 @@ def dtensor_to_tensor_hook( if dist.get_global_rank() == 0: if self.mlflow_registered_model_name and self._is_last_batch(state): + + new_model_instance = self.transform_model_pre_registration( + new_model_instance, + ) + components = {'model': new_model_instance} if original_tokenizer is not None: components['tokenizer'] = original_tokenizer @@ -575,7 +615,7 @@ def dtensor_to_tensor_hook( model_saving_kwargs: Dict[str, Any] = { 'path': local_save_path, } - if composer_model.using_peft: + if self.using_peft: model_saving_kwargs['flavor'] = 'peft' model_saving_kwargs['save_pretrained_dir' ] = temp_save_dir @@ -591,15 +631,18 @@ def dtensor_to_tensor_hook( ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( ) with context_manager: + # Add the pip requirements directly to avoid mlflow + # attempting to run inference on the model + model_saving_kwargs['pip_requirements'] = [ + 'transformers', + 'torch', + ] mlflow_logger.save_model(**model_saving_kwargs) # Upload the license file generated by mlflow during the model saving. license_filename = _maybe_get_license_filename( local_save_path, - self.mlflow_logging_config['metadata'].get( - 'pretrained_model_name', - None, - ), + self.pretrained_model_name, ) if license_filename is not None: mlflow_logger._mlflow_client.log_artifact( @@ -607,6 +650,8 @@ def dtensor_to_tensor_hook( os.path.join(local_save_path, license_filename), ) + self.pre_register_edit(local_save_path,) + # Spawn a new process to register the model. process = SpawnProcess( target=_register_model_with_run_id_multiprocess, diff --git a/llmfoundry/cli/data_prep_cli.py b/llmfoundry/cli/data_prep_cli.py index 3ca53f4104..130e0a6585 100644 --- a/llmfoundry/cli/data_prep_cli.py +++ b/llmfoundry/cli/data_prep_cli.py @@ -1,6 +1,7 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import os from typing import Annotated, Optional import psutil @@ -9,6 +10,8 @@ from llmfoundry.command_utils import ( convert_dataset_hf_from_args, convert_dataset_json_from_args, + convert_delta_to_json_from_args, + convert_finetuning_dataset_from_args, convert_text_to_mds_from_args, ) @@ -106,6 +109,97 @@ def convert_dataset_json( ) +@app.command(name='convert_finetuning_dataset') +def convert_finetuning_dataset_cli( + dataset: Annotated[ + str, + Option( + ..., + help= + 'Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`).', + )], + data_subset: Annotated[ + Optional[str], + Option(help='(Optional) subset of data to use.',)] = None, + splits: Annotated[str, + Option(help='Comma-separated list of dataset splits'), + ] = 'train,validation', + preprocessor: Annotated[ + Optional[str], + Option( + help= + 'Name or import path of function used to preprocess (reformat) the dataset.', + )] = None, + data_files: Annotated[ + str, Option(help='Data file for each split. Comma-separated.')] = '', + skip_preprocessing: Annotated[ + bool, Option(help='Whether to skip preprocessing.')] = False, + out_root: Annotated[ + str, + Option( + ..., + help= + 'Root path of output directory where MDS shards will be stored. Can be a remote URI.', + )] = '', + local: Annotated[ + Optional[str], + Option( + help= + '(Optional) root path of local directory if you want to keep a local copy when out_root is remote.', + )] = None, + compression: Annotated[ + Optional[str], + Option(help='(Optional) name of compression algorithm to use.')] = None, + num_workers: Annotated[Optional[int], + Option(help='Number of workers.')] = None, + tokenizer: Annotated[Optional[str], + Option(help='Tokenizer used for processing.')] = None, + tokenizer_kwargs: Annotated[ + Optional[str], + Option( + help= + 'Keyword arguments for tokenizer initialization in JSON format.', + )] = None, + max_seq_len: Annotated[int, Option(help='Maximum sequence length.')] = 2048, + target_prompts: Annotated[ + str, + Option(help='Policy for when to use prompts as training targets.'), + ] = 'none', + target_responses: Annotated[ + str, + Option(help='Policy for which responses to treat as training targets.'), + ] = 'last', + encoder_decoder: Annotated[ + bool, + Option( + help= + 'Set if the data are intended to be used to train an encoder-decoder model.', + )] = False, +): + """Convert a Finetuning Dataset to MDS streaming format.""" + # Convert comma-separated args + splits_list = splits.split(',') if splits else [] + data_files_list = data_files.split(',') if data_files else [] + convert_finetuning_dataset_from_args( + dataset=dataset, + data_subset=data_subset, + splits=splits_list, + preprocessor=preprocessor, + data_files=data_files_list, + skip_preprocessing=skip_preprocessing, + out_root=out_root, + local=local, + compression=compression, + num_workers=num_workers, + tokenizer=tokenizer, + tokenizer_kwargs=tokenizer_kwargs, + max_seq_len=max_seq_len, + target_prompts=target_prompts, + target_responses=target_responses, + encoder_decoder=encoder_decoder, + ) + + @app.command(name='convert_text_to_mds') def convert_text_to_mds( output_folder: Annotated[str, Option(..., help='The folder to write output to')], @@ -148,3 +242,27 @@ def convert_text_to_mds( trust_remote_code=trust_remote_code, logging_level=logging_level, ) + + +@app.command(name='convert_delta_to_json') +def convert_delta_to_json_cli( + delta_table_name: Annotated[str, Option(..., help='UC table ..')], + json_output_folder: Annotated[str, Option(..., help='Local path to save the converted json')], + http_path: Annotated[Optional[str], Option(help='If set, dbsql method is used')] = None, + batch_size: Annotated[int, Option(help='Row chunks to transmit a time to avoid OOM')] = 1 << 30, + processes: Annotated[int, Option(help='Number of processes allowed to use')] = os.cpu_count(), # type: ignore + cluster_id: Annotated[Optional[str], Option(help='Cluster ID with runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.')] = None, + use_serverless: Annotated[bool, Option(help='Use serverless or not. Make sure the workspace is entitled with serverless')] = False, + json_output_filename: Annotated[str, Option(help='The name of the combined final jsonl that combines all partitioned jsonl')] = 'train-00000-of-00001.jsonl', +): + """Convert a Delta table into JSON files.""" + convert_delta_to_json_from_args( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + batch_size=batch_size, + processes=processes, + cluster_id=cluster_id, + use_serverless=use_serverless, + json_output_filename=json_output_filename, + ) diff --git a/llmfoundry/command_utils/__init__.py b/llmfoundry/command_utils/__init__.py index 995c5345e7..0226c4f408 100644 --- a/llmfoundry/command_utils/__init__.py +++ b/llmfoundry/command_utils/__init__.py @@ -8,6 +8,14 @@ convert_dataset_json, convert_dataset_json_from_args, ) +from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( + convert_delta_to_json_from_args, + fetch_DT, +) +from llmfoundry.command_utils.data_prep.convert_finetuning_dataset import ( + convert_finetuning_dataset, + convert_finetuning_dataset_from_args, +) from llmfoundry.command_utils.data_prep.convert_text_to_mds import ( convert_text_to_mds, convert_text_to_mds_from_args, @@ -36,6 +44,10 @@ 'convert_dataset_hf_from_args', 'convert_dataset_json', 'convert_dataset_json_from_args', + 'convert_finetuning_dataset_from_args', + 'convert_finetuning_dataset', 'convert_text_to_mds', 'convert_text_to_mds_from_args', + 'convert_delta_to_json_from_args', + 'fetch_DT', ] diff --git a/llmfoundry/command_utils/data_prep/convert_dataset_json.py b/llmfoundry/command_utils/data_prep/convert_dataset_json.py index 9f174d1aaf..35d7e637e6 100644 --- a/llmfoundry/command_utils/data_prep/convert_dataset_json.py +++ b/llmfoundry/command_utils/data_prep/convert_dataset_json.py @@ -34,7 +34,7 @@ def build_hf_dataset( """Build an IterableDataset over the HF C4 or pile source data. Args: - dataset_name (str): Dataset name + path (str): Dataset name split (str): Split name. mode (ConcatMode): NO_CONCAT, or CONCAT_TOKENS max_length (int): The length of concatenated tokens diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py new file mode 100644 index 0000000000..635efd54d4 --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -0,0 +1,762 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import re +import time +import urllib.parse +from collections import namedtuple +from concurrent.futures import ProcessPoolExecutor +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union +from uuid import uuid4 + +import google.protobuf.any_pb2 as any_pb2 +import pandas as pd +import pyarrow as pa +import requests +from composer.utils import retry +from packaging import version + +from llmfoundry.utils.exceptions import ( + ClusterDoesNotExistError, + FailedToConnectToDatabricksError, + FailedToCreateSQLConnectionError, +) + +if TYPE_CHECKING: + import pyspark.sql.connect.proto as pb2 + from databricks.sql.client import Connection as Connection + from databricks.sql.client import Cursor as Cursor + from pyspark.sql import SparkSession + from pyspark.sql.connect.client.core import SparkConnectClient + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.dataframe import DataFrame as SparkDataFrame + from pyspark.sql.types import Row + +try: + from pyspark.sql.connect.client.core import SparkConnectClient + spark_connect_client_installed = True +except ImportError: + spark_connect_client_installed = False + +try: + from pyspark.sql.connect.dataframe import DataFrame + data_frame_installed = True +except ImportError: + data_frame_installed = False + +MINIMUM_DB_CONNECT_DBR_VERSION = '14.1' +MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2' + +TABLENAME_PATTERN = re.compile(r'(\S+)\.(\S+)\.(\S+)') + +log = logging.getLogger(__name__) + +Result = namedtuple( + 'Result', + [ + 'url', + 'row_count', + 'compressed_size', + 'uncompressed_size', + ], +) # pyright: ignore + +# ``collect_as_cf`` is an addon new feature monkey patch on top of the DB Connect package. +# It allows the client to fetch the results in different formats from the server. +# To be able to use the code make sure this module is not overriden by DB Connect classes. + + +def to_cf(self: 'SparkConnectClient', + plan: 'pb2.Plan', + type: str = 'json') -> Tuple[List[Result], int, bool]: + """Executes the query plans and return as presigned URLS for cloud fetch. + + It can handle the current output formats that are supported by the server. + In contrast to the regular API methods of the client, this method does not + return the schema and drops all other responses. + + Args: + self (SparkConnectClient): The SparkConnectClient we are processing. + plan (pb2.Plan): The plan object to be executed by spark. + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result has been truncated. + """ + req = self._execute_plan_request_with_metadata() + req.plan.CopyFrom(plan) + + import pyspark.sql.connect.proto as pb2 + import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 + + # Add the request options + if type == 'json': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON + elif type == 'csv': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV + elif type == 'arrow': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW + else: + raise ValueError( + f'Only formats json, csv, and arrow are supported. Got invalid type {type}', + ) + + ro = cloud_pb2.ResultOptions( + type=cloud_pb2.ResultOptions.TYPE_CLOUD, + cloudOptions=cloud_pb2.ResultOptions.CloudOptions( + format=format, + useCompression=False, + ), + ) + cloud_option = any_pb2.Any() + cloud_option.Pack(ro) + req.request_options.append( + pb2.ExecutePlanRequest.RequestOption(extension=cloud_option), + ) + + # Create the iterator + from pyspark.sql.connect.client.reattach import ( + ExecutePlanResponseReattachableIterator, + ) + iterator = ExecutePlanResponseReattachableIterator( + req, + self._stub, + self._retry_policy, + self._builder.metadata(), + ) + # Iterate over the response + result = [] + row_count = 0 + is_overflow = False + + for response in iterator: + if response.HasField('extension') and response.extension.Is( + cloud_pb2.CloudResultBatch.DESCRIPTOR, + ): + batch = cloud_pb2.CloudResultBatch() + if not response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR): + raise ValueError( + 'Response extension is not of type CloudResultBatch.', + ) + response.extension.Unpack(batch) + result += [ + Result( + b.url, + b.row_count, + b.compressed_size, + b.uncompressed_size, + ) for b in batch.results + ] + row_count += sum(result.row_count for result in batch.results) + is_overflow |= batch.truncated + return result, row_count, is_overflow + + +if spark_connect_client_installed: + SparkConnectClient.to_cf = to_cf # pyright: ignore + + +def collect_as_cf(self: 'DataFrame', + type: str = 'json') -> Tuple[List[Result], int, bool]: + """Collects DataFrame execution plan as presigned URLs. + + This method is a wrapper around the `to_cf` method of SparkConnectClient. It takes the + execution plan of the current DataFrame, converts it to a protocol buffer format, and then + uses the `to_cf` method to execute the plan and fetch results as presigned URLs. + + Args: + self (pd.DataFrame): The dataframe we are processing. + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result is truncated or overflowed. + """ + query = self._plan.to_proto(self._session.client) # pyright: ignore + return self._session.client.to_cf(query, type) # pyright: ignore + + +if data_frame_installed: + DataFrame.collect_cf = collect_as_cf # pyright: ignore + + +def iterative_combine_jsons(json_directory: str, output_file: str) -> None: + """Combine jsonl files in json_directory into one big jsonl file. + + This function does not work for nested subdirectories. + + Args: + json_directory(str): directory containing the JSONL files + output_file(str): path to the output combined JSONL file + """ + json_files = [f for f in os.listdir(json_directory) if f.endswith('.jsonl')] + with open(output_file, 'w') as outfile: + for file_name in json_files: + with open(os.path.join(json_directory, file_name), 'r') as infile: + for line in infile: + outfile.write(line) + log.info('JSON files have been combined into a JSONL file.') + + +def run_query( + query: str, + method: str, + cursor: Optional['Cursor'] = None, + spark: Optional['SparkSession'] = None, + collect: bool = True, +) -> Optional[Union[List['Row'], 'DataFrame', 'SparkDataFrame']]: + """Run SQL query via databricks-connect or databricks-sql. + + Args: + query (str): sql query + method (str): select from dbsql and dbconnect + cursor (Optional[Cursor]): connection.cursor + spark (Optional[SparkSession]): spark session + collect (bool): whether to get the underlying data from spark dataframe + """ + if method == 'dbsql': + if cursor is None: + raise ValueError(f'cursor cannot be None if using method dbsql') + cursor.execute(query) + if collect: + return cursor.fetchall() + elif method == 'dbconnect': + if spark == None: + raise ValueError(f'sparkSession is required for dbconnect') + df = spark.sql(query) + if collect: + return df.collect() + return df + else: + raise ValueError(f'Unrecognized method: {method}') + + +def get_args(signed: List, json_output_folder: str, columns: List) -> Iterable: + for i, r in enumerate(signed): + yield (i, r.url, json_output_folder, columns) + + +def download( + ipart: int, + url: str, + json_output_folder: str, + columns: Optional[List] = None, + resp_format: str = 'arrow', + compressed: bool = False, +) -> None: + """Thread download presigned url and save to jsonl locally. + + Args: + ipart (int): presigned url id + url (str): presigned url + json_output_folder (str): directory to save the ipart_th segment of dataframe + columns (list): schema to save to json + resp_format (str): whether to use arrow or json when collect + compressed (bool): if data is compressed before downloading. Need decompress if compressed=True. + """ + resp = requests.get(url) + if resp.status_code == 200: + if resp_format == 'json': + data = resp.json() + pd.DataFrame(data, columns=columns).to_json( + os.path.join( + json_output_folder, + 'part_' + str(ipart) + '.jsonl', + ), + orient='records', + lines=True, + ) + return + + # When resp_format is arrow: + if compressed: + # The data is lz4 compressed arrow format. + # Decompress the data + import lz4.frame + decompressed_data = lz4.frame.decompress(resp.content) + # Convert the decompressed data into a PyArrow table + reader = pa.ipc.open_stream(decompressed_data) + else: + reader = pa.ipc.open_stream(resp.content) + table = reader.read_all() + + # Convert the PyArrow table into a pandas DataFrame + df = table.to_pandas() + df.to_json( + os.path.join(json_output_folder, 'part_' + str(ipart) + '.jsonl'), + orient='records', + lines=True, + force_ascii=False, + ) + + +def download_starargs(args: Tuple) -> None: + return download(*args) + + +def format_tablename(table_name: str) -> str: + """Escape catalog, schema and table names with backticks. + + This needs to be done when running SQL queries/setting spark sessions to prevent invalid identifier errors. + + Args: + table_name (str): catalog.scheme.tablename on UC + """ + match = re.match(TABLENAME_PATTERN, table_name) + + if match is None: + return table_name + + formatted_identifiers = [] + for i in range(1, 4): + identifier = f'`{match.group(i)}`' + formatted_identifiers.append(identifier) + + return '.'.join(formatted_identifiers) + + +def fetch_data( + method: str, + cursor: Optional['Cursor'], + sparkSession: Optional['SparkSession'], + start: int, + end: int, + order_by: str, + tablename: str, + columns_str: str, + json_output_folder: str, +) -> None: + """Fetches a specified range of rows from a given table to a json file. + + This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, + from a specified table and column set. The fetched data is then exported as a JSON file. + + Args: + method (str): The method to use for fetching data, either 'dbconnect' or 'dbsql'. + cursor (Optional[Cursor]): The cursor object for executing queries in 'dbsql' method. + sparkSession (Optional[SparkSession]): The Spark session object for executing queries in 'dbconnect' method. + start (int): The starting index for row fetching. + end (int): The ending index for row fetching. + order_by (str): The column name to use for ordering the rows. + tablename (str): The name of the table from which to fetch the data. + columns_str (str): The string representation of the columns to select from the table. + json_output_folder (str): The file path where the resulting JSON file will be saved. + + Returns: + None: The function doesn't return any value, but writes the result to a JSONL file. + """ + query = f""" + WITH NumberedRows AS ( + SELECT + *, + ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn + FROM + {tablename} + ) + SELECT {columns_str} + FROM NumberedRows + WHERE rn BETWEEN {start+1} AND {end}""" + + if method == 'dbconnect': + spark_df = run_query(query, method, cursor, sparkSession, collect=False) + if spark_df is None: + raise RuntimeError( + f'Expect spark dataframe with {query} but got None', + ) + pdf = spark_df.toPandas() # pyright: ignore + else: # method == 'dbsql': + ans = run_query(query, method, cursor, sparkSession, collect=True) + if ans is None: + raise RuntimeError(f'Got empty results with {query}') + records = [r.asDict() for r in ans] # pyright: ignore + pdf = pd.DataFrame.from_dict(records) + + pdf.to_json( + os.path.join(json_output_folder, f'part_{start+1}_{end}.jsonl'), + orient='records', + lines=True, + ) + + +@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5) +def get_total_rows( + tablename: str, + method: str, + cursor: Optional['Cursor'], + sparkSession: Optional['SparkSession'], +): + ans = run_query( + f'SELECT COUNT(*) FROM {tablename}', + method, + cursor, + sparkSession, + ) + nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore + log.info(f'total_rows = {nrows}') + return nrows + + +@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5) +def get_columns_info( + tablename: str, + method: str, + cursor: Optional['Cursor'], + sparkSession: Optional['SparkSession'], +): + ans = run_query( + f'SHOW COLUMNS IN {tablename}', + method, + cursor, + sparkSession, + ) + columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore + order_by = columns[0] + columns_str = ','.join(columns) + log.info(f'order by column {order_by}') + return columns, order_by, columns_str + + +def fetch( + method: str, + tablename: str, + json_output_folder: str, + batch_size: int = 1 << 30, + processes: int = 1, + sparkSession: Optional['SparkSession'] = None, + dbsql: Optional['Connection'] = None, +) -> None: + """Fetch UC delta table with databricks-connect as JSONL. + + Args: + method (str): dbconnect or dbsql + tablename (str): catalog.scheme.tablename on UC + json_output_folder (str): path to write the result json file to + batch_size (int): number of rows that dbsql fetches each time to avoid OOM + processes (int): max number of processes to use to parallelize the fetch + sparkSession (pyspark.sql.sparksession): spark session + dbsql (databricks.sql.connect): dbsql session + """ + cursor = dbsql.cursor() if dbsql is not None else None + try: + nrows = get_total_rows( + tablename, + method, + cursor, + sparkSession, + ) + except Exception as e: + raise RuntimeError( + f'Error in get rows from {tablename}. Restart sparkSession and try again', + ) from e + + try: + columns, order_by, columns_str = get_columns_info( + tablename, + method, + cursor, + sparkSession, + ) + except Exception as e: + raise RuntimeError( + f'Error in get columns from {tablename}. Restart sparkSession and try again', + ) from e + + if method == 'dbconnect' and sparkSession is not None: + log.info(f'{processes=}') + df = sparkSession.table(tablename) + + # Running the query and collecting the data as arrow or json. + signed, _, _ = df.collect_cf('arrow') # pyright: ignore + log.info(f'len(signed) = {len(signed)}') + + args = get_args(signed, json_output_folder, columns) + + # Stopping the SparkSession to avoid spilling connection state into the subprocesses. + sparkSession.stop() + + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_starargs, args)) + + elif method == 'dbsql' and cursor is not None: + for start in range(0, nrows, batch_size): + log.warning(f'batch {start}') + end = min(start + batch_size, nrows) + fetch_data( + method, + cursor, + sparkSession, + start, + end, + order_by, + tablename, + columns_str, + json_output_folder, + ) + + if cursor is not None: + cursor.close() + + +def validate_and_get_cluster_info( + cluster_id: Optional[str], + databricks_host: str, + databricks_token: str, + http_path: Optional[str], + use_serverless: bool = False, +) -> tuple: + """Validate and get cluster info for running the Delta to JSONL conversion. + + Args: + cluster_id (str): cluster id to validate and fetch additional info for + databricks_host (str): databricks host name + databricks_token (str): databricks auth token + http_path (Optional[str]): http path to use for sql connect + use_serverless (bool): whether to use serverless or not + """ + method = 'dbsql' + dbsql = None + sparkSession = None + + if use_serverless: + method = 'dbconnect' + else: + if not cluster_id: + raise ValueError( + 'cluster_id is not set, however use_serverless is False', + ) + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + res = w.clusters.get(cluster_id=cluster_id) + if res is None: + raise ClusterDoesNotExistError(cluster_id) + + assert res.spark_version is not None + stripped_runtime = re.sub( + r'[a-zA-Z]', + '', + res.spark_version.split('-scala') + [0].replace( # type: ignore + 'x-snapshot', '', + ), + ) + runtime_version = re.sub(r'[.-]*$', '', stripped_runtime) + if version.parse( + runtime_version, + ) < version.parse(MINIMUM_SQ_CONNECT_DBR_VERSION): + raise ValueError( + f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}', + ) + + if http_path is None and version.parse( + runtime_version, + ) >= version.parse(MINIMUM_DB_CONNECT_DBR_VERSION): + method = 'dbconnect' + + if method == 'dbconnect': + from databricks.connect import DatabricksSession + try: + if use_serverless: + session_id = str(uuid4()) + sparkSession = DatabricksSession.builder.host( + databricks_host, + ).token( + databricks_token, + ).header('x-databricks-session-id', session_id).getOrCreate() + + else: + if not cluster_id: + raise ValueError('cluster_id is needed for dbconnect.',) + sparkSession = DatabricksSession.builder.remote( + host=databricks_host, + token=databricks_token, + cluster_id=cluster_id, + ).getOrCreate() + + except Exception as e: + raise FailedToConnectToDatabricksError() from e + else: + try: + from databricks import sql + dbsql = sql.connect( + server_hostname=re.compile(r'^https?://').sub( + '', databricks_host).strip( + ), # sqlconnect hangs if hostname starts with https + http_path=http_path, + access_token=databricks_token, + ) + except Exception as e: + raise FailedToCreateSQLConnectionError() from e + return method, dbsql, sparkSession + + +def fetch_DT( + delta_table_name: str, + json_output_folder: str, + http_path: Optional[str], + cluster_id: Optional[str], + use_serverless: bool, + DATABRICKS_HOST: str, + DATABRICKS_TOKEN: str, + batch_size: int = 1 << 30, + processes: int = os.cpu_count(), # type: ignore + json_output_filename: str = 'train-00000-of-00001.jsonl', +) -> None: + """Fetch UC Delta Table to local as jsonl.""" + log.info(f'Start .... Convert delta to json') + + obj = urllib.parse.urlparse(json_output_folder) + if obj.scheme != '': + raise ValueError( + 'Check the json_output_folder and verify it is a local path!', + ) + + if os.path.exists(json_output_folder): + if not os.path.isdir(json_output_folder) or os.listdir( + json_output_folder, + ): + raise RuntimeError( + f'Output folder {json_output_folder} already exists and is not empty. Please remove it and retry.', + ) + + os.makedirs(json_output_folder, exist_ok=True) + + if not json_output_filename.endswith('.jsonl'): + raise ValueError('json_output_filename needs to be a jsonl file') + + log.info(f'Directory {json_output_folder} created.') + + # validate_and_get_cluster_info allows cluster_id to be None if use_serverless is True + method, dbsql, sparkSession = validate_and_get_cluster_info( + cluster_id=cluster_id, + databricks_host=DATABRICKS_HOST, + databricks_token=DATABRICKS_TOKEN, + http_path=http_path, + use_serverless=use_serverless, + ) + + formatted_delta_table_name = format_tablename(delta_table_name) + + fetch( + method, + formatted_delta_table_name, + json_output_folder, + batch_size, + processes, + sparkSession, + dbsql, + ) + + if dbsql is not None: + dbsql.close() + + # combine downloaded jsonl into one big jsonl for IFT + iterative_combine_jsons( + json_output_folder, + os.path.join(json_output_folder, json_output_filename), + ) + + +def _check_imports(): + try: + import lz4.frame + _ = lz4.frame + except ImportError as e: + raise ImportError('lz4 is not installed.') from e + + try: + from databricks.connect import DatabricksSession + _ = DatabricksSession + except ImportError as e: + raise ImportError( + 'databricks-connect is not installed or improperly configured.', + ) from e + + try: + from databricks import sql + from databricks.sdk import WorkspaceClient + from databricks.sql.client import Connection as Connection + from databricks.sql.client import Cursor as Cursor + _ = WorkspaceClient, Connection, Cursor, sql + except ImportError as e: + raise ImportError( + 'databricks-sdk is not installed or improperly configured.', + ) from e + + try: + import pyspark.sql.connect.proto as pb2 + import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 + from pyspark.sql import SparkSession + from pyspark.sql.connect.client.core import SparkConnectClient + from pyspark.sql.connect.client.reattach import ( + ExecutePlanResponseReattachableIterator, + ) + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.dataframe import DataFrame as SparkDataFrame + from pyspark.sql.types import Row + _ = ( + pb2, + cloud_pb2, + SparkSession, + SparkConnectClient, + ExecutePlanResponseReattachableIterator, + DataFrame, + SparkDataFrame, + Row, + ) + except ImportError as e: + raise ImportError( + 'pyspark is not installed or improperly configured.', + ) from e + + +def convert_delta_to_json_from_args( + delta_table_name: str, + json_output_folder: str, + http_path: Optional[str], + cluster_id: Optional[str], + use_serverless: bool, + batch_size: int, + processes: int, + json_output_filename: str, +) -> None: + """A wrapper for `convert_dataset_json` that parses arguments. + + Args: + delta_table_name (str): UC table ..
+ json_output_folder (str): Local path to save the converted json + http_path (Optional[str]): If set, dbsql method is used + batch_size (int): Row chunks to transmit a time to avoid OOM + processes (int): Number of processes allowed to use + cluster_id (Optional[str]): Cluster ID with runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect. + use_serverless (bool): Use serverless or not. Make sure the workspace is entitled with serverless + json_output_filename (str): The name of the combined final jsonl that combines all partitioned jsonl + """ + _check_imports() + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + DATABRICKS_HOST = w.config.host + DATABRICKS_TOKEN = w.config.token + + tik = time.time() + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + batch_size=batch_size, + processes=processes, + cluster_id=cluster_id, + use_serverless=use_serverless, + json_output_filename=json_output_filename, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + ) + log.info(f'Elapsed time {time.time() - tik}') diff --git a/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py b/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py new file mode 100644 index 0000000000..94cd79815b --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_finetuning_dataset.py @@ -0,0 +1,346 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +import platform +import warnings +from typing import Any, Callable, Dict, Iterable, Optional, Union + +import datasets as hf_datasets +import psutil +from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict +from streaming import MDSWriter +from torch.utils.data import DataLoader +from tqdm import tqdm + +from llmfoundry.data.finetuning.collator import validate_target_settings +from llmfoundry.data.finetuning.tasks import ( + _get_example_type, + dataset_constructor, + is_valid_ift_example, + tokenize_formatted_example, +) +from llmfoundry.utils.builders import build_tokenizer + +HFDataset = Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset] + + +def build_dataloader( + dataset: HFDataset, + batch_size: int, + num_workers: Optional[int] = None, +) -> DataLoader: + if num_workers is None: + # Multiple workers is only supported on linux machines + if 'linux' in platform.platform().lower(): + num_workers = max(1, psutil.cpu_count()) + else: + num_workers = 0 + + # If using multiple workers, configure each worker to prefetch as many samples as it can, up to + # the aggregate device batch size + # If not using workers, the torch DataLoader expects the default value for prefetch_factor, + # which non-intuitively must be 2. + # If on macOS, PyTorch requires prefetch_factor set to None since num_workers is always zero + if 'macos' in platform.platform().lower() and num_workers == 0: + prefetch_factor = None + else: + prefetch_factor = max( + 1, + 2 * batch_size // num_workers, + ) if num_workers > 0 else 2 + + return DataLoader( + dataset=dataset, + sampler=None, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + ) + + +def generate_samples( + loader: DataLoader, + truncate_num_samples: Optional[int] = None, +) -> Iterable[Dict[str, bytes]]: + """Generator over samples of a dataloader. + + Args: + loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} + truncate_num_samples (Optional[int]): An optional # of samples to stop at. + + Yields: + Sample dicts. + """ + n_samples = 0 + for batch in loader: + keys = list(batch.keys()) + current_bs = len(batch[keys[0]]) + for idx in range(current_bs): + if truncate_num_samples is not None and n_samples == truncate_num_samples: + return + n_samples += 1 + yield {k: v[idx] for k, v in batch.items()} + + +def get_columns_and_format( + dataset: HFDataset, + tokenizing: bool, + preprocessing_fn: Callable, +): + ex = preprocessing_fn(next(iter(dataset))) + example_type = _get_example_type(ex) + if tokenizing: + return {'turns': 'json'}, example_type + if example_type == 'chat': + # Chat format + return {'messages': 'json'}, example_type + else: + # Prompt-response format + return {'prompt': 'str', 'response': 'str'}, example_type + + +def convert_finetuning_dataset( + dataset: str, + data_subset: Optional[str], + splits: list[str], + preprocessor: Optional[str], + data_files: list[str], + skip_preprocessing: bool, + out_root: str, + local: Optional[str], + compression: Optional[str], + num_workers: Optional[int], + tokenizer: Optional[str], + tokenizer_kwargs: dict[str, Any], + max_seq_len: int, + target_prompts: str, + target_responses: str, + encoder_decoder: bool, +) -> None: + """Converts Finetuning datasets to MDS format. + + Args: + dataset (str): Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`). + data_subset (Optional[str]): Subset of data to use. + splits (list[str]): Comma-separated list of dataset splits + preprocessor (Optional[str]): Name or import path of function used to preprocess (reformat) the dataset. + data_files (list[str]): Data file for each split. Comma-separated. + skip_preprocessing (bool): Whether to skip preprocessing. + out_root (str): Root path of output directory where MDS shards will be stored. Can be a remote URI. + local (Optional[str]): Root path of local directory if you want to keep a local copy when out_root is remote. + compression (Optional[str]): Name of compression algorithm to use. + num_workers (Optional[int]): Number of workers. + tokenizer (Optional[str]): Tokenizer used for processing. + tokenizer_kwargs (dict[str, Any]): Keyword arguments for tokenizer initialization. + max_seq_len (int): Maximum sequence length. + target_prompts (str): Policy for when to use prompts as training targets. + target_responses (str): Policy for which responses to treat as training targets. + encoder_decoder (bool): Set if the data are intended to be used to train an encoder-decoder model + + Raises: + ValueError: If the target settings are invalid. + """ + if skip_preprocessing: + preprocessing_fn = lambda x: x # Just an identity function + else: + preprocessor_str = preprocessor + preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str( + preprocessor=preprocessor_str, + dataset_name=dataset, + ) + if preprocessing_fn is None: + raise ValueError( + '`preprocessor` was not set and no preprocessing function ' +\ + 'has been registered for `dataset`. If this was intentional ' +\ + '(e.g., because your dataset is already correctly formatted), ' +\ + 'include the "--skip-preprocessing" flag to avoid this error.', + ) + + # Make sure the target settings are valid + validate_target_settings( + target_prompts=target_prompts, + target_responses=target_responses, + decoder_only_format=not encoder_decoder, + ) + + tokenizer = None + tokenizer_kwargs = tokenizer_kwargs + tokenizer_kwargs.update({'model_max_length': max_seq_len}) + if tokenizer: + tokenizer = build_tokenizer(tokenizer, tokenizer_kwargs) + + for i, split_name in enumerate(splits): + data_file = None + if len(data_files) > 0: + data_file = data_files[i] + loaded_dataset = hf_datasets.load_dataset( + path=dataset, + name=data_subset, + split=split_name, + data_files=data_file, + streaming=True, + ) + # Determine the output columns + columns, example_type = get_columns_and_format( + dataset=loaded_dataset, + tokenizing=tokenizer is not None, + preprocessing_fn=preprocessing_fn, + ) + # Prepare the iterables + if example_type == 'chat': + samples = iter(loaded_dataset) + else: + loader = build_dataloader( + dataset=loaded_dataset, + batch_size=512, + num_workers=num_workers, + ) + samples = generate_samples(loader) + + # Write samples + print(f'Converting {split_name} to MDS format...') + out = os.path.join(out_root, split_name) + if local is not None: + out = (os.path.join(local, split_name), out) + keep_local = True + else: + keep_local = False + with MDSWriter( + columns=columns, + out=out, + compression=compression, + keep_local=keep_local, + ) as out: + examples_removed = 0 + for sample in tqdm(samples, desc=split_name): + formatted_sample = preprocessing_fn(sample) + assert isinstance(formatted_sample, dict) + + # Use the _get_example_type utility to confirm that the formatted sample + # can be interpreted by the tokenization code + try: + example_type = _get_example_type(formatted_sample) + except Exception as e: + raise ValueError( + 'Encountered an error when checking example for proper formatting. ' +\ + f'example={formatted_sample}', + ) from e + if tokenizer is not None: + sample = tokenize_formatted_example( + formatted_sample, + tokenizer=tokenizer, + ) + if not is_valid_ift_example( + max_seq_len, + target_prompts=target_prompts, + target_responses=target_responses, + decoder_only_format=not encoder_decoder, + example=sample, + ): + examples_removed += 1 + continue + + sample_to_write = {'turns': []} + for turn in sample['turns']: + turn_to_write = {} + for key in ['input_ids', 'labels']: + turn_to_write[key] = list(turn[key]) + sample_to_write['turns'].append(turn_to_write) + out.write(sample_to_write) + else: + if example_type == 'prompt_response': + encoded_sample = {} + for key in ['prompt', 'response']: + value = formatted_sample[key] + assert isinstance(value, str) + encoded_sample[key] = value.encode('utf-8') + out.write(encoded_sample) + else: + out.write(formatted_sample) + + if tokenizer is not None and examples_removed > 0: + warnings.warn( + f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, ' + + + 'the prompt or response was empty, or the response was all padding tokens.', + ) + + +def convert_finetuning_dataset_from_args( + dataset: str, + data_subset: Optional[str], + splits: list[str], + preprocessor: Optional[str], + data_files: list[str], + skip_preprocessing: bool, + out_root: str, + local: Optional[str], + compression: Optional[str], + num_workers: Optional[int], + tokenizer: Optional[str], + tokenizer_kwargs: Optional[str], + max_seq_len: int, + target_prompts: str, + target_responses: str, + encoder_decoder: bool, +): + """A wrapper for `convert_finetuning_dataset` to parse arguments. + + Args: + dataset (str): Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`). + data_subset (Optional[str]): Subset of data to use. + splits (list[str]): Comma-separated list of dataset splits + preprocessor (Optional[str]): Name or import path of function used to preprocess (reformat) the dataset. + data_files (list[str]): Data file for each split. Comma-separated. + skip_preprocessing (bool): Whether to skip preprocessing. + out_root (str): Root path of output directory where MDS shards will be stored. Can be a remote URI. + local (Optional[str]): Root path of local directory if you want to keep a local copy when out_root is remote. + compression (Optional[str]): Name of compression algorithm to use. + num_workers (Optional[int]): Number of workers. + tokenizer (Optional[str]): Tokenizer used for processing. + tokenizer_kwargs (Optional[str]): Keyword arguments for tokenizer initialization in JSON format. + max_seq_len (int): Maximum sequence length. + target_prompts (str): Policy for when to use prompts as training targets. + target_responses (str): Policy for which responses to treat as training targets. + encoder_decoder (bool): Set if the data are intended to be used to train an encoder-decoder model. + + Raises: + ValueError: If the target settings are invalid. + ValueError: If the output directory already contains the requested splits. + """ + if os.path.isdir(out_root) and len( + set(os.listdir(out_root)).intersection(set(splits)), + ) > 0: + raise ValueError( + f'--out_root={out_root} contains {os.listdir(out_root)} which cannot overlap with the requested splits {splits}.', + ) + + if tokenizer_kwargs is not None: + parsed_tokenizer_kwargs = json.loads(tokenizer_kwargs) + else: + parsed_tokenizer_kwargs = {} + + if len(data_files) > 0 and len(data_files,) != len(splits): + raise ValueError( + f'If data_files is set, data_files and splits must have the same length. Got {len(data_files)=} while {len(splits)=}', + ) + convert_finetuning_dataset( + dataset=dataset, + data_subset=data_subset, + splits=splits, + preprocessor=preprocessor, + data_files=data_files, + skip_preprocessing=skip_preprocessing, + out_root=out_root, + local=local, + compression=compression, + num_workers=num_workers, + tokenizer=tokenizer, + tokenizer_kwargs=parsed_tokenizer_kwargs, + max_seq_len=max_seq_len, + target_prompts=target_prompts, + target_responses=target_responses, + encoder_decoder=encoder_decoder, + ) diff --git a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py index 14afe279fd..336c82a5e7 100644 --- a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py @@ -394,6 +394,13 @@ def convert_text_to_mds( reprocess (bool): Whether to always reprocess the given folder of text files trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer """ + # Load the tokenizer once on the main process so that the files are cached to avoid race conditions + # in the Hugging Face load code + AutoTokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + ) + is_remote_output = is_remote_path(output_folder) log.info(f'Output is remote: {is_remote_output}') diff --git a/llmfoundry/command_utils/eval.py b/llmfoundry/command_utils/eval.py index 7d8306c0a0..bddd592dba 100644 --- a/llmfoundry/command_utils/eval.py +++ b/llmfoundry/command_utils/eval.py @@ -175,6 +175,54 @@ def evaluate_model( return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) +def allow_toplevel_keys(cfg: Dict[str, Any]) -> Dict[str, Any]: + """Transform the config to allow top-level keys for model configuration. + + This function allows users to use the 'train.py' syntax in 'eval.py'. + It converts a config with top-level 'model', 'tokenizer', and (optionally) 'load_path' keys + into the nested 'models' list format required by 'eval.py'. + + Input config format (train.py style): + ```yaml + model: + + load_path: /path/to/checkpoint + tokenizer: + + ``` + + Output config format (eval.py style): + ```yaml + models: + - model: + + tokenizer: + + load_path: /path/to/checkpoint + ``` + """ + if 'model' in cfg: + if 'models' in cfg: + raise ValueError( + 'Please specify either model or models in the config, not both', + ) + default_name = cfg.get('model').get('name') # type: ignore + model_cfg = { + 'model': cfg.pop('model'), + 'tokenizer': cfg.pop('tokenizer', None), + 'model_name': cfg.pop('model_name', default_name), + } + if 'tokenizer' not in model_cfg or model_cfg['tokenizer'] is None: + raise ValueError( + 'When specifying model, "tokenizer" must be provided in the config', + ) + if 'load_path' in cfg: + model_cfg['load_path'] = cfg.pop('load_path') + cfg['models'] = [model_cfg] + + return cfg + + def evaluate(cfg: DictConfig) -> Tuple[list[Trainer], pd.DataFrame]: # Run user provided code if specified for code_path in cfg.get('code_paths', []): @@ -184,6 +232,7 @@ def evaluate(cfg: DictConfig) -> Tuple[list[Trainer], pd.DataFrame]: cfg, EvalConfig, EVAL_CONFIG_KEYS, + transforms=[allow_toplevel_keys], icl_tasks_required=True, ) diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index feed1e9fb1..c925e6e586 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -260,7 +260,7 @@ def train(cfg: DictConfig) -> Trainer: if fsdp_config is not None: if 'load_planner' in fsdp_config: - load_planners = fsdp_config['load_planner'].items() + load_planners = list(fsdp_config['load_planner'].items()) if len(load_planners) > 1: raise ValueError( 'Only one load planner can be specified in the config.', @@ -272,7 +272,7 @@ def train(cfg: DictConfig) -> Trainer: ) if 'save_planner' in fsdp_config: - save_planners = fsdp_config['save_planner'].items() + save_planners = list(fsdp_config['save_planner'].items()) if len(save_planners) > 1: raise ValueError( 'Only one save planner can be specified in the config.', @@ -544,6 +544,7 @@ def train(cfg: DictConfig) -> Trainer: dist_timeout=train_cfg.dist_timeout, profiler=profiler, compile_config=compile_config, + spin_dataloaders=train_cfg.spin_dataloaders, ) # Optionally just save an HF checkpoint diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 11104ac706..d9450bc657 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -64,9 +64,12 @@ def build_finetuning_dataloader( on which you intend to use, as explained below. Args: - name (str): The type of dataloader to build. Must = "finetuning". - --- - *** HuggingFace dataset config fields *** + tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to + prepare the data from raw text. Any missing sentinel tokens will + be added by the collator. + device_batch_size (int, float): The size of the batches (number of examples) + that the dataloader will produce. + dataset (Dict[str, Any]): A HuggingFace dataset config which contains the following fields: dataset.hf_name (str, optional): The name of the HuggingFace dataset to use. Can also be a remote http(s) directory or object store bucket containing the file {split}.jsonl in the format (prompt, response), @@ -130,16 +133,32 @@ def build_finetuning_dataloader( The script `scripts/misc/profile_packing.py` can help you choose the best packing_ratio. dataset.shuffle (bool): Whether to shuffle the dataset. - ___ See :class:`StreamingFinetuningDataset` for info on other standard config options within `dataset` that will be passed as kwargs if using the streaming codepath. - --- - tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to - prepare the data from raw text. Any missing sentinel tokens will - be added by the collator. - device_batch_size (int, float): The size of the batches (number of examples) - that the dataloader will produce. + num_workers (int, optional): How many subprocesses to use for data loading. + 0 means that the data will be loaded in the main process. The default is 0. + This argument is passed directly to the pytorch :class:`DataLoader`. + drop_last (bool, optional): If true, drop the last incomplete batch, if the dataset + size is not divisible by the batch size. If False and the size of dataset is + not divisible by the batch size, then the last batch will be smaller. The + default is False. This argument is passed directly to the pytorch :class:`DataLoader`. + pin_memory (bool, optional): If True, the data loader will copy Tensors into device/CUDA + pinned memory before returning them. If your data elements are a custom type, or your + `collate_fn` returns a batch that is a custom type. This argument is passed directly to + the pytorch :class:`DataLoader`. + prefetch_factor (int, optional): Number of batches loaded in advance by each worker. + 2 means there will be a total of 2 * num_workers batches prefetched across all workers. + (default value depends on the set value for num_workers. If value of num_workers=0 default + is None. Otherwise, if value of num_workers > 0 default is 2). This argument is passed + directly to the pytorch :class:`DataLoader`. + persistent_workers (bool, optional): If True, the data loader will not shut down the worker + processes after a dataset has been consumed once. This allows to maintain the workers + Dataset instances alive. The default is False. This argument is passed directly to the + pytorch :class:`DataLoader`. + timeout (int, optional): If positive, the timeout value for collecting a batch from workers. + Should always be non-negative. The default is 0. This argument is passed directly to the + pytorch :class:`DataLoader`. See :class:`DataLoader` for standard argument options to the pytorch dataloader, such as `drop_last`, `num_workers`, etc. @@ -357,7 +376,50 @@ def _validate_config( the other. Args: - dataset_cfg (DictConfig): The dataset configuration to be validated. + max_seq_len (int): The maximum length of sequences + in the batch. See :class:`Seq2SeqFinetuningCollator` docstring + for details. + decoder_only_format (bool): Whether to format the + examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator` + docstring for details. + hf_name (str, optional): The name of the HuggingFace dataset + to use. Can also be a remote http(s) directory or object store bucket + containing the file {split}.jsonl in the format (prompt, response), + in which case the builder will create a HuggingFace dataset. + local (str, optional): Local path where remote data + will be streamed to. Only valid if `cfg.dataset.remote` has + also been set. + remote (str, optional): Location of a MDS-formatted + streaming dataset to use. Setting this will tell the builder + to create a streaming dataset rather than a HuggingFace dataset. + hf_kwargs (DictConfig, optional): Additional kwargs to + pass to `datasets.load_dataset`, which can be used to load + a dataset from local files. + preprocessing_fn (str, optional): The name/import path of + the preprocessing function to use for formatting the data examples. + If ``None`` (default), the builder will use the preprocessing function + registered under `hf_name` (see `tasks.py`), if one exists, + otherwise it will skip preprocessing. + If `preprocessing_fn` corresponds to a registered preprocessing + function in `tasks.py`, the builder will use that. + Otherwise, it will interpret `preprocessing_fn` as a + "import.path:function_name" import path; e.g., it will call + `from import.path import function_name` and use the imported + function as the preprocessing function. + safe_load (bool, optional): Whether to enforce safe loading of the dataset. + If `None`, will default to not applying any safe loading. + streams (Dict[str, Any], optional): A dictionary with multiple data streams. + If `None`, will assume no streams. + target_prompts (str): Which prompts are used as training targets. + Defaults to "none", meaning prompts are never used as training targets. + See :class:`Seq2SeqFinetuningCollator` docstring for details. + target_responses (str): Which responses are used as training targets. + Defaults to "last", meaning only the final response in multi-turn examples + will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for + details. + kwargs (DictConfig, optional): Additional kwargs to + pass to `datasets.load_dataset`, which can be used to load + a dataset from local files. Raises: ValueError: If the dataset configuration does not meet the requirements. @@ -504,7 +566,7 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str: completed, the function removes the signal file. Args: - hf_name (str): The path of the HuggingFace dataset to download. + remote_path (str): The path of the HuggingFace dataset to download. split (str): The dataset split to download (e.g., 'train', 'validation', 'test'). Returns: diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 78bfb9c74c..397b619e73 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -162,7 +162,7 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: Args: dirpath (str): Directory path to check. - Returns + Returns: True if directory is empty or non-existent. False otherwise. """ return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 @@ -820,9 +820,33 @@ def build_from_hf( Note: This function will drop examples where the prompt is longer than the max_seq_len Args: - cfg (DictConfig): The dataset configuration. - max_seq_len (int): The maximum sequence length. Examples with prompts longer than this will be dropped. - tokenizer (Tokenizer): The tokenizer to be used for tokenizing the dataset. + dataset_name (str): The name of the HuggingFace dataset + to use. Can also be a remote http(s) directory or object store bucket + containing the file {split}.jsonl in the format (prompt, response), + in which case the builder will create a HuggingFace dataset. + split (str): The split of the HuggingFace dataset. + safe_load (bool, optional): Whether to enforce safe loading of the dataset. + If `None`, will default to not applying any safe loading. + max_seq_len (int): The maximum length of sequences + in the batch. See :class:`Seq2SeqFinetuningCollator` docstring + for details. + preprocessing_fn (Callable, optional): The preprocessing function to use for + formatting the data examples. + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for tokenizing + the HuggingFace dataset. + target_prompts (str): Which prompts are used as training targets. + Defaults to "none", meaning prompts are never used as training targets. + See :class:`Seq2SeqFinetuningCollator` docstring for details. + target_responses (str): Which responses are used as training targets. + Defaults to "last", meaning only the final response in multi-turn examples + will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for + details. + decoder_only_format (bool): Whether to format the + examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator` + docstring for details. + hf_kwargs (DictConfig, optional): Additional kwargs to + pass to `datasets.load_dataset`, which can be used to load + a dataset from local files. Returns: Dataset: The tokenized dataset. diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index a6fdf34953..5579066f89 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -337,7 +337,7 @@ def auto_packing_ratio( dataloader_cfg (DictConfig): The dataloader configuration for profiling. tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling. device_batch_size (int): The size of the batches (number of examples) per device. - num_packing_ratio (int): The number of packing ratios to try. + num_packing_ratios (int): The number of packing ratios to try. Returns: A packing ratio that minimizes padding while maintaining zero waste. diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py index 8a8b9de551..4e49be3fba 100644 --- a/llmfoundry/eval/datasets/in_context_learning_evaluation.py +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -251,8 +251,9 @@ def read_dataset( """ from datasets import \ Dataset as HFDataset # pyright: ignore[reportGeneralTypeIssues] - from datasets import \ - load_dataset # pyright: ignore[reportGeneralTypeIssues] + from datasets import ( # pyright: ignore[reportGeneralTypeIssues] + load_dataset, + ) if 'hf://' in dataset_uri: dataset_uri = dataset_uri.replace('hf://', '') if hf_loading_vars is None: @@ -363,6 +364,7 @@ def get_answer_from_example( Args: example (Dict): The example from which to retrieve the answer + in_context (bool): Whether this is an in-context example. Default to False. Returns: str: The answer in the example @@ -712,6 +714,7 @@ def get_answer_from_example( Args: example (Dict): The example from which to retrieve the answer + in_context (bool): Whether this is an in-context example. Default to False. Returns: str: The answer in from the example with chain of thought and delimiter if needed @@ -731,7 +734,7 @@ def tokenize_example( Args: prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context - ctx (str): The specific example's derived context + ctxt (str): The specific example's derived context example (Dict): The example as a dictionary. Returns: @@ -1035,6 +1038,7 @@ def get_answer_from_example( Args: example (Dict): The example from which to retrieve the answer + in_context (bool): Whether this is an in-context example. Default to False. Returns: str: The full string of the correct answer based on the 'gold' key @@ -1053,7 +1057,7 @@ def tokenize_example( Args: prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context - ctx (str): The specific example's derived context + ctxt (str): The specific example's derived context example (Dict): The example as a dictionary. Returns: @@ -1129,6 +1133,7 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: since the batch may consist of multiple questions, the choice_groupings indicates which contiguous sequences of elements in the batch correspond to which question gold_indices indicates which of the [0, N-1] choices is the correct one for each question. + Args: data (List): List of tokenized datapoints (dicts returned by self._tokenize_example) @@ -1168,6 +1173,7 @@ def split_batch(self, batch: Any, and real example, which refers to one possible continuation. As example count and microbatch_size are tracked in logical example, we split logical attributes by microbatch_size and real attributes by microbatch_size * num_choices. + Args: batch (Dict): Batch of data microbatch_size (int | float): Size of microbatches @@ -1419,7 +1425,7 @@ def tokenize_example( Args: prompt_and_fewshot (str): The collection of the prompt and fewshot examples that belongs before the example's context - ctx (str): The specific example's derived context + context_options (str): A list of contexts for this specific example. example (Dict): The example as a dictionary. Returns: @@ -1548,6 +1554,10 @@ def partition_dataset_by_category( Args: dataset_uri (str): Location of dataset. destination_path (str): Base destination path, we will write a separate partition off this URI for each category. + hf_loading_vars (Dict): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. + hf_parsing_map (Dict): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. + Column contents will be concatenated with ' ' separating them. If not included, will load the columns already present in the HF dataset. + Raises: MissingConditionalImportError: If datasets not installed raise exception. @@ -1643,8 +1653,7 @@ def get_icl_task_dataloader( # At this point, hf_model is randomly initialized composer_model = HuggingFaceModel(hf_model, hf_tokenizer) - Example: - + Example: .. testcode:: @@ -1685,8 +1694,8 @@ def get_icl_task_dataloader( hf_loading_vars (Dict, default = None): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF. hf_parsing_map (Dict, default = None): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}. Column contents will be concatenated with ' ' separating them. If not included, will load the columns already present in the HF dataset. - kwargs (Dict[str, Any], default=None): Dictionary containing a mapping - from ICL dataset constructor's parameter names and their desired values. + destination_path: Where the dataloader will be saved. + kwargs (Dict[str, Any], default=None): Dictionary containing a mapping from ICL dataset constructor's parameter names and their desired values. Returns: DataLoader: A dataloader used for performing in-context learning evaluation on the dataset provided. diff --git a/llmfoundry/eval/datasets/utils.py b/llmfoundry/eval/datasets/utils.py index 1ce249437d..c19ae15dd9 100644 --- a/llmfoundry/eval/datasets/utils.py +++ b/llmfoundry/eval/datasets/utils.py @@ -130,7 +130,7 @@ def make_padded_input( Args: context_enc (List): The encoded input to the model continuation_enc (List): The encoded desired output for the example - max_seq_list (int): Maximum length sequences can be + max_seq_len (int): Maximum length sequences can be pad_tok_id (int): The token id we pad with padding_side (str): Which side to pad the context on. Can be 'right' or 'left diff --git a/llmfoundry/eval/metrics/nlp.py b/llmfoundry/eval/metrics/nlp.py index 3ee30ebf5e..f0fbba3ece 100644 --- a/llmfoundry/eval/metrics/nlp.py +++ b/llmfoundry/eval/metrics/nlp.py @@ -80,7 +80,7 @@ def update( Args: batch (dict): Batch must consist minimally of `input_ids` as well as any other structure needed to compute the metric. - output_logits (torch.Tensor): The model outputs evaluated on the batch `input_ids` + outputs (torch.Tensor): The model outputs evaluated on the batch `input_ids`. labels (torch.Tensor): The correct outputs. Raises: diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 536cd0257d..34ce22d694 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -11,7 +11,6 @@ Any, Dict, List, - Mapping, Optional, Tuple, Union, @@ -23,7 +22,6 @@ from transformers import ( AutoConfig, AutoModelForCausalLM, - PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase, ) @@ -36,7 +34,7 @@ from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.utils import init_empty_weights -from llmfoundry.utils.config_utils import get_hf_config_value +from llmfoundry.utils.config_utils import set_config_overrides if TYPE_CHECKING: from peft import PeftConfig, PeftModel @@ -105,9 +103,13 @@ def __init__( config_overrides=config_overrides, load_in_8bit=load_in_8bit, pretrained=pretrained, - prepare_for_fsdp=True, + prepare_for_fsdp=False, ) + model = self.transform_model(model) + + ComposerHFCausalLM.prepare_inner_model(model, init_device) + train_metrics, eval_metrics = ComposerHFCausalLM.build_metrics( use_train_metrics=use_train_metrics, additional_train_metrics=additional_train_metrics, @@ -121,7 +123,7 @@ def __init__( peft_config_object = None if peft_config is not None: - peft_config_object = self._get_peft_config(peft_config) + peft_config_object = self.get_peft_config(peft_config) # Set up config args for the model construction and base classes super().__init__( @@ -135,6 +137,17 @@ def __init__( should_save_peft_only=should_save_peft_only, ) + def transform_model(self, model: PreTrainedModel) -> PreTrainedModel: + """Transforms the model after initialization. + + Args: + model (PreTrainedModel): The model to transform. + + Returns: + PreTrainedModel: The transformed model. + """ + return model + @staticmethod def build_metrics( use_train_metrics: bool, @@ -192,6 +205,7 @@ def build_inner_model( use_auth_token (bool): Whether to use an authentication token. config_overrides (Dict[str, Any]): The configuration overrides. load_in_8bit (bool): Whether to load in 8-bit. + pretrained (bool): Whether the model is pretrained. prepare_for_fsdp (bool, optional): Whether to prepare the model for FSDP wrapping. Default: False. Returns: @@ -216,6 +230,22 @@ def build_inner_model( + 'Please `pip install llm-foundry[gpu]`.', ) + # Hugging Face copies the modules into the + # transformers modules cache. On particular systems, this operation seems to cause contention between + # the different processes. To avoid this contention, we first create the config on local rank + # zero. This will set up the transformers module cache and avoid the future contention. + if dist.get_local_rank() == 0: + AutoConfig.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + attn_implementation=requested_attention_implementation, + use_cache= + False, # Necessary due to https://github.com/huggingface/transformers/issues/28056 + ) + + dist.barrier() + # Construct the Hugging Face config to use config = AutoConfig.from_pretrained( pretrained_model_name_or_path, @@ -243,50 +273,7 @@ def _autoset_attn_implementation_monkeypatch( _autoset_attn_implementation_monkeypatch, ) - # set config overrides - for k, v in config_overrides.items(): - if not hasattr(config, k): - raise ValueError( - f'config does not have attribute "{k}" to override ({k}: {v}).', - ) - - attr = getattr(config, k) - # attempt to disallow typos in nested configs - if isinstance(attr, Mapping): - extra_keys = [_k for _k in v.keys() if _k not in attr.keys()] - if extra_keys: - raise ValueError( - f'Config dict override got unknown keys. ' + - f'Extra keys: {extra_keys}. ' + - f'Expected (a subset of) keys: {list(attr.keys())}.', - ) - getattr(config, k).update(v) - # necessary case to allow for rope_scaling to be overriden in llama config - elif attr is None and isinstance(v, Mapping): - setattr(config, k, {}) - getattr(config, k).update(v) - elif isinstance(attr, PretrainedConfig): - if not isinstance(v, Mapping): - raise ValueError( - f'Expected a dictionary for config override {k}, but got {v}.', - ) - - for _k, _v in v.items(): - if not hasattr(attr, _k): - raise ValueError( - f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).', - ) - setattr(attr, _k, _v) - else: - setattr(config, k, v) - - if hasattr(config, 'attn_config') and get_hf_config_value( - config.attn_config, - 'seq_parallel_world_size', - ) is not None: - raise NotImplementedError( - 'Sequence Parallelism is not supported for HuggingFace models.', - ) + set_config_overrides(config, config_overrides) # We need to have all non-zero local ranks be not-pretrained # Rank 0 will still be pretrained, and distribute the weights appropriately @@ -298,7 +285,7 @@ def _autoset_attn_implementation_monkeypatch( # the different processes. To avoid this contention, we first create the model (on meta device) on local rank # zero. This will set up the transformers model cache and avoid the future contention. if dist.get_local_rank() == 0: - if os.path.isdir(pretrained_model_name_or_path): + if pretrained and os.path.isdir(pretrained_model_name_or_path): with init_empty_weights(include_buffers=False): with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) @@ -379,10 +366,10 @@ def _autoset_attn_implementation_monkeypatch( if prepare_for_fsdp: ComposerHFCausalLM.prepare_inner_model(model, init_device) + return model - @staticmethod - def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig': + def get_peft_config(self, peft_config_dict: Dict[str, Any]) -> 'PeftConfig': if peft_installed: from peft import LoraConfig peft_type = peft_config_dict.get('peft_type', '') diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 8e740be2b3..3e365edc47 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -415,6 +415,7 @@ def __init__( softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, @@ -520,6 +521,7 @@ def __init__( self.q_ln = build_norm( name=norm_type.lower(), normalized_shape=norm_size, + eps=norm_eps, device=device, ) if self.reuse_kv_layer_idx is None: @@ -528,6 +530,7 @@ def __init__( self.k_ln = build_norm( name=norm_type.lower(), normalized_shape=norm_size, + eps=norm_eps, device=device, ) @@ -603,6 +606,7 @@ def get_qkv( Args: x (torch.Tensor): The input tensor. + prev_layer_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): The key value of the previous layer. Returns: query (torch.Tensor): The query tensor. @@ -796,6 +800,7 @@ def __init__( softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, @@ -814,6 +819,7 @@ def __init__( softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, + norm_eps=norm_eps, fc_type=fc_type, device=device, bias=bias, @@ -841,6 +847,7 @@ def __init__( softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, @@ -859,6 +866,7 @@ def __init__( softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, + norm_eps=norm_eps, fc_type=fc_type, device=device, bias=bias, diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c6988b7bd7..92735cc489 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -42,6 +42,7 @@ def __init__( ffn_config: Optional[Dict] = None, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, no_bias: bool = False, @@ -84,6 +85,7 @@ def __init__( fc_type=fc_type, resid_pdrop=resid_pdrop, norm_type=norm_type, + norm_eps=norm_eps, device=device, no_bias=no_bias, ) @@ -99,6 +101,7 @@ def __init__( self.norm_1 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.attn = build_attention_layer( @@ -117,6 +120,7 @@ def __init__( self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) @@ -260,6 +264,7 @@ def __init__( fc_type: Optional[dict[str, Any]] = None, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, device: Optional[str] = None, no_bias: bool = False, **kwargs: Any, @@ -283,6 +288,7 @@ def __init__( self.norm_1 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.attn = build_attention_layer( @@ -302,6 +308,7 @@ def __init__( self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index a28725ee0f..f5d6d67040 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -53,6 +53,19 @@ } +def quickgelu_activation(input: torch.Tensor) -> torch.Tensor: + """Applies GELU approximation that is fast but somewhat inaccurate. + + Args: + input (torch.Tensor): Input tensor of shape(*), where * means any + number of dimensions + + Returns: + torch.Tensor: Tensor with same shape as input tensor + """ + return input * torch.sigmoid(1.702 * input) + + def resolve_ffn_act_fn( config: Optional[dict] = None, ) -> Callable[[torch.Tensor], torch.Tensor]: @@ -70,10 +83,13 @@ def resolve_ffn_act_fn( config = _FFN_ACT_FN_DEFAULT config = deepcopy(config) name = config.pop('name') - if not hasattr(torch.nn.functional, name): - raise ValueError(f'Unrecognized activation function name ({name}).') - act = getattr(torch.nn.functional, name) - return partial(act, **config) + if name == 'quick_gelu': + return quickgelu_activation + else: + if not hasattr(torch.nn.functional, name): + raise ValueError(f'Unrecognized activation function name ({name}).') + act = getattr(torch.nn.functional, name) + return partial(act, **config) _DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT) @@ -413,6 +429,7 @@ def set_ffn_device_mesh( ffn (nn.Module): The FFN module. moe_world_size (int): The MoE world size. device_mesh (DeviceMesh): The full device mesh. + get_fsdp_submesh (Callable[[DeviceMesh], DeviceMesh]): A function to get the fsdp submesh. Raises: RuntimeError: If the device mesh is 3D. diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 69d2059bad..d5fd1d37d4 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -26,10 +26,12 @@ def build_norm( name: str, normalized_shape: Union[int, List[int], torch.Size], + eps: Optional[float] = 1e-5, device: Optional[str] = None, ): kwargs = { 'normalized_shape': normalized_shape, + 'eps': eps, 'device': device, } diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 3de3744745..9671eb6ed5 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -44,6 +44,7 @@ def __init__( no_bias: bool = False, embedding_fraction: float = 1.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, use_cache: bool = False, init_config: Optional[Dict] = None, fc_type: Union[str, Dict] = 'torch', @@ -101,6 +102,7 @@ def __init__( no_bias (bool): Whether to use bias in all layers. embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. norm_type (str): choose type of norm to use + norm_eps (float): epsilon value for norm layer use_cache (bool): Whether or not the model should return the last key/values attentions init_config (Dict): A dictionary used to configure the model initialization: init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', @@ -145,6 +147,7 @@ def __init__( reuse_kv_layer: attn_config: reuse_kv_layer_idx: -6 # Relative index of the layer whose kv cache to reuse + kwargs (Any): Other relevant keyword arguments. """ self.d_model = d_model self.n_heads = n_heads @@ -168,6 +171,7 @@ def __init__( self.no_bias = no_bias self.embedding_fraction = embedding_fraction self.norm_type = norm_type + self.norm_eps = norm_eps self.use_cache = use_cache self.init_config = init_config if init_config is not None else copy.deepcopy( init_config_defaults, @@ -306,6 +310,7 @@ def _validate_config(self) -> None: 'no_scaling', 'linear', 'dynamic', + 'llama3', ]: raise ValueError( 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".', diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 3b2744f867..6f9b6bf806 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -49,12 +49,10 @@ BaseModelOutputWithPast, CausalLMOutputWithPast, ) -from transformers.models.llama.modeling_llama import \ - LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as HFRotaryEmbedding +from transformers.models.llama.modeling_llama import ( + LlamaConfig, + LlamaRotaryEmbedding, +) from llmfoundry.layers_registry import norms, param_init_fns from llmfoundry.models.layers.attention import ( @@ -88,14 +86,62 @@ log = logging.getLogger(__name__) +class InvalidConfigAccessError(KeyError): + pass + + +_ALLOWED_LLAMA_CONFIG_KEYS = { + # These are the only config keys that are set and are safe to read from + 'rope_scaling', + 'rope_theta', + 'max_position_embeddings', + 'hidden_size', + 'num_attention_heads', + + # Not set but llama modeling code tries to read this attribute + 'partial_rotary_factor', + + # Benign transformers attributes needed for __init__ + '_get_generation_defaults', + 'label2id', + 'id2label', + 'torch_dtype', + 'problem_type', + '__class__', +} + + +class PartialLlamaConfig(LlamaConfig): + """Holds the rope config for Llama models and throws. + + an `InvalidConfigAccessError` if any other config elements are read. This + class is necessary because the `LlamaRotaryEmbedding` class takes a full + `LlamaConfig` now instead of the old keyword arguments. + """ + + def __getattribute__(self, key: str): + if key not in _ALLOWED_LLAMA_CONFIG_KEYS: + raise InvalidConfigAccessError(key) + + return super().__getattribute__(key) + + def __getitem__(self, key: str): + if key not in _ALLOWED_LLAMA_CONFIG_KEYS: + raise InvalidConfigAccessError(key) + + return super().__getitem__(key) + + def gen_rotary_embedding( - rope_head_dim: int, rope_impl: str, rope_theta: int, rope_dail_config: dict, rope_hf_config: dict, max_seq_len: int, + d_model: int, + n_heads: int, ): + rope_head_dim = d_model // n_heads if rope_impl == 'dail': return DAILRotaryEmbedding( dim=rope_head_dim, @@ -108,32 +154,21 @@ def gen_rotary_embedding( 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif rope_impl == 'hf': + llama_rope_config = {**rope_hf_config} + llama_rope_config['rope_type'] = llama_rope_config.pop('type') + if llama_rope_config['rope_type'] == 'no_scaling': + llama_rope_config['rope_type'] = 'default' + partial_llama_config = PartialLlamaConfig( + rope_scaling=llama_rope_config, + rope_theta=rope_theta, + max_position_embeddings=max_seq_len, + hidden_size=d_model, + num_attention_heads=n_heads, + ) if rope_hf_config['type'] == 'no_scaling': - return HFRotaryEmbeddingFoundry( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif rope_hf_config['type'] == 'linear': - return HFLinearScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - scaling_factor=rope_hf_config['factor'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) - elif rope_hf_config['type'] == 'dynamic': - return HFDynamicNTKScalingRotaryEmbedding( - rope_head_dim, - max_position_embeddings=max_seq_len, - base=rope_theta, - scaling_factor=rope_hf_config['factor'], - device= - 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu - ) + return LlamaRotaryEmbeddingFoundry(config=partial_llama_config) + elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}: + return LlamaRotaryEmbedding(config=partial_llama_config) raise ValueError('rope_impl needs to be either dail or hf') @@ -306,7 +341,7 @@ def apply_sequence_id( return attn_bias -class HFRotaryEmbeddingFoundry(HFRotaryEmbedding): +class LlamaRotaryEmbeddingFoundry(LlamaRotaryEmbedding): @torch.no_grad() def forward( @@ -391,6 +426,7 @@ def __init__(self, config: MPTConfig): self.norm_f = build_norm( name=config.norm_type.lower(), normalized_shape=config.d_model, + eps=config.norm_eps, device=config.init_device, ) @@ -399,12 +435,13 @@ def __init__(self, config: MPTConfig): if self.rope: self.rope_impl = config.attn_config['rope_impl'] self.rotary_embedding = gen_rotary_embedding( - rope_head_dim=config.d_model // config.n_heads, rope_impl=self.rope_impl, rope_theta=config.attn_config['rope_theta'], rope_dail_config=config.attn_config['rope_dail_config'], rope_hf_config=config.attn_config['rope_hf_config'], max_seq_len=self.config.max_seq_len, + d_model=config.d_model, + n_heads=config.n_heads, ) if config.init_device != 'meta': @@ -1285,6 +1322,40 @@ def _reorder_cache( return reordered_past +def get_targets(labels: torch.Tensor) -> torch.Tensor: + targets = torch.roll(labels, shifts=-1) + targets[:, -1] = -100 + return targets + + +def compute_loss_from_logits( + outputs: CausalLMOutputWithPast, + shift_labels: bool, + labels: torch.Tensor, + loss_fn: nn.Module, + sample_weighing_factor: Optional[torch.Tensor] = None, +) -> torch.Tensor: + targets = get_targets(labels) if shift_labels else labels + + losses = loss_fn( + outputs.logits.view(-1, outputs.logits.size(-1)), + targets.view(-1), + ) + + if torch.all(targets == loss_fn.ignore_index): + loss = losses.sum() + else: + loss = losses.sum() / (targets != loss_fn.ignore_index).sum() + if sample_weighing_factor is not None: + if sample_weighing_factor.shape[0] > 1: + raise ValueError( + 'Sample weighing factor is not supported when batch["sample_weighing_factor"].shape[0] > 1.', + ) + loss = loss * sample_weighing_factor[0].item() + + return loss + + class ComposerMPTCausalLM(HuggingFaceModel): def __init__( @@ -1362,9 +1433,7 @@ def config_class(self) -> Type[MPTConfig]: return MPTConfig def get_targets(self, batch: Mapping) -> torch.Tensor: - targets = torch.roll(batch['labels'], shifts=-1) - targets[:, -1] = -100 - return targets + return get_targets(batch['labels']) def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: @@ -1385,27 +1454,14 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> Union[dict, torch.Tensor]: - if self.shift_labels: - targets = self.get_targets(batch) - else: - targets = batch['labels'] - - losses = self.loss_fn( - outputs.logits.view(-1, outputs.logits.size(-1)), - targets.view(-1), + loss = compute_loss_from_logits( + outputs, + self.shift_labels, + batch['labels'], + self.loss_fn, + batch.get('sample_weighing_factor', None), ) - if torch.all(targets == self.loss_fn.ignore_index): - loss = losses.sum() - else: - loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum() - if 'sample_weighing_factor' in batch: - if batch['sample_weighing_factor'].shape[0] > 1: - raise ValueError( - 'Sample weighing factor is not supported when batch["sample_weighing_factor"].shape[0] > 1.', - ) - loss = loss * batch['sample_weighing_factor'][0].item() - if self.config.ffn_config['ffn_type'] in ffns_with_megablocks: # MegaBlocks MoE load balancing loss try: # Add try/catch to avoid transformers complaining and raising errors @@ -1420,7 +1476,6 @@ def loss(self, outputs: CausalLMOutputWithPast, 'loss': loss, 'lbl': lbl, } - return loss @cached_property diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index e31840d3fb..3f0163ff01 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -155,6 +155,19 @@ description=_schedulers_description, ) +_tokenizers_description = ( + 'The tokenizers registry is used to register tokenizers that implement the transformers.PreTrainedTokenizerBase interface. ' + + + 'The tokenizer will be passed to the build_dataloader() and build_composer_model() methods in train.py.' +) +tokenizers = create_registry( + 'llmfoundry', + 'tokenizers', + generic_type=Type[PreTrainedTokenizerBase], + entry_points=True, + description=_tokenizers_description, +) + _models_description = ( """The models registry is used to register classes that implement the ComposerModel interface. @@ -383,6 +396,7 @@ 'optimizers', 'algorithms', 'schedulers', + 'tokenizers', 'models', 'dataset_replication_validators', 'collators', diff --git a/llmfoundry/tokenizers/__init__.py b/llmfoundry/tokenizers/__init__.py index 1703ed8862..d37c12a555 100644 --- a/llmfoundry/tokenizers/__init__.py +++ b/llmfoundry/tokenizers/__init__.py @@ -1,8 +1,11 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from llmfoundry.registry import tokenizers from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper +tokenizers.register('tiktoken', func=TiktokenTokenizerWrapper) + __all__ = [ 'TiktokenTokenizerWrapper', ] diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py index f087664344..fd0fc5948a 100644 --- a/llmfoundry/tokenizers/tiktoken.py +++ b/llmfoundry/tokenizers/tiktoken.py @@ -90,6 +90,7 @@ def __init__( errors (str, optional): Paradigm to follow when decoding bytes to UTF-8. See [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. Defaults to `"replace"`. + kwargs (Any): Other relevant keyword arguments. """ try: import tiktoken diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 0437736f74..000155f1a4 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -35,9 +35,9 @@ from llmfoundry import registry from llmfoundry.callbacks import EvalGauntlet from llmfoundry.data.dataloader import build_dataloader -from llmfoundry.eval.datasets.in_context_learning_evaluation import \ - get_icl_task_dataloader -from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper +from llmfoundry.eval.datasets.in_context_learning_evaluation import ( + get_icl_task_dataloader, +) from llmfoundry.utils.config_utils import to_dict_container, to_list_container from llmfoundry.utils.registry_utils import construct_from_registry @@ -192,7 +192,8 @@ def build_load_planner(name: str, **kwargs: Any) -> LoadPlanner: """Builds a load planner from the registry. Args: - name: Name of the load planner to build. + name (str): Name of the load planner to build. + kwargs (Any): Other relevant keyword arguments. Returns: LoadPlanner: The load planner. @@ -211,7 +212,8 @@ def build_save_planner(name: str, **kwargs: Any) -> SavePlanner: """Builds a save planner from the registry. Args: - name: Name of the save planner to build. + name (str): Name of the save planner to build. + kwargs (Any): Other relevant keyword arguments. Returns: savePlanner: The save planner. @@ -506,8 +508,15 @@ def build_tokenizer( with dist.local_rank_zero_download_and_wait(signal_file_path): pass - if tokenizer_name.startswith('tiktoken'): - tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs) + if tokenizer_name in registry.tokenizers: + tokenizer = construct_from_registry( + name=tokenizer_name, + registry=registry.tokenizers, + partial_function=True, + pre_validation_function=PreTrainedTokenizerBase, + post_validation_function=None, + kwargs=tokenizer_kwargs, + ) else: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, diff --git a/llmfoundry/utils/checkpoint_conversion_helpers.py b/llmfoundry/utils/checkpoint_conversion_helpers.py index 905afd6edb..5c65a7475e 100644 --- a/llmfoundry/utils/checkpoint_conversion_helpers.py +++ b/llmfoundry/utils/checkpoint_conversion_helpers.py @@ -177,6 +177,7 @@ def _convert_weight_to_ft_each( tensor_name (str): Name of the weight tensor. Used in naming the output file. config (Dict[str, Any]): Configuration for the model. This is used in getting model specific parameters. data (np.ndarray): Tensor data in np.ndarray format. + np_weight_data_type (np.dtype): Data type of the numpy array `data`. Returns: None: Writes to a file in `save_dir`. File name is based on the `tensor_name` diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 1b5b23cb9f..f99139b5e1 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -167,6 +167,7 @@ class TrainConfig: # Dataloader device_train_microbatch_size: Union[str, int, float] = 'auto' global_train_batch_size: Optional[int] = None + spin_dataloaders: bool = True # Eval dataloader eval_subset_num_batches: int = -1 @@ -340,8 +341,6 @@ def make_dataclass_and_log_config( transforms, ) - logged_cfg.update(unstructured_config, merge=True) - arg_config_keys = set(unstructured_config.keys()) extraneous_keys = set.difference(arg_config_keys, dataclass_fields) @@ -467,6 +466,9 @@ def update_config_with_batch_size_info( Args: cfg (Dict[str, Any]): The config to update. + device_train_batch_size (Union[int, float]): The batch size of the training dataset for each device. + device_train_microbatch_size (Union[int, float, Literal['auto']]): The microbatch size of the training dataset for each device. + device_train_grad_accum (Union[int, Literal['auto']]): The gradient accumulation settings for each device. Returns: Dict[str, Any]: The updated config. @@ -531,7 +533,6 @@ def process_init_device(model_cfg: Dict[str, Any], fsdp_config: Optional[Dict]): fsdp_config['sync_module_states'] = True # Set defaults for mixed initialization - fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) # Set ffn_config.device_mesh to fsdp_config.device_mesh @@ -812,3 +813,45 @@ def _verify_uc_path(path: str) -> bool: f'but your `UCVolumeDatasetSource` might be invalid.', ) return False + + +def set_config_overrides( + config: PretrainedConfig, + config_overrides: Dict[str, Any], +): + # set config overrides + for k, v in config_overrides.items(): + if not hasattr(config, k): + raise ValueError( + f'config does not have attribute "{k}" to override ({k}: {v}).', + ) + + attr = getattr(config, k) + # attempt to disallow typos in nested configs + if isinstance(attr, Mapping): + extra_keys = [_k for _k in v.keys() if _k not in attr.keys()] + if extra_keys: + raise ValueError( + f'Config dict override got unknown keys. ' + + f'Extra keys: {extra_keys}. ' + + f'Expected (a subset of) keys: {list(attr.keys())}.', + ) + getattr(config, k).update(v) + # necessary case to allow for rope_scaling to be overriden in llama config + elif attr is None and isinstance(v, Mapping): + setattr(config, k, {}) + getattr(config, k).update(v) + elif isinstance(attr, PretrainedConfig): + if not isinstance(v, Mapping): + raise ValueError( + f'Expected a dictionary for config override {k}, but got {v}.', + ) + + for _k, _v in v.items(): + if not hasattr(attr, _k): + raise ValueError( + f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).', + ) + setattr(attr, _k, _v) + else: + setattr(config, k, v) diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py index dde8240d8b..9609982fda 100644 --- a/llmfoundry/utils/model_download_utils.py +++ b/llmfoundry/utils/model_download_utils.py @@ -69,7 +69,7 @@ def download_from_hf_hub( Safetensors weights will be downloaded unless `prefer_safetensors` is set to False. Args: - repo_id (str): The Hugging Face Hub repo ID. + model (str): The Hugging Face Hub repo ID. save_dir (str, optional): The local path to the directory where the model files will be downloaded. prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are available. Defaults to True. @@ -157,7 +157,7 @@ def _recursive_download( Args: session: A requests.Session through which to make requests to the remote server. - url (str): The base URL where the files are located. + base_url (str): The base URL where the files are located. path (str): The path from the base URL to the files to download. The full URL for the download is equal to '/'. save_dir (str): The directory to save downloaded files to. diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index 3ea7cc58a7..f96e72b3a2 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -127,6 +127,7 @@ def construct_from_registry( before constructing the item to return. This should throw an exception if validation fails. Defaults to None. post_validation_function (Optional[Callable[[Any], None]], optional): An optional validation function called after constructing the item to return. This should throw an exception if validation fails. Defaults to None. + kwargs (Optional[Dict[str, Any]]): Other relevant keyword arguments. Raises: ValueError: If the validation functions failed or the registered item is invalid @@ -176,6 +177,7 @@ def import_file(loc: Union[str, Path]) -> ModuleType: """Import module from a file. Used to run arbitrary python code. + Args: name (str): Name of module to load. loc (str / Path): Path to the file. diff --git a/pyproject.toml b/pyproject.toml index 53007cafaf..fdbabfff96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,23 +11,44 @@ skip = [ "env", "wandb", "runs", "build", "node_modules" ] include_trailing_comma = true split_on_trailing_comma = true +# Ruff global +[tool.ruff] +exclude = [ + "build/**", + "docs/**", + "node_modules/**", +] + +# Ruff linter [tool.ruff.lint] select = [ "C4", - # TODO port pydocstyle - # "D", # pydocstyle "LOG", "PERF", "PLE", "COM812", + "D", # pydocstyle ] -[tool.ruff] -exclude = [ - "build/**", - "docs/**", - "node_modules/**", + +extend-select = ["D404"] # pydocstyle + +ignore = [ + "D100", + "D101", + "D102", + "D103", + "D104", + "D105", + "D107", + "D400", + "D401", + "D415", ] +[tool.ruff.lint.pydocstyle] +convention = "google" + + # Coverage [tool.coverage.run] parallel = true @@ -79,7 +100,7 @@ reportMissingImports = "none" # Pytest [tool.pytest.ini_options] # By default, skip gpu tests -addopts = "--tb=short -m 'not gpu'" +addopts = "--tb=short -m 'not gpu' --color=yes" markers = [ # For distributed testing @@ -506,8 +527,3 @@ ignore_patterns = [ "wandb/**/*.py", "build/**/*.py", ] - -[tool.pydocstyle] -convention="google" -add_ignore="D100,D101,D102,D103,D104,D105,D107,D400,D401,D415" -add_select="D404" diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 3b88ba668f..277a8c1ffc 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -4,41 +4,12 @@ import logging import os import re -import time -import urllib.parse -from argparse import ArgumentParser, Namespace -from collections import namedtuple -from concurrent.futures import ProcessPoolExecutor -from typing import Iterable, List, Optional, Tuple, Union -from uuid import uuid4 +from argparse import ArgumentParser -import google.protobuf.any_pb2 as any_pb2 -import lz4.frame -import pandas as pd -import pyarrow as pa -import pyspark.sql.connect.proto as pb2 -import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 -import requests -from composer.utils import retry -from databricks import sql -from databricks.connect import DatabricksSession -from databricks.sdk import WorkspaceClient from databricks.sql.client import Connection as Connection from databricks.sql.client import Cursor as Cursor -from packaging import version -from pyspark.sql import SparkSession -from pyspark.sql.connect.client.core import SparkConnectClient -from pyspark.sql.connect.client.reattach import \ - ExecutePlanResponseReattachableIterator -from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.dataframe import DataFrame as SparkDataFrame -from pyspark.sql.types import Row -from llmfoundry.utils.exceptions import ( - ClusterDoesNotExistError, - FailedToConnectToDatabricksError, - FailedToCreateSQLConnectionError, -) +from llmfoundry.command_utils import convert_delta_to_json_from_args MINIMUM_DB_CONNECT_DBR_VERSION = '14.1' MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2' @@ -47,617 +18,6 @@ log = logging.getLogger(__name__) -Result = namedtuple( - 'Result', - [ - 'url', - 'row_count', - 'compressed_size', - 'uncompressed_size', - ], -) # pyright: ignore - -# ``collect_as_cf`` is an addon new feature monkey patch on top of the DB Connect package. -# It allows the client to fetch the results in different formats from the server. -# To be able to use the code make sure this module is not overriden by DB Connect classes. - - -def to_cf(self: SparkConnectClient, - plan: pb2.Plan, - type: str = 'json') -> Tuple[List[Result], int, bool]: - """Executes the query plans and return as presigned URLS for cloud fetch. - - It can handle the current output formats that are supported by the server. - In contrast to the regular API methods of the client, this method does not - return the schema and drops all other responses. - - Args: - plan (pb2.Plan): The plan object to be executed by spark. - type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. - - Returns: - Tuple[List[Result], int, bool]: A tuple containing: - - A list of Result namedtuples, each containing a URL, row count, compressed size, - and uncompressed size of the part of the result. - - Total row count of all parts of the result. - - A boolean indicating whether the result has been truncated. - """ - log.info(f'Executing query plan with format: {type}') - - req = self._execute_plan_request_with_metadata() - req.plan.CopyFrom(plan) - - # Add the request options - if type == 'json': - format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON - elif type == 'csv': - format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV - elif type == 'arrow': - format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW - else: - raise ValueError( - f'Only formats json, csv, and arrow are supported. Got invalid type {type}', - ) - - ro = cloud_pb2.ResultOptions( - type=cloud_pb2.ResultOptions.TYPE_CLOUD, - cloudOptions=cloud_pb2.ResultOptions.CloudOptions( - format=format, - useCompression=False, - ), - ) - cloud_option = any_pb2.Any() - cloud_option.Pack(ro) - req.request_options.append( - pb2.ExecutePlanRequest.RequestOption(extension=cloud_option), - ) - - # Create the iterator - iterator = ExecutePlanResponseReattachableIterator( - req, - self._stub, - self._retry_policy, - self._builder.metadata(), - ) - # Iterate over the response - result = [] - row_count = 0 - is_overflow = False - - for response in iterator: - if response.HasField('extension') and response.extension.Is( - cloud_pb2.CloudResultBatch.DESCRIPTOR, - ): - batch = cloud_pb2.CloudResultBatch() - if not response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR): - raise ValueError( - 'Response extension is not of type CloudResultBatch.', - ) - response.extension.Unpack(batch) - result += [ - Result( - b.url, - b.row_count, - b.compressed_size, - b.uncompressed_size, - ) for b in batch.results - ] - row_count += sum(result.row_count for result in batch.results) - is_overflow |= batch.truncated - return result, row_count, is_overflow - - -SparkConnectClient.to_cf = to_cf # pyright: ignore - - -def collect_as_cf(self: DataFrame, - type: str = 'json') -> Tuple[List[Result], int, bool]: - """Collects DataFrame execution plan as presigned URLs. - - This method is a wrapper around the `to_cf` method of SparkConnectClient. It takes the - execution plan of the current DataFrame, converts it to a protocol buffer format, and then - uses the `to_cf` method to execute the plan and fetch results as presigned URLs. - - Args: - type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. - - Returns: - Tuple[List[Result], int, bool]: A tuple containing: - - A list of Result namedtuples, each containing a URL, row count, compressed size, - and uncompressed size of the part of the result. - - Total row count of all parts of the result. - - A boolean indicating whether the result is truncated or overflowed. - """ - log.info(f'Collecting DataFrame as cloud fetch with format: {type}') - query = self._plan.to_proto(self._session.client) # pyright: ignore - return self._session.client.to_cf(query, type) # pyright: ignore - - -DataFrame.collect_cf = collect_as_cf # pyright: ignore - - -def iterative_combine_jsons(json_directory: str, output_file: str) -> None: - """Combine jsonl files in json_directory into one big jsonl file. - - This function does not work for nested subdirectories. - - Args: - json_directory(str): directory containing the JSONL files - output_file(str): path to the output combined JSONL file - """ - log.info( - f'Starting to combine JSON files from {json_directory} into {output_file}', - ) - json_files = [f for f in os.listdir(json_directory) if f.endswith('.jsonl')] - log.info(f'Found {len(json_files)} JSON files to combine') - with open(output_file, 'w') as outfile: - for file_name in json_files: - log.debug(f'Processing file: {file_name}') - with open(os.path.join(json_directory, file_name), 'r') as infile: - for line in infile: - outfile.write(line) - log.info('JSON files have been successfully combined into a JSONL file.') - - -def run_query( - query: str, - method: str, - cursor: Optional[Cursor] = None, - spark: Optional[SparkSession] = None, - collect: bool = True, -) -> Optional[Union[List[Row], DataFrame, SparkDataFrame]]: - """Run SQL query via databricks-connect or databricks-sql. - - Args: - query (str): sql query - method (str): select from dbsql and dbconnect - cursor (Optional[Cursor]): connection.cursor - spark (Optional[SparkSession]): spark session - collect (bool): whether to get the underlying data from spark dataframe - """ - log.info(f'Executing query using method: {method}') - log.debug(f'Query: {query}') - - if method == 'dbsql': - if cursor is None: - raise ValueError(f'cursor cannot be None if using method dbsql') - cursor.execute(query) - if collect: - return cursor.fetchall() - elif method == 'dbconnect': - if spark == None: - raise ValueError(f'sparkSession is required for dbconnect') - df = spark.sql(query) - if collect: - return df.collect() - return df - else: - raise ValueError(f'Unrecognized method: {method}') - - -def get_args(signed: List, json_output_folder: str, columns: List) -> Iterable: - for i, r in enumerate(signed): - yield (i, r.url, json_output_folder, columns) - - -def download( - ipart: int, - url: str, - json_output_folder: str, - columns: Optional[List] = None, - resp_format: str = 'arrow', - compressed: bool = False, -) -> None: - """Thread download presigned url and save to jsonl locally. - - Args: - ipart (int): presigned url id - url (str): presigned url - json_output_folder (str): directory to save the ipart_th segment of dataframe - columns (list): schema to save to json - resp_format (str): whether to use arrow or json when collect - compressed (bool): if data is compressed before downloading. Need decompress if compressed=True. - """ - log.info(f'Downloading part {ipart} from URL: {url}') - - resp = requests.get(url) - if resp.status_code == 200: - if resp_format == 'json': - data = resp.json() - pd.DataFrame(data, columns=columns).to_json( - os.path.join( - json_output_folder, - 'part_' + str(ipart) + '.jsonl', - ), - orient='records', - lines=True, - ) - return - - # When resp_format is arrow: - if compressed: - # The data is lz4 compressed arrow format. - # Decompress the data - decompressed_data = lz4.frame.decompress(resp.content) - # Convert the decompressed data into a PyArrow table - reader = pa.ipc.open_stream(decompressed_data) - else: - reader = pa.ipc.open_stream(resp.content) - table = reader.read_all() - - # Convert the PyArrow table into a pandas DataFrame - df = table.to_pandas() - df.to_json( - os.path.join(json_output_folder, 'part_' + str(ipart) + '.jsonl'), - orient='records', - lines=True, - force_ascii=False, - ) - - -def download_starargs(args: Tuple) -> None: - return download(*args) - - -def format_tablename(table_name: str) -> str: - """Escape catalog, schema and table names with backticks. - - This needs to be done when running SQL queries/setting spark sessions to prevent invalid identifier errors. - - Args: - table_name (str): catalog.scheme.tablename on UC - """ - log.debug(f'Formatting table name: {table_name}') - match = re.match(TABLENAME_PATTERN, table_name) - - if match is None: - return table_name - - formatted_identifiers = [] - for i in range(1, 4): - identifier = f'`{match.group(i)}`' - formatted_identifiers.append(identifier) - - return '.'.join(formatted_identifiers) - - -def fetch_data( - method: str, - cursor: Optional[Cursor], - sparkSession: Optional[SparkSession], - start: int, - end: int, - order_by: str, - tablename: str, - columns_str: str, - json_output_folder: str, -) -> None: - """Fetches a specified range of rows from a given table to a json file. - - This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, - from a specified table and column set. The fetched data is then exported as a JSON file. - - Args: - method (str): The method to use for fetching data, either 'dbconnect' or 'dbsql'. - cursor (Optional[Cursor]): The cursor object for executing queries in 'dbsql' method. - sparkSession (Optional[SparkSession]): The Spark session object for executing queries in 'dbconnect' method. - start (int): The starting index for row fetching. - end (int): The ending index for row fetching. - order_by (str): The column name to use for ordering the rows. - tablename (str): The name of the table from which to fetch the data. - columns_str (str): The string representation of the columns to select from the table. - json_output_folder (str): The file path where the resulting JSON file will be saved. - - Returns: - None: The function doesn't return any value, but writes the result to a JSONL file. - """ - log.info(f'Fetching data from {start} to {end} using method: {method}') - query = f""" - WITH NumberedRows AS ( - SELECT - *, - ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn - FROM - {tablename} - ) - SELECT {columns_str} - FROM NumberedRows - WHERE rn BETWEEN {start+1} AND {end}""" - - if method == 'dbconnect': - spark_df = run_query(query, method, cursor, sparkSession, collect=False) - if spark_df is None: - raise RuntimeError( - f'Expect spark dataframe with {query} but got None', - ) - pdf = spark_df.toPandas() # pyright: ignore - else: # method == 'dbsql': - ans = run_query(query, method, cursor, sparkSession, collect=True) - if ans is None: - raise RuntimeError(f'Got empty results with {query}') - records = [r.asDict() for r in ans] # pyright: ignore - pdf = pd.DataFrame.from_dict(records) - - pdf.to_json( - os.path.join(json_output_folder, f'part_{start+1}_{end}.jsonl'), - orient='records', - lines=True, - ) - - -@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5) -def get_total_rows( - tablename: str, - method: str, - cursor: Optional[Cursor], - sparkSession: Optional[SparkSession], -): - ans = run_query( - f'SELECT COUNT(*) FROM {tablename}', - method, - cursor, - sparkSession, - ) - nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore - log.info(f'total_rows = {nrows}') - return nrows - - -@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5) -def get_columns_info( - tablename: str, - method: str, - cursor: Optional[Cursor], - sparkSession: Optional[SparkSession], -): - ans = run_query( - f'SHOW COLUMNS IN {tablename}', - method, - cursor, - sparkSession, - ) - columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore - order_by = columns[0] - columns_str = ','.join(columns) - log.info(f'order by column {order_by}') - return columns, order_by, columns_str - - -def fetch( - method: str, - tablename: str, - json_output_folder: str, - batch_size: int = 1 << 30, - processes: int = 1, - sparkSession: Optional[SparkSession] = None, - dbsql: Optional[Connection] = None, -) -> None: - """Fetch UC delta table with databricks-connect as JSONL. - - Args: - method (str): dbconnect or dbsql - tablename (str): catalog.scheme.tablename on UC - json_output_folder (str): path to write the result json file to - batch_size (int): number of rows that dbsql fetches each time to avoid OOM - processes (int): max number of processes to use to parallelize the fetch - sparkSession (pyspark.sql.sparksession): spark session - dbsql (databricks.sql.connect): dbsql session - """ - log.info(f'Starting data fetch for table: {tablename}') - log.info( - f'Method: {method}, Batch size: {batch_size}, Processes: {processes}', - ) - - cursor = dbsql.cursor() if dbsql is not None else None - try: - nrows = get_total_rows( - tablename, - method, - cursor, - sparkSession, - ) - except Exception as e: - raise RuntimeError( - f'Error in get rows from {tablename}. Restart sparkSession and try again', - ) from e - - try: - columns, order_by, columns_str = get_columns_info( - tablename, - method, - cursor, - sparkSession, - ) - except Exception as e: - raise RuntimeError( - f'Error in get columns from {tablename}. Restart sparkSession and try again', - ) from e - - if method == 'dbconnect' and sparkSession is not None: - log.info(f'{processes=}') - df = sparkSession.table(tablename) - - # Running the query and collecting the data as arrow or json. - signed, _, _ = df.collect_cf('arrow') # pyright: ignore - log.info(f'len(signed) = {len(signed)}') - - args = get_args(signed, json_output_folder, columns) - - # Stopping the SparkSession to avoid spilling connection state into the subprocesses. - sparkSession.stop() - - with ProcessPoolExecutor(max_workers=processes) as executor: - list(executor.map(download_starargs, args)) - - elif method == 'dbsql' and cursor is not None: - for start in range(0, nrows, batch_size): - log.warning(f'batch {start}') - end = min(start + batch_size, nrows) - fetch_data( - method, - cursor, - sparkSession, - start, - end, - order_by, - tablename, - columns_str, - json_output_folder, - ) - - if cursor is not None: - cursor.close() - - -def validate_and_get_cluster_info( - cluster_id: str, - databricks_host: str, - databricks_token: str, - http_path: Optional[str], - use_serverless: bool = False, -) -> tuple: - """Validate and get cluster info for running the Delta to JSONL conversion. - - Args: - cluster_id (str): cluster id to validate and fetch additional info for - databricks_host (str): databricks host name - databricks_token (str): databricks auth token - http_path (Optional[str]): http path to use for sql connect - use_serverless (bool): whether to use serverless or not - """ - log.info('Validating cluster information and getting connection details') - log.debug( - f'Cluster ID: {cluster_id}, Host: {databricks_host}, Use Serverless: {use_serverless}', - ) - - method = 'dbsql' - dbsql = None - sparkSession = None - - if use_serverless: - method = 'dbconnect' - else: - w = WorkspaceClient() - res = w.clusters.get(cluster_id=cluster_id) - if res is None: - raise ClusterDoesNotExistError(cluster_id) - - assert res.spark_version is not None - stripped_runtime = re.sub( - r'[a-zA-Z]', - '', - res.spark_version.split('-scala') - [0].replace( # type: ignore - 'x-snapshot', '', - ), - ) - runtime_version = re.sub(r'[.-]*$', '', stripped_runtime) - if version.parse( - runtime_version, - ) < version.parse(MINIMUM_SQ_CONNECT_DBR_VERSION): - raise ValueError( - f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}', - ) - - if http_path is None and version.parse( - runtime_version, - ) >= version.parse(MINIMUM_DB_CONNECT_DBR_VERSION): - method = 'dbconnect' - - if method == 'dbconnect': - try: - if use_serverless: - session_id = str(uuid4()) - sparkSession = DatabricksSession.builder.host( - databricks_host, - ).token( - databricks_token, - ).header('x-databricks-session-id', session_id).getOrCreate() - - else: - sparkSession = DatabricksSession.builder.remote( - host=databricks_host, - token=databricks_token, - cluster_id=cluster_id, - ).getOrCreate() - - except Exception as e: - raise FailedToConnectToDatabricksError() from e - else: - try: - dbsql = sql.connect( - server_hostname=re.compile(r'^https?://').sub( - '', databricks_host).strip( - ), # sqlconnect hangs if hostname starts with https - http_path=http_path, - access_token=databricks_token, - ) - except Exception as e: - raise FailedToCreateSQLConnectionError() from e - return method, dbsql, sparkSession - - -def fetch_DT(args: Namespace) -> None: - """Fetch UC Delta Table to local as jsonl.""" - log.info(f'Start .... Convert delta to json') - log.info('Starting Delta Table to JSON conversion process') - log.info(f'Delta Table: {args.delta_table_name}') - log.info(f'Output Folder: {args.json_output_folder}') - log.info(f'Output Filename: {args.json_output_filename}') - - obj = urllib.parse.urlparse(args.json_output_folder) - if obj.scheme != '': - raise ValueError( - 'Check the json_output_folder and verify it is a local path!', - ) - - if os.path.exists(args.json_output_folder): - if not os.path.isdir(args.json_output_folder) or os.listdir( - args.json_output_folder, - ): - raise RuntimeError( - f'Output folder {args.json_output_folder} already exists and is not empty. Please remove it and retry.', - ) - - os.makedirs(args.json_output_folder, exist_ok=True) - - if not args.json_output_filename.endswith('.jsonl'): - raise ValueError('json_output_filename needs to be a jsonl file') - - log.info(f'Directory {args.json_output_folder} created.') - - method, dbsql, sparkSession = validate_and_get_cluster_info( - cluster_id=args.cluster_id, - databricks_host=args.DATABRICKS_HOST, - databricks_token=args.DATABRICKS_TOKEN, - http_path=args.http_path, - use_serverless=args.use_serverless, - ) - - args.delta_table_name = format_tablename(args.delta_table_name) - - fetch( - method, - args.delta_table_name, - args.json_output_folder, - args.batch_size, - args.processes, - sparkSession, - dbsql, - ) - - if dbsql is not None: - dbsql.close() - - # combine downloaded jsonl into one big jsonl for IFT - iterative_combine_jsons( - args.json_output_folder, - os.path.join(args.json_output_folder, args.json_output_filename), - ) - - log.info('Delta Table to JSON conversion completed successfully') - - if __name__ == '__main__': parser = ArgumentParser( description= @@ -719,11 +79,13 @@ def fetch_DT(args: Namespace) -> None: 'The name of the combined final jsonl that combines all partitioned jsonl', ) args = parser.parse_args() - w = WorkspaceClient() - args.DATABRICKS_HOST = w.config.host - args.DATABRICKS_TOKEN = w.config.token - - tik = time.time() - fetch_DT(args) - log.info(f'Elapsed time {time.time() - tik}') - log.info('Delta Table to JSON conversion script completed') + convert_delta_to_json_from_args( + delta_table_name=args.delta_table_name, + json_output_folder=args.json_output_folder, + http_path=args.http_path, + batch_size=args.batch_size, + processes=args.processes, + cluster_id=args.cluster_id, + use_serverless=args.use_serverless, + json_output_filename=args.json_output_filename, + ) diff --git a/scripts/data_prep/convert_finetuning_dataset.py b/scripts/data_prep/convert_finetuning_dataset.py index 523d45093d..b28e25786b 100644 --- a/scripts/data_prep/convert_finetuning_dataset.py +++ b/scripts/data_prep/convert_finetuning_dataset.py @@ -1,28 +1,12 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import json -import os -import platform -import warnings from argparse import ArgumentParser, Namespace -from typing import Callable, Dict, Iterable, Optional, Union +from typing import Union -import datasets as hf_datasets -import psutil from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict -from streaming import MDSWriter -from torch.utils.data import DataLoader -from tqdm import tqdm -from llmfoundry.data.finetuning.collator import validate_target_settings -from llmfoundry.data.finetuning.tasks import ( - _get_example_type, - dataset_constructor, - is_valid_ift_example, - tokenize_formatted_example, -) -from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.command_utils import convert_finetuning_dataset_from_args HFDataset = Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset] @@ -116,236 +100,9 @@ def parse_args() -> Namespace: ) parsed = parser.parse_args() - - if os.path.isdir(parsed.out_root) and len( - set(os.listdir(parsed.out_root)).intersection(set(parsed.splits)), - ) > 0: - raise ValueError( - f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.', - ) - - if parsed.tokenizer_kwargs is not None: - parsed.tokenizer_kwargs = json.loads(parsed.tokenizer_kwargs) - else: - parsed.tokenizer_kwargs = {} - - if len(parsed.data_files) > 0 and len( - parsed.data_files, - ) != len(parsed.splits): - raise ValueError( - f'If data_files is set, data_files and splits must have the same length. Got {len(parsed.data_files)=} while {len(parsed.splits)=}', - ) - return parsed -def build_dataloader( - dataset: HFDataset, - batch_size: int, - num_workers: Optional[int] = None, -) -> DataLoader: - if num_workers is None: - # Multiple workers is only supported on linux machines - if 'linux' in platform.platform().lower(): - num_workers = max(1, psutil.cpu_count()) - else: - num_workers = 0 - - # If using multiple workers, configure each worker to prefetch as many samples as it can, up to - # the aggregate device batch size - # If not using workers, the torch DataLoader expects the default value for prefetch_factor, - # which non-intuitively must be 2. - # If on macOS, PyTorch requires prefetch_factor set to None since num_workers is always zero - if 'macos' in platform.platform().lower() and num_workers == 0: - prefetch_factor = None - else: - prefetch_factor = max( - 1, - 2 * batch_size // num_workers, - ) if num_workers > 0 else 2 - - return DataLoader( - dataset=dataset, - sampler=None, - batch_size=batch_size, - num_workers=num_workers, - prefetch_factor=prefetch_factor, - ) - - -def generate_samples( - loader: DataLoader, - truncate_num_samples: Optional[int] = None, -) -> Iterable[Dict[str, bytes]]: - """Generator over samples of a dataloader. - - Args: - loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} - truncate_num_samples (Optional[int]): An optional # of samples to stop at. - - Yields: - Sample dicts. - """ - n_samples = 0 - for batch in loader: - keys = list(batch.keys()) - current_bs = len(batch[keys[0]]) - for idx in range(current_bs): - if truncate_num_samples is not None and n_samples == truncate_num_samples: - return - n_samples += 1 - yield {k: v[idx] for k, v in batch.items()} - - -def get_columns_and_format( - dataset: HFDataset, - tokenizing: bool, - preprocessing_fn: Callable, -): - ex = preprocessing_fn(next(iter(dataset))) - example_type = _get_example_type(ex) - if tokenizing: - return {'turns': 'json'}, example_type - if example_type == 'chat': - # Chat format - return {'messages': 'json'}, example_type - else: - # Prompt-response format - return {'prompt': 'str', 'response': 'str'}, example_type - - -def main(args: Namespace) -> None: - """Main: create a streaming dataset. - - Args: - args (Namespace): Commandline arguments. - """ - if args.skip_preprocessing: - preprocessing_fn = lambda x: x # Just an identity function - else: - preprocessor_str = args.preprocessor - preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str( - preprocessor=preprocessor_str, - dataset_name=args.dataset, - ) - if preprocessing_fn is None: - raise ValueError( - '`args.preprocessor` was not set and no preprocessing function ' +\ - 'has been registered for `args.dataset`. If this was intentional ' +\ - '(e.g., because your dataset is already correctly formatted), ' +\ - 'include the "--skip-preprocessing" flag to avoid this error.', - ) - - # Make sure the target settings are valid - validate_target_settings( - target_prompts=args.target_prompts, - target_responses=args.target_responses, - decoder_only_format=not args.encoder_decoder, - ) - - tokenizer = None - tokenizer_kwargs = args.tokenizer_kwargs - tokenizer_kwargs.update({'model_max_length': args.max_seq_len}) - if args.tokenizer: - tokenizer = build_tokenizer(args.tokenizer, tokenizer_kwargs) - - for i, split_name in enumerate(args.splits): - data_file = None - if len(args.data_files) > 0: - data_file = args.data_files[i] - dataset = hf_datasets.load_dataset( - path=args.dataset, - name=args.data_subset, - split=split_name, - data_files=data_file, - streaming=True, - ) - # Determine the output columns - columns, example_type = get_columns_and_format( - dataset=dataset, - tokenizing=tokenizer is not None, - preprocessing_fn=preprocessing_fn, - ) - # Prepare the iterables - if example_type == 'chat': - samples = iter(dataset) - else: - loader = build_dataloader( - dataset=dataset, - batch_size=512, - num_workers=args.num_workers, - ) - samples = generate_samples(loader) - - # Write samples - print(f'Converting {split_name} to MDS format...') - out = os.path.join(args.out_root, split_name) - if args.local is not None: - out = (os.path.join(args.local, split_name), out) - keep_local = True - else: - keep_local = False - with MDSWriter( - columns=columns, - out=out, - compression=args.compression, - keep_local=keep_local, - ) as out: - examples_removed = 0 - for sample in tqdm(samples, desc=split_name): - formatted_sample = preprocessing_fn(sample) - assert isinstance(formatted_sample, dict) - - # Use the _get_example_type utility to confirm that the formatted sample - # can be interpreted by the tokenization code - try: - example_type = _get_example_type(formatted_sample) - except Exception as e: - raise ValueError( - 'Encountered an error when checking example for proper formatting. ' +\ - f'example={formatted_sample}', - ) from e - if tokenizer is not None: - sample = tokenize_formatted_example( - formatted_sample, - tokenizer=tokenizer, - ) - if not is_valid_ift_example( - args.max_seq_len, - target_prompts=args.target_prompts, - target_responses=args.target_responses, - decoder_only_format=not args.encoder_decoder, - example=sample, - ): - examples_removed += 1 - continue - - sample_to_write = {'turns': []} - for turn in sample['turns']: - turn_to_write = {} - for key in ['input_ids', 'labels']: - turn_to_write[key] = list(turn[key]) - sample_to_write['turns'].append(turn_to_write) - out.write(sample_to_write) - else: - if example_type == 'prompt_response': - encoded_sample = {} - for key in ['prompt', 'response']: - value = formatted_sample[key] - assert isinstance(value, str) - encoded_sample[key] = value.encode('utf-8') - out.write(encoded_sample) - else: - out.write(formatted_sample) - - if tokenizer is not None and examples_removed > 0: - warnings.warn( - f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, ' - + - 'the prompt or response was empty, or the response was all padding tokens.', - ) - - if __name__ == '__main__': """Example for converting Muennighoff/P3: @@ -355,4 +112,22 @@ def main(args: Namespace) -> None: >>> --preprocessor llmfoundry.data.finetuning.tasks:p3_preprocessing_function \ >>> --out_root s3:///muennighoff-p3 """ - main(parse_args()) + args = parse_args() + convert_finetuning_dataset_from_args( + dataset=args.dataset, + data_subset=args.data_subset, + splits=args.splits, + preprocessor=args.preprocessor, + data_files=args.data_files, + skip_preprocessing=args.skip_preprocessing, + out_root=args.out_root, + local=args.local, + compression=args.compression, + num_workers=args.num_workers, + tokenizer=args.tokenizer, + tokenizer_kwargs=args.tokenizer_kwargs, + max_seq_len=args.max_seq_len, + target_prompts=args.target_prompts, + target_responses=args.target_responses, + encoder_decoder=args.encoder_decoder, + ) diff --git a/setup.py b/setup.py index 309d7d3372..185d3970f7 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ import copy import os -import re +from typing import Any, Dict, Mapping import setuptools from setuptools import setup @@ -15,17 +15,15 @@ _REPO_REAL_PATH = os.path.dirname(os.path.realpath(__file__)) _PACKAGE_REAL_PATH = os.path.join(_REPO_REAL_PATH, _PACKAGE_DIR) -# Read the repo version +# Read the llm-foundry version # We can't use `.__version__` from the library since it's not installed yet -with open(os.path.join(_PACKAGE_REAL_PATH, '__init__.py')) as f: +version_path = os.path.join(_PACKAGE_REAL_PATH, '_version.py') +with open(version_path, encoding='utf-8') as f: + version_globals: Dict[str, Any] = {} + version_locals: Mapping[str, object] = {} content = f.read() -# regex: '__version__', whitespace?, '=', whitespace, quote, version, quote -# we put parens around the version so that it becomes elem 1 of the match -expr = re.compile( - r"""^__version__\s*=\s*['"]([0-9]+\.[0-9]+\.[0-9]+(?:\.\w+)?)['"]""", - re.MULTILINE, -) -repo_version = expr.findall(content)[0] + exec(content, version_globals, version_locals) + repo_version = str(version_locals['__version__']) # Use repo README for PyPi description with open('README.md', 'r', encoding='utf-8') as fh: @@ -56,9 +54,9 @@ install_requires = [ 'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.23.4,<0.24', 'mlflow>=2.14.1,<2.15', - 'accelerate>=0.25,<0.26', # for HF inference `device_map` - 'transformers>=4.42.3,<4.43', - 'mosaicml-streaming>=0.7.6,<0.8', + 'accelerate>=0.25,<0.34', # for HF inference `device_map` + 'transformers>=4.43.2,<4.44', + 'mosaicml-streaming>=0.8.0,<0.9', 'torch>=2.3.0,<2.4', 'datasets>=2.19,<2.20', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data @@ -70,7 +68,7 @@ 'onnx==1.16.1', 'onnxruntime==1.18.1', 'boto3>=1.21.45,<2', - 'huggingface-hub>=0.19.0,<0.24', + 'huggingface-hub>=0.19.0,<0.25', 'beautifulsoup4>=4.12.2,<5', # required for model download utils 'tenacity>=8.2.3,<9', 'catalogue>=2,<3', @@ -104,7 +102,7 @@ # Flash 2 group kept for backwards compatibility extra_deps['gpu-flash2'] = [ - 'flash-attn==2.5.8', + 'flash-attn>=2.5.8,<3', ] extra_deps['gpu'] = copy.deepcopy(extra_deps['gpu-flash2']) diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 83d6edeca2..e623467bf7 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -1,15 +1,12 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -# copyright 2022 mosaicml llm foundry authors -# spdx-license-identifier: apache-2.0 - import unittest from argparse import Namespace from typing import Any from unittest.mock import MagicMock, mock_open, patch -from scripts.data_prep.convert_delta_to_json import ( +from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( download, fetch_DT, format_tablename, @@ -20,11 +17,19 @@ class TestConvertDeltaToJsonl(unittest.TestCase): - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sdk.WorkspaceClient', + ) def test_stream_delta_to_json( self, mock_workspace_client: Any, @@ -33,19 +38,15 @@ def test_stream_delta_to_json( mock_makedirs: Any, mock_sql_connect: Any, ): - - args = MagicMock() - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' - args.DATABRICKS_HOST = 'test_host' - args.DATABRICKS_TOKEN = 'test_token' - args.http_path = 'test_path' - args.batch_size = 1000 - args.partitions = 1 - args.cluster_id = '1234' - args.debug = False - args.use_serverless = False - args.json_output_filename = 'combined.jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' + DATABRICKS_HOST = 'test_host' + DATABRICKS_TOKEN = 'test_token' + http_path = 'test_path' + batch_size = 1000 + cluster_id = '1234' + use_serverless = False + json_output_filename = 'combined.jsonl' mock_cluster_get = MagicMock() mock_cluster_get.return_value = MagicMock( @@ -53,7 +54,17 @@ def test_stream_delta_to_json( ) mock_workspace_client.return_value.clusters.get = mock_cluster_get - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + batch_size=batch_size, + json_output_filename=json_output_filename, + ) mock_sql_connect.assert_called_once_with( server_hostname='test_host', http_path='test_path', @@ -66,7 +77,9 @@ def test_stream_delta_to_json( '/path/to/jsonl/combined.jsonl', ) - @patch('scripts.data_prep.convert_delta_to_json.os.listdir') + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.listdir', + ) @patch( 'builtins.open', new_callable=mock_open, @@ -102,7 +115,9 @@ def test_iterative_combine_jsons(self, mock_file: Any, mock_listdir: Any): """ self.assertEqual(mock_file().write.call_count, 2) - @patch('scripts.data_prep.convert_delta_to_json.SparkSession') + @patch( + 'pyspark.sql.SparkSession', + ) def test_run_query_dbconnect(self, mock_spark: Any): method = 'dbconnect' mock_cursor = None @@ -118,7 +133,9 @@ def test_run_query_dbconnect(self, mock_spark: Any): mock_spark.sql.assert_called_once_with('SELECT * FROM table') self.assertEqual(result, 'result') - @patch('scripts.data_prep.convert_delta_to_json.Cursor') + @patch( + 'databricks.sql.client.Cursor', + ) def test_run_query_dbsql(self, mock_cursor: Any): method = 'dbsql' mock_cursor.fetchall.return_value = 'result' @@ -134,14 +151,18 @@ def test_run_query_dbsql(self, mock_cursor: Any): mock_cursor.execute.assert_called_once_with('SELECT * FROM table') self.assertEqual(result, 'result') - @patch('scripts.data_prep.convert_delta_to_json.requests.get') - @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') @patch( - 'scripts.data_prep.convert_delta_to_json.os.path.join', + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.requests.get', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.pd.DataFrame.to_json', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.path.join', return_value='/fake/path/part_1.jsonl', ) @patch( - 'scripts.data_prep.convert_delta_to_json.time.sleep', + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.time.sleep', ) # Mock sleep to speed up the test def test_download_success( self, @@ -174,12 +195,22 @@ def test_download_success( mock_get.assert_called_once_with('http://fakeurl.com/data') - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_dbconnect_called( self, mock_fetch: Any, @@ -189,17 +220,14 @@ def test_dbconnect_called( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = None - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = False + http_path = None + cluster_id = '1234' + DATABRICKS_HOST = 'host' + DATABRICKS_TOKEN = 'token' + use_serverless = False mock_cluster_response = Namespace(spark_version='14.1.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response @@ -209,19 +237,37 @@ def test_dbconnect_called( ) # Mock return value for getOrCreate mock_databricks_session.builder.remote.return_value = mock_remote - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) mock_databricks_session.builder.remote.assert_called_once_with( - host=args.DATABRICKS_HOST, - token=args.DATABRICKS_TOKEN, - cluster_id=args.cluster_id, + host=DATABRICKS_HOST, + token=DATABRICKS_TOKEN, + cluster_id=cluster_id, ) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_sqlconnect_called_dbr13( self, mock_fetch: Any, @@ -231,34 +277,49 @@ def test_sqlconnect_called_dbr13( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = 'test_path' - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = False + http_path = 'test_path' + cluster_id = '1234' + DATABRICKS_HOST = 'host' + DATABRICKS_TOKEN = 'token' + use_serverless = False mock_cluster_response = Namespace(spark_version='13.0.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) mock_sql_connect.assert_called_once_with( - server_hostname=args.DATABRICKS_HOST, - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN, + server_hostname=DATABRICKS_HOST, + http_path=http_path, + access_token=DATABRICKS_TOKEN, ) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_sqlconnect_called_dbr14( self, mock_fetch: Any, @@ -268,34 +329,49 @@ def test_sqlconnect_called_dbr14( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = 'test_path' - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = False + http_path = 'test_path' + cluster_id = '1234' + DATABRICKS_HOST = 'host' + DATABRICKS_TOKEN = 'token' + use_serverless = False mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) mock_sql_connect.assert_called_once_with( - server_hostname=args.DATABRICKS_HOST, - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN, + server_hostname=DATABRICKS_HOST, + http_path=http_path, + access_token=DATABRICKS_TOKEN, ) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_sqlconnect_called_https( self, mock_fetch: Any, @@ -305,34 +381,49 @@ def test_sqlconnect_called_https( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = 'test_path' - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'https://test-host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = False + http_path = 'test_path' + cluster_id = '1234' + DATABRICKS_HOST = 'https://test-host' + DATABRICKS_TOKEN = 'token' + use_serverless = False mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) mock_sql_connect.assert_called_once_with( server_hostname='test-host', - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN, + http_path=http_path, + access_token=DATABRICKS_TOKEN, ) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') - @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch( + 'databricks.sql.connect', + ) + @patch( + 'databricks.connect.DatabricksSession', + ) + @patch( + 'databricks.sdk.WorkspaceClient', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.os.makedirs', + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_json.iterative_combine_jsons', + ) + @patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch') def test_serverless( self, mock_fetch: Any, @@ -342,22 +433,27 @@ def test_serverless( mock_databricks_session: Any, mock_sql_connect: Any, ): - - args = MagicMock() - - args.delta_table_name = 'test_table' - args.json_output_folder = '/path/to/jsonl' + delta_table_name = 'test_table' + json_output_folder = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path = 'test_path' - args.cluster_id = '1234' - args.DATABRICKS_HOST = 'https://test-host' - args.DATABRICKS_TOKEN = 'token' - args.use_serverless = True + http_path = 'test_path' + cluster_id = '1234' + DATABRICKS_HOST = 'https://test-host' + DATABRICKS_TOKEN = 'token' + use_serverless = True mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response - fetch_DT(args) + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_output_folder, + http_path=http_path, + cluster_id=cluster_id, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + use_serverless=use_serverless, + ) assert not mock_sql_connect.called assert not mock_databricks_session.builder.remote.called diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 2ef458fece..cd47b2df7c 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -383,6 +383,14 @@ def test_huggingface_conversion_callback_interval( mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' + mlflow_logger_mock._enabled = True + mlflow_logger_mock.run_url = 'fake-url' + checkpointer_callback.transform_model_pre_registration = MagicMock( + wraps=checkpointer_callback.transform_model_pre_registration, + ) + checkpointer_callback.pre_register_edit = MagicMock( + wraps=checkpointer_callback.pre_register_edit, + ) trainer = Trainer( model=original_model, device='gpu', @@ -406,9 +414,14 @@ def test_huggingface_conversion_callback_interval( task='llm/v1/completions', input_example=ANY, metadata={}, + pip_requirements=ANY, ) + assert checkpointer_callback.transform_model_pre_registration.call_count == 1 + assert checkpointer_callback.pre_register_edit.call_count == 1 assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: + assert checkpointer_callback.transform_model_pre_registration.call_count == 0 + assert checkpointer_callback.pre_register_edit.call_count == 0 assert mlflow_logger_mock.save_model.call_count == 0 assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 @@ -582,6 +595,7 @@ def _assert_mlflow_logger_calls( 'task': 'llm/v1/completions', 'input_example': default_input_example, 'metadata': {}, + 'pip_requirements': ANY, } mlflow_logger_mock.save_model.assert_called_with(**expectation) assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 21d73c0d34..8e92658194 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -22,6 +22,8 @@ from streaming.base.util import clean_stale_shared_memory from llmfoundry.command_utils import convert_dataset_hf +from llmfoundry.command_utils.data_prep.convert_finetuning_dataset import \ + get_columns_and_format from llmfoundry.data import build_dataloader, build_finetuning_dataloader from llmfoundry.data.finetuning.collator import ( _HF_IGNORE_INDEX, @@ -55,8 +57,6 @@ NotEnoughDatasetSamplesError, UnknownExampleTypeError, ) -# yapf: enable -from scripts.data_prep.convert_finetuning_dataset import get_columns_and_format from tests.data_utils import ( make_tiny_conversation_ft_dataset, make_tiny_ft_dataset, diff --git a/tests/models/hf/test_hf_transform.py b/tests/models/hf/test_hf_transform.py new file mode 100644 index 0000000000..f479b50f73 --- /dev/null +++ b/tests/models/hf/test_hf_transform.py @@ -0,0 +1,76 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Optional + +import pytest +from composer.models.huggingface import maybe_get_underlying_model +from peft import PeftConfig, PeftModel +from transformers import LlamaForCausalLM, PreTrainedModel + +from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM +from llmfoundry.models.utils import init_empty_weights + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'peft_config', + [ + None, + { + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + 'lora_alpha': 32, + 'r': 2, + 'target_modules': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + }, + ], +) +def test_hf_transform(peft_config: Optional[dict]): + model_cfg = { + 'pretrained_model_name_or_path': 'codellama/CodeLlama-7b-hf', + 'config_overrides': { + 'num_hidden_layers': 2, + 'hidden_size': 32, + 'intermediate_size': 64, + }, + 'pretrained': False, + 'peft_config': peft_config, + 'init_device': 'meta', + 'tokenizer': 'codellama/CodeLlama-7b-hf', + } + + class TransformedHFCausalLM(ComposerHFCausalLM): + + def transform_model(self, model: PreTrainedModel) -> PreTrainedModel: + assert isinstance(model, LlamaForCausalLM) + with init_empty_weights(): + model.config.num_hidden_layers = 1 + new_model = type(model)(model.config) + return new_model + + def get_peft_config( + self, + peft_config_dict: Dict[str, Any], + ) -> PeftConfig: + peft_config_dict['target_modules'] = ['o_proj'] + return super().get_peft_config(peft_config_dict) + + composer_model = TransformedHFCausalLM(**model_cfg) + model = composer_model.model + inner_model = maybe_get_underlying_model(model) + + if peft_config: + peft_model = composer_model.model + assert isinstance(peft_model, PeftModel) + + target_modules = peft_model.peft_config[peft_model.active_adapter + ].target_modules + assert list(target_modules) == ['o_proj'] + + assert isinstance(inner_model, LlamaForCausalLM) + assert inner_model.config.num_hidden_layers == 1 diff --git a/tests/models/layers/test_ffn.py b/tests/models/layers/test_ffn.py new file mode 100644 index 0000000000..bb78763f58 --- /dev/null +++ b/tests/models/layers/test_ffn.py @@ -0,0 +1,73 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn + +from llmfoundry.models.layers.ffn import quickgelu_activation +from llmfoundry.models.layers.layer_builders import build_ffn + + +@pytest.mark.gpu +def test_quickgelu_activation(): + d_model = 32 + expansion_ratio = 1 + no_bias = True + ffn_config = { + 'ffn_act_fn': { + 'name': 'quick_gelu', + }, + 'ffn_type': 'mptmlp', + } + rank: int = dist.get_rank() + device_str = f'cuda:{rank}' + device: torch.device = torch.device(device_str) + + ffn1 = build_ffn( + name=ffn_config['ffn_type'], + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device_str, + bias=not no_bias, + ffn_kwargs=ffn_config, + ) + assert ( + ffn1.act == quickgelu_activation + ), f'Expected quick_gelu activation function, got {ffn1.act}' + + ffn_config = { + 'ffn_act_fn': { + 'name': 'gelu', + }, + 'ffn_type': 'mptmlp', + } + ffn2 = build_ffn( + name=ffn_config['ffn_type'], + d_model=d_model, + expansion_ratio=expansion_ratio, + device=device_str, + bias=not no_bias, + ffn_kwargs=ffn_config, + ) + + def num_params(model: nn.Module) -> int: + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + return sum([p.numel() for p in model_parameters]) + + ffn1_numparams = num_params(ffn1) + ffn2_numparams = num_params(ffn2) + assert ( + ffn1_numparams == ffn2_numparams + ), 'Only activation paths should have changed, re-check modeling!' + + input_ = torch.rand(1, d_model, device=device) + output1 = ffn1(input_) + output2 = ffn2(input_) + assert ( + output1.numel() == output2.numel() + ), 'Only activation paths should have changed, re-check modeling!' + assert ( + not torch.allclose(output1, output2) + ), 'Functions are different, outputs should not match!' diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 01d982052f..4bfdfb84dc 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -251,12 +251,13 @@ def gen_bias(attn_impl: str): rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=pos_emb_config['rope_impl'], rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), rope_hf_config=pos_emb_config.get('rope_hf_config', {}), max_seq_len=s, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to(device) pos = torch.arange(s).unsqueeze(0).to(device=device) # adjust the position indices to account for padding tokens @@ -664,12 +665,13 @@ def gen_bias(attn_impl: str): rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( - rope_head_dim=cfg['d_model'] // cfg['n_heads'], rope_impl=pos_emb_config['rope_impl'], rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), rope_hf_config=pos_emb_config.get('rope_hf_config', {}), max_seq_len=s, + d_model=cfg['d_model'], + n_heads=cfg['n_heads'], ).to(device) pos = torch.arange(s).unsqueeze(0).to(device=device) # adjust the position indices to account for padding tokens diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 45378e42bd..ed40e7a88a 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -35,8 +35,7 @@ ) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.bloom.modeling_bloom import build_alibi_tensor -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as HFRotaryEmbedding +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from llmfoundry import ComposerHFCausalLM from llmfoundry.layers_registry import norms @@ -48,7 +47,7 @@ ) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel -from llmfoundry.models.mpt.modeling_mpt import HFRotaryEmbeddingFoundry +from llmfoundry.models.mpt.modeling_mpt import LlamaRotaryEmbeddingFoundry from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model from llmfoundry.utils.config_utils import to_dict_container @@ -2924,7 +2923,7 @@ def test_hf_rotary_child_class_builds(): list(range(max_seq_len)), ] * bsz) - rot_emb_mp = HFRotaryEmbeddingFoundry( + rot_emb_mp = LlamaRotaryEmbeddingFoundry( rope_head_dim, max_seq_len, rope_theta, @@ -2932,7 +2931,7 @@ def test_hf_rotary_child_class_builds(): ) cos_mp, sin_mp = rot_emb_mp(value, position_ids) - rot_emb = HFRotaryEmbedding( + rot_emb = LlamaRotaryEmbedding( rope_head_dim, max_seq_len, rope_theta, diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 6a41e64f48..34fb23f670 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -77,12 +77,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): } dail_rope = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=dail_rope_config['rope_impl'], rope_theta=dail_rope_config['rope_theta'], rope_dail_config=dail_rope_config['rope_dail_config'], rope_hf_config={}, max_seq_len=seq_len, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to('cuda') dail_rope_w_meta_info = { 'impl': 'dail', @@ -92,12 +93,13 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): } hf_rope = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=hf_rope_config['rope_impl'], rope_theta=hf_rope_config['rope_theta'], rope_dail_config={}, rope_hf_config=hf_rope_config['rope_hf_config'], max_seq_len=seq_len, + d_model=cfg.d_model, + n_heads=cfg.n_heads, ).to('cuda') pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda') # adjust the position indices to account for padding tokens diff --git a/tests/models/test_rope_scaling.py b/tests/models/test_rope_scaling.py new file mode 100644 index 0000000000..484ac2b23a --- /dev/null +++ b/tests/models/test_rope_scaling.py @@ -0,0 +1,35 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding + +rope_config = { + 'rope_theta': 500000.0, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'factor': 8.0, + 'low_freq_factor': 1.0, + 'high_freq_factor': 4.0, + 'original_max_position_embeddings': 8192, + 'type': 'llama3', + }, +} + +rope_dail_config = {} + + +def test_rope_scaling(): + d_model = 128 + n_heads = 32 + max_seq_len = 65536 + + embedding = gen_rotary_embedding( + d_model=d_model, + n_heads=n_heads, + rope_dail_config=rope_dail_config, + max_seq_len=max_seq_len, + **rope_config, + ) + + assert isinstance(embedding, LlamaRotaryEmbedding) diff --git a/tests/test_registry.py b/tests/test_registry.py index aa0c93ee13..c4d1a1bcd5 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -24,6 +24,7 @@ def test_expected_registries_exist(): 'loggers', 'optimizers', 'schedulers', + 'tokenizers', 'callbacks', 'algorithms', 'callbacks_with_config', diff --git a/tests/tokenizers/test_registry.py b/tests/tokenizers/test_registry.py new file mode 100644 index 0000000000..920c207a64 --- /dev/null +++ b/tests/tokenizers/test_registry.py @@ -0,0 +1,35 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Optional + +from transformers import PreTrainedTokenizer + +from llmfoundry.registry import tokenizers +from llmfoundry.utils import build_tokenizer + + +class DummyTokenizer(PreTrainedTokenizer): + """A dummy tokenizer that inherits from ``PreTrainedTokenizer``.""" + + def __init__( + self, + model_name: Optional[str] = 'dummy', + **kwargs: Optional[Dict[str, Any]], + ): + """Dummy constructor that has no real purpose.""" + super().__init__( + model_name=model_name, + eos_token='0', + pad_token='1', + **kwargs, + ) + + def get_vocab(self) -> Dict[str, int]: + return {} + + +def test_tokenizer_registry(): + tokenizers.register('dummy', func=DummyTokenizer) + tokenizer = build_tokenizer(tokenizer_name='dummy', tokenizer_kwargs={}) + assert type(tokenizer) == DummyTokenizer