Skip to content

Commit

Permalink
Fix code quality
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 24, 2023
1 parent 044bb00 commit 96b4829
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 75 deletions.
3 changes: 1 addition & 2 deletions llmfoundry/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.data.denoising import (MixtureOfDenoisersCollator,
build_text_denoising_dataloader)
from llmfoundry.data.finetuning import (Seq2SeqFinetuningCollator,
build_finetuning_dataloader)
from llmfoundry.data.text_data import (StreamingTextDataset,
build_text_dataloader)

from llmfoundry.data.dataloader import build_dataloader

__all__ = [
'MixtureOfDenoisersCollator',
'build_text_denoising_dataloader',
Expand Down
4 changes: 1 addition & 3 deletions llmfoundry/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.text_data import build_text_dataloader

from llmfoundry.data.denoising import build_text_denoising_dataloader

from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
from llmfoundry.data.text_data import build_text_dataloader


def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
Expand Down
42 changes: 25 additions & 17 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple
from composer import DataSpec

import numpy as np
import torch
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase


class BinPackCollator:
"""Utility collator for packing to reduce padding."""

Expand Down Expand Up @@ -57,9 +57,11 @@ def efficiency(self) -> float:

def __call__(
self,
examples: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
batch = self.base_collator(examples)
return self.pack(batch)

def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
assert 'attention_mask' in batch
assert 'input_ids' in batch

Expand Down Expand Up @@ -93,14 +95,14 @@ def __call__(

# Re-pad to max_seq_len and batch
batch = _repad(packed_examples,
max_seq_len=self.max_seq_len,
pad_token_id=self.pad_token_id,
padding_side=self.padding_side)
max_seq_len=self.max_seq_len,
pad_token_id=self.pad_token_id,
padding_side=self.padding_side)
return batch


def _extract_trim_batch_idx(batch: Dict[str, torch.Tensor],
idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
example = {k: v[idx] for k, v in batch.items()}

keep = example['attention_mask'] == 1
Expand Down Expand Up @@ -225,7 +227,7 @@ def _first_fit_bin_packing(


def _repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int,
pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:
pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:

def pad_tensor(tensor: torch.Tensor, pad_value: int):
if len(tensor) == max_seq_len:
Expand Down Expand Up @@ -286,19 +288,22 @@ def auto_packing_ratio(dataloader_cfg: DictConfig,
if waste > 0:
break
packing_ratio = packing_ratio_candidate

# Select the minimum packing ratio across all ranks.
if torch.cuda.is_available() and dist.is_available() and dist.is_initialized():
if torch.cuda.is_available() and dist.is_available(
) and dist.is_initialized():
device = get_device('gpu')
packing_ratio_tensor = device.tensor_to_device(torch.tensor(packing_ratio))
packing_ratio_tensor = device.tensor_to_device(
torch.tensor(packing_ratio))
dist.all_reduce(packing_ratio_tensor, reduce_operation='MIN')
packing_ratio = packing_ratio_tensor.item()
return packing_ratio

def profile_packing(dataloader_cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase, min_ratio: float,
max_ratio: float, num_packing_ratios: int,
device_batch_size: int) -> Iterable[Tuple[float, float, float]]:

def profile_packing(
dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
min_ratio: float, max_ratio: float, num_packing_ratios: int,
device_batch_size: int) -> Iterable[Tuple[float, float, float]]:
"""Generator function that profiles example packing across packing ratios.
Args:
Expand All @@ -313,10 +318,12 @@ def profile_packing(dataloader_cfg: DictConfig,
An iterable of tuples of packing ratio, padding, and waste.
"""
import copy

from llmfoundry.data.dataloader import build_dataloader

max_seq_len = dataloader_cfg.dataset.get('max_seq_len')
max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep', None)
max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep',
None)

# Turn off packing for the dataloader (we want raw, pre-packed examples)
dataloader_cfg = copy.deepcopy(dataloader_cfg)
Expand All @@ -340,7 +347,8 @@ def profile_packing(dataloader_cfg: DictConfig,

n_profile_examples = max(raw_batch_sizes) * 100

train_dataspec = build_dataloader(dataloader_cfg, tokenizer, n_profile_examples)
train_dataspec = build_dataloader(dataloader_cfg, tokenizer,
n_profile_examples)
train_dataloader = train_dataspec.dataloader

# Get a bunch of raw examples
Expand Down Expand Up @@ -370,7 +378,7 @@ def profile(raw_batch_size: int) -> Tuple[float, float]:
for batch in split_big_batch(raw_batch_size):
if batch['input_ids'].shape[0] < device_batch_size:
continue
_ = packer(batch)
_ = packer.pack(batch)

# Return the padding / waste stats over that bunch of data
padding_percent = 100 * (1 - packer.efficiency)
Expand Down
2 changes: 1 addition & 1 deletion scripts/misc/profile_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# SPDX-License-Identifier: Apache-2.0

"""Script to profile example packing."""
import os
from typing import Any, Dict

from llmfoundry.data.packing import profile_packing


if __name__ == '__main__':
from argparse import ArgumentParser, Namespace

Expand Down
4 changes: 2 additions & 2 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from llmfoundry import (COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM,
MPTForCausalLM)
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.utils.builders import (build_algorithm, build_callback,
build_icl_data_and_gauntlet,
build_logger, build_optimizer,
Expand All @@ -32,8 +33,6 @@
process_init_device,
update_batch_size_info)

from llmfoundry.data.dataloader import build_dataloader


def validate_config(cfg: DictConfig):
"""Validates compatible model and dataloader selection."""
Expand Down Expand Up @@ -167,6 +166,7 @@ def print_trainable_parameters(model: torch.nn.Module) -> None:
f'trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}'
)


def main(cfg: DictConfig) -> Trainer:
# Filter deprecation warning from torch internal usage
warnings.filterwarnings(
Expand Down
1 change: 0 additions & 1 deletion tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import sys
import tempfile
from argparse import Namespace

from typing import Any, Optional
from unittest.mock import MagicMock

Expand Down
Loading

0 comments on commit 96b4829

Please sign in to comment.