Skip to content

Commit

Permalink
Removing logging exception through update run metadata (#1292)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjanezhang authored Jul 9, 2024
1 parent 7e20d81 commit 304bf28
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 87 deletions.
11 changes: 2 additions & 9 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from composer import DataSpec
from composer.core import State, Time, TimeUnit, ensure_time
from composer.loggers import Logger, MosaicMLLogger
from composer.loggers import Logger
from streaming import StreamingDataset
from streaming.base.util import clean_stale_shared_memory
from torch.utils.data import DataLoader
Expand All @@ -23,7 +23,6 @@
BaseContextualError,
TrainDataLoaderLocation,
)
from llmfoundry.utils.mosaicml_logger_utils import no_override_excepthook

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -238,13 +237,7 @@ def _build_train_loader(
self.device_train_batch_size,
)
except BaseContextualError as e:
for destination in logger.destinations:
if (
isinstance(destination, MosaicMLLogger) and
no_override_excepthook()
):
e.location = TrainDataLoaderLocation
destination.log_exception(e)
e.location = TrainDataLoaderLocation
raise e

def _validate_dataloader(self, train_loader: Any):
Expand Down
16 changes: 2 additions & 14 deletions llmfoundry/callbacks/run_timeout_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,12 @@
from typing import Optional

from composer import Callback, Logger, State
from composer.loggers import MosaicMLLogger

from llmfoundry.utils.exceptions import RunTimeoutError
from llmfoundry.utils.mosaicml_logger_utils import no_override_excepthook

log = logging.getLogger(__name__)


def _timeout(timeout: int, mosaicml_logger: Optional[MosaicMLLogger] = None):
def _timeout(timeout: int):
log.error(f'Timeout after {timeout} seconds of inactivity after fit_end.',)
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(RunTimeoutError(timeout=timeout))
os.kill(os.getpid(), signal.SIGINT)


Expand All @@ -30,14 +24,8 @@ def __init__(
timeout: int = 1800,
):
self.timeout = timeout
self.mosaicml_logger: Optional[MosaicMLLogger] = None
self.timer: Optional[threading.Timer] = None

def init(self, state: State, logger: Logger):
for callback in state.callbacks:
if isinstance(callback, MosaicMLLogger):
self.mosaicml_logger = callback

def _reset(self):
if self.timer is not None:
self.timer.cancel()
Expand All @@ -48,7 +36,7 @@ def _timeout(self):
self.timer = threading.Timer(
self.timeout,
_timeout,
[self.timeout, self.mosaicml_logger],
[self.timeout],
)
self.timer.daemon = True
self.timer.start()
Expand Down
2 changes: 0 additions & 2 deletions llmfoundry/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
log_eval_analytics,
log_train_analytics,
maybe_create_mosaicml_logger,
no_override_excepthook,
)
from llmfoundry.utils.prompt_files import load_prompts, load_prompts_from_file
from llmfoundry.utils.registry_utils import (
Expand Down Expand Up @@ -98,7 +97,6 @@
'download_from_hf_hub',
'download_from_oras',
'maybe_create_mosaicml_logger',
'no_override_excepthook',
'find_mosaicml_logger',
'log_eval_analytics',
'log_train_analytics',
Expand Down
11 changes: 0 additions & 11 deletions llmfoundry/utils/mosaicml_logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,6 @@ def maybe_create_mosaicml_logger() -> Optional[MosaicMLLogger]:
return MosaicMLLogger()


def no_override_excepthook() -> bool:
"""Returns True if the excepthook flag is off.
This means we are not automatically catching exceptions for MosaicMl runs.
"""
return os.environ.get(
'OVERRIDE_EXCEPTHOOK',
'false',
).lower() != 'true'


def find_mosaicml_logger(
loggers: List[LoggerDestination],
) -> Optional[MosaicMLLogger]:
Expand Down
24 changes: 6 additions & 18 deletions scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@
from pyspark.sql.dataframe import DataFrame as SparkDataFrame
from pyspark.sql.types import Row

from llmfoundry.utils import (
maybe_create_mosaicml_logger,
no_override_excepthook,
)
from llmfoundry.utils.exceptions import (
ClusterDoesNotExistError,
FailedToConnectToDatabricksError,
Expand Down Expand Up @@ -667,18 +663,10 @@ def fetch_DT(args: Namespace) -> None:
'The name of the combined final jsonl that combines all partitioned jsonl',
)
args = parser.parse_args()
mosaicml_logger = maybe_create_mosaicml_logger()

try:
w = WorkspaceClient()
args.DATABRICKS_HOST = w.config.host
args.DATABRICKS_TOKEN = w.config.token

tik = time.time()
fetch_DT(args)
log.info(f'Elapsed time {time.time() - tik}')
w = WorkspaceClient()
args.DATABRICKS_HOST = w.config.host
args.DATABRICKS_TOKEN = w.config.token

except Exception as e:
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e
tik = time.time()
fetch_DT(args)
log.info(f'Elapsed time {time.time() - tik}')
40 changes: 14 additions & 26 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.data.data import AbstractConcatTokensDataset
from llmfoundry.utils import (
maybe_create_mosaicml_logger,
no_override_excepthook,
)
from llmfoundry.utils.data_prep_utils import (
DownloadingIterable,
download_file,
Expand Down Expand Up @@ -608,25 +604,17 @@ def _configure_logging(logging_level: str):
if __name__ == '__main__':
args = parse_args()
_configure_logging(args.logging_level)

mosaicml_logger = maybe_create_mosaicml_logger()

try:
convert_text_to_mds(
tokenizer_name=args.tokenizer,
output_folder=args.output_folder,
input_folder=args.input_folder,
concat_tokens=args.concat_tokens,
eos_text=args.eos_text,
bos_text=args.bos_text,
no_wrap=args.no_wrap,
compression=args.compression,
processes=args.processes,
reprocess=args.reprocess,
trust_remote_code=args.trust_remote_code,
args_str=_args_str(args),
)
except Exception as e:
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e
convert_text_to_mds(
tokenizer_name=args.tokenizer,
output_folder=args.output_folder,
input_folder=args.input_folder,
concat_tokens=args.concat_tokens,
eos_text=args.eos_text,
bos_text=args.bos_text,
no_wrap=args.no_wrap,
compression=args.compression,
processes=args.processes,
reprocess=args.reprocess,
trust_remote_code=args.trust_remote_code,
args_str=_args_str(args),
)
7 changes: 0 additions & 7 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
EvalDataLoaderLocation,
TrainDataLoaderLocation,
)
from llmfoundry.utils.mosaicml_logger_utils import no_override_excepthook
from llmfoundry.utils.registry_utils import import_file

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -398,8 +397,6 @@ def main(cfg: DictConfig) -> Trainer:
)
except BaseContextualError as e:
e.location = TrainDataLoaderLocation
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e

if mosaicml_logger is not None:
Expand Down Expand Up @@ -431,8 +428,6 @@ def main(cfg: DictConfig) -> Trainer:
callbacks.append(eval_gauntlet_callback)
except BaseContextualError as e:
e.location = EvalDataLoaderLocation
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e

if mosaicml_logger is not None:
Expand Down Expand Up @@ -481,8 +476,6 @@ def main(cfg: DictConfig) -> Trainer:
)
except BaseContextualError as e:
e.location = EvalDataLoaderLocation
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e

compile_config = train_cfg.compile_config
Expand Down

0 comments on commit 304bf28

Please sign in to comment.