diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 753281fcb6..7f389e00a9 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -3,18 +3,27 @@ import logging import tempfile +from io import StringIO from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple import numpy as np import torch from omegaconf import DictConfig from tqdm import tqdm -from tqdm.contrib.logging import logging_redirect_tqdm from transformers import PreTrainedTokenizerBase log = logging.getLogger(__name__) +class LogIO(StringIO): + + def __init__(self, log: logging.Logger): + self.log = log + + def write(self, message: str): + self.log.debug(message) + + class BinPackCollator: """Utility collator for packing to reduce padding.""" @@ -452,10 +461,11 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]: return padding_percent, waste_percent log.info('Profiling packing ratios') - with logging_redirect_tqdm(loggers=[log]): - for packing_ratio, raw_batch_size in (pbar := tqdm( - zip(packing_ratios, raw_batch_sizes), - desc='Profiling packing ratios')): - pbar.set_description_str(f'Profiling packing ratio {packing_ratio}') - padding, waste = profile(raw_batch_size) - yield (packing_ratio, padding, waste) + for packing_ratio, raw_batch_size in (pbar := + tqdm(zip(packing_ratios, + raw_batch_sizes), + desc='Profiling packing ratios', + file=LogIO(log))): + pbar.set_description_str(f'Profiling packing ratio {packing_ratio}') + padding, waste = profile(raw_batch_size) + yield (packing_ratio, padding, waste)