Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Apr 2, 2024
1 parent 6a7746e commit 0e6f3bf
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,27 @@

import logging
import tempfile
from io import StringIO
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple

import numpy as np
import torch
from omegaconf import DictConfig
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
from transformers import PreTrainedTokenizerBase

log = logging.getLogger(__name__)


class LogIO(StringIO):

def __init__(self, log: logging.Logger):
self.log = log

def write(self, message: str):
self.log.debug(message)


class BinPackCollator:
"""Utility collator for packing to reduce padding."""

Expand Down Expand Up @@ -452,10 +461,11 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
return padding_percent, waste_percent

log.info('Profiling packing ratios')
with logging_redirect_tqdm(loggers=[log]):
for packing_ratio, raw_batch_size in (pbar := tqdm(
zip(packing_ratios, raw_batch_sizes),
desc='Profiling packing ratios')):
pbar.set_description_str(f'Profiling packing ratio {packing_ratio}')
padding, waste = profile(raw_batch_size)
yield (packing_ratio, padding, waste)
for packing_ratio, raw_batch_size in (pbar :=
tqdm(zip(packing_ratios,
raw_batch_sizes),
desc='Profiling packing ratios',
file=LogIO(log))):
pbar.set_description_str(f'Profiling packing ratio {packing_ratio}')
padding, waste = profile(raw_batch_size)
yield (packing_ratio, padding, waste)

0 comments on commit 0e6f3bf

Please sign in to comment.