From 68c262569ab7671525a5c3ac1ed55a7bf383c7a9 Mon Sep 17 00:00:00 2001 From: Sami Jawhar Date: Mon, 1 Jul 2024 08:02:31 -0700 Subject: [PATCH] Add optional logging of text output to EvalOutputLogging (#1283) --------- Co-authored-by: Mihir Patel Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- .../callbacks/eval_output_logging_callback.py | 77 ++++++++---- .../test_eval_output_logging_callback.py | 113 +++++++++++++++++- 2 files changed, 160 insertions(+), 30 deletions(-) diff --git a/llmfoundry/callbacks/eval_output_logging_callback.py b/llmfoundry/callbacks/eval_output_logging_callback.py index edcd6ed336..b84ea063d1 100644 --- a/llmfoundry/callbacks/eval_output_logging_callback.py +++ b/llmfoundry/callbacks/eval_output_logging_callback.py @@ -5,11 +5,12 @@ import warnings from copy import deepcopy -from typing import Any, Dict, List, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union import torch from composer.core import Callback, State from composer.loggers import ConsoleLogger, Logger +from composer.models import HuggingFaceModel from composer.utils.dist import all_gather_object @@ -24,51 +25,85 @@ class EvalOutputLogging(Callback): into `batch_keys_to_log`. It will do so after every eval batch. """ - def __init__(self, log_tokens: bool = False, *args: Any, **kwargs: Any): + def __init__( + self, + log_tokens: bool = False, + log_output_text: Optional[bool] = None, + *args: Any, + **kwargs: Any, + ): super().__init__(self, *args, **kwargs) self.log_tokens = log_tokens self.columns = None self.name = None self.rows = [] + self.log_output_text = log_output_text + + def init(self, state: State, logger: Logger) -> None: + if self.log_output_text is False: + return + + has_output_text = ( + isinstance(state.model, HuggingFaceModel) + and state.dataloader is not None + and hasattr( + state.dataloader.dataset, # pyright: ignore[reportGeneralTypeIssues] + 'tokenizer', + ) + ) + if self.log_output_text is True and has_output_text is False: + raise ValueError( + '`log_output_text=True` is only supported for HuggingFace models and datasets with tokenizers.', + ) + elif self.log_output_text is None: + self.log_output_text = has_output_text def eval_batch_end(self, state: State, logger: Logger) -> None: if not isinstance(state.batch, Dict): warnings.warn( - f'''EvalOutputLogging only supports batches that are dictionary. \ + f"""EvalOutputLogging only supports batches that are dictionary. \ Found batch for type {type(state.batch)}. \ - Not logging eval outputs.''', + Not logging eval outputs.""", ) return assert state.outputs is not None assert state.metric_outputs is not None - logging_dict: Dict[str, Union[List[Any], torch.Tensor, - Sequence[torch.Tensor]]] = deepcopy( - state.metric_outputs, - ) - - # If batch mode is not generate, outputs will be logits - if state.batch['mode'] == 'generate': + logging_dict: Dict[str, + Union[List[Any], torch.Tensor, + Sequence[torch.Tensor]], + ] = deepcopy( + state.metric_outputs, + ) + + if state.batch.get('mode') == 'generate': # Outputs are already detokenized logging_dict['outputs'] = state.outputs + elif self.log_output_text and isinstance(state.outputs, torch.Tensor): + # If batch mode is not generate, outputs will be logits + logging_dict['outputs'] = state.outputs.argmax(dim=-1) input_ids = state.batch['input_ids'] logged_input = [] assert state.dataloader is not None + dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues] + tokenizer = dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues] + pad_token_id = getattr( + dataset, + 'pad_tok_id', + dataset.tokenizer.pad_token_id, + ) # Depad and decode input_ids for input_list in input_ids.tolist(): - dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues] - depadded_input = [ - tok for tok in input_list if tok != dataset.pad_tok_id - ] - logged_input.append(dataset.tokenizer.decode(depadded_input)) + depadded_input = [tok for tok in input_list if tok != pad_token_id] + logged_input.append(tokenizer.decode(depadded_input)) logging_dict['input'] = logged_input # Log token indices if toggled if self.log_tokens: logging_dict['input_tokens'] = input_ids.tolist() - if not state.batch['mode'] == 'generate': + if not state.batch.get('mode') == 'generate': if isinstance(state.outputs, torch.Tensor): # pyright logging_dict['label_tokens'] = state.outputs.tolist() @@ -85,15 +120,9 @@ def eval_batch_end(self, state: State, logger: Logger) -> None: for key, value in logging_dict.items(): # All types in list are the same if isinstance(value[0], torch.Tensor): - logging_dict[key] = [ - state.dataloader.dataset. # pyright: ignore[reportGeneralTypeIssues] - tokenizer.decode( # pyright: ignore[reportGeneralTypeIssues] - t, - ) for t in value - ] + logging_dict[key] = [tokenizer.decode(t) for t in value] elif isinstance(value[0], list): if isinstance(value[0][0], torch.Tensor): - tokenizer = state.dataloader.dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues] logging_dict[key] = [[ tokenizer.decode(choice) for choice in t ] for t in value] diff --git a/tests/callbacks/test_eval_output_logging_callback.py b/tests/callbacks/test_eval_output_logging_callback.py index 7778e39fe3..b5006f6fb2 100644 --- a/tests/callbacks/test_eval_output_logging_callback.py +++ b/tests/callbacks/test_eval_output_logging_callback.py @@ -1,13 +1,19 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import json +import re +from typing import Any +from unittest import mock +import pytest import torch import transformers from composer.core.state import State from composer.core.time import Timestamp from composer.loggers import InMemoryLogger, Logger +from composer.models import HuggingFaceModel from torch.utils.data import DataLoader from torchmetrics import Metric @@ -50,6 +56,23 @@ def update_curr_eval(self, dataloader: DataLoader, dataloader_label: str): self._dataloader_label = dataloader_label +class MockHFModel(HuggingFaceModel): + + def __init__(self, *args: Any, **kargs: Any): + pass + + +class RegexMatcher: + + def __init__(self, pattern: str): + self.pattern = re.compile(pattern) + + def __eq__(self, other: str): + if not isinstance(other, str): + return False + return bool(self.pattern.match(other)) + + def mock_lm_computation( metric: Metric, tokenizer: transformers.AutoTokenizer, @@ -158,8 +181,45 @@ def mock_mc_computation( metric.compute() +@pytest.mark.parametrize('is_hf_model', [True, False]) +@pytest.mark.parametrize('has_tokenizer', [True, False]) +@pytest.mark.parametrize('log_output_text', [True, False, None]) +def test_init( + is_hf_model: bool, + has_tokenizer: bool, + log_output_text: bool, +): + state = MockState() + in_memory_logger = InMemoryLogger() + logger = Logger(state, in_memory_logger) + + expected_error = log_output_text is True and not ( + is_hf_model and has_tokenizer + ) + exptected_log_output_text = ( + log_output_text is not False and is_hf_model and has_tokenizer + ) + + eval_output_logging = EvalOutputLogging( + loggers_to_use=['InMemoryLogger'], + log_output_text=log_output_text, + ) + + state = mock.Mock(model=MockHFModel() if is_hf_model else mock.Mock()) + state.dataloader.dataset = mock.Mock( + spec=['tokenizer'] if has_tokenizer else [], + ) + with pytest.raises( + ValueError, + ) if expected_error else contextlib.nullcontext(): + eval_output_logging.init(state, logger) + assert eval_output_logging.log_output_text == exptected_log_output_text + + +@pytest.mark.parametrize('log_output_text', [True, False]) def test_eval_output_logging_lm( tiny_gpt2_tokenizer: transformers.AutoTokenizer, + log_output_text: bool, ): # this test simulates an unrolled version of the eval loop occurring twice state = MockState() @@ -170,7 +230,11 @@ def test_eval_output_logging_lm( state.add_metric('lm_acc', lm_metric) # Construct the callback - eval_output_logging = EvalOutputLogging(loggers_to_use=['InMemoryLogger']) + eval_output_logging = EvalOutputLogging( + loggers_to_use=['InMemoryLogger'], + log_output_text=log_output_text, + ) + eval_output_logging.init(mock.Mock(model=MockHFModel()), logger) for _ in range(2): state.update_curr_eval( @@ -193,23 +257,28 @@ def test_eval_output_logging_lm( assert f'lm_acc_step_0' in in_memory_logger.tables # Only want one table - we log once to a single step value during eval_end() assert len(in_memory_logger.tables) == 1 - assert json.loads(in_memory_logger.tables[f'lm_acc_step_0'])['columns'] == [ + logged_data = json.loads(in_memory_logger.tables[f'lm_acc_step_0']) + assert logged_data['columns'] == [ 'context', 'label', 'output', 'result', 'metric_name', + *(['outputs'] if log_output_text else []), 'input', 'run_name', ] + # We use the same data in each batch - assert json.loads(in_memory_logger.tables[f'lm_acc_step_0'])['data'] == [ + assert logged_data['data'] == [ [ 'The dog is', ' furry', ' furry', 1, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' dog is furry(\[PAD\])+I'),) + if log_output_text else []), 'The dog is furry', 'mock_name', ], @@ -219,6 +288,8 @@ def test_eval_output_logging_lm( '[PAD]', 0, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' love to eat(\[PAD\])+I'),) + if log_output_text else []), 'I love to eat pie', 'mock_name', ], @@ -228,6 +299,8 @@ def test_eval_output_logging_lm( ' long lines', 1, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' hate long lines(\[PAD\])+The'),) + if log_output_text else []), 'I hate long lines', 'mock_name', ], @@ -237,6 +310,8 @@ def test_eval_output_logging_lm( ' snowy', 1, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' weather is snowy(\[PAD\])+The'),) + if log_output_text else []), 'The weather is snowy', 'mock_name', ], @@ -246,6 +321,8 @@ def test_eval_output_logging_lm( ' furry', 1, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' dog is furry(\[PAD\])+I'),) + if log_output_text else []), 'The dog is furry', 'mock_name', ], @@ -255,6 +332,8 @@ def test_eval_output_logging_lm( '[PAD]', 0, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' love to eat(\[PAD\])+I'),) + if log_output_text else []), 'I love to eat pie', 'mock_name', ], @@ -264,6 +343,8 @@ def test_eval_output_logging_lm( ' long lines', 1, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' hate long lines(\[PAD\])+The'),) + if log_output_text else []), 'I hate long lines', 'mock_name', ], @@ -273,6 +354,8 @@ def test_eval_output_logging_lm( ' snowy', 1, 'InContextLearningLMAccuracy', + *((RegexMatcher(r' weather is snowy(\[PAD\])+The'),) + if log_output_text else []), 'The weather is snowy', 'mock_name', ], @@ -291,7 +374,11 @@ def test_eval_output_logging_mc( state.add_metric('mc_acc', mc_metric) # Construct the callback - eval_output_logging = EvalOutputLogging(loggers_to_use=['InMemoryLogger']) + eval_output_logging = EvalOutputLogging( + loggers_to_use=['InMemoryLogger'], + log_output_text=True, + ) + eval_output_logging.init(mock.Mock(model=MockHFModel()), logger) for _ in range(2): state.update_curr_eval( MockDataLoader(tiny_gpt2_tokenizer), @@ -314,7 +401,8 @@ def test_eval_output_logging_mc( assert f'mc_acc_step_0' in in_memory_logger.tables # Only want one table - we log once to a single step value during eval_end() assert len(in_memory_logger.tables) == 1 - assert json.loads(in_memory_logger.tables[f'mc_acc_step_0'])['columns'] == [ + logged_data = json.loads(in_memory_logger.tables[f'mc_acc_step_0']) + assert logged_data['columns'] == [ 'context', 'correct_choice', 'correct_choice_idx', @@ -323,11 +411,12 @@ def test_eval_output_logging_mc( 'all_choices', 'result', 'metric_name', + 'outputs', 'input', 'run_name', ] # We use the same data for each batch - assert json.loads(in_memory_logger.tables[f'mc_acc_step_0'])['data'] == [ + assert logged_data['data'] == [ [ 'Q: How do you cook a cake?', ' A: turn on the oven', @@ -340,6 +429,9 @@ def test_eval_output_logging_mc( ], 1, 'InContextLearningMultipleChoiceAccuracy', + RegexMatcher( + r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q', + ), 'Q: How do you cook a cake? A: turn on the oven', 'mock_name', ], @@ -355,6 +447,9 @@ def test_eval_output_logging_mc( ], 0, 'InContextLearningMultipleChoiceAccuracy', + RegexMatcher( + r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q', + ), 'Q: How do you cook a cake? A: do a backflip', 'mock_name', ], @@ -370,6 +465,9 @@ def test_eval_output_logging_mc( ], 1, 'InContextLearningMultipleChoiceAccuracy', + RegexMatcher( + r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q', + ), 'Q: How do you cook a cake? A: turn on the oven', 'mock_name', ], @@ -385,6 +483,9 @@ def test_eval_output_logging_mc( ], 0, 'InContextLearningMultipleChoiceAccuracy', + RegexMatcher( + r': How do you cook a cake\? A: turn on the oven(\[PAD\])+Q', + ), 'Q: How do you cook a cake? A: do a backflip', 'mock_name', ],