Skip to content

Commit

Permalink
Add optional logging of text output to EvalOutputLogging (#1283)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Mihir Patel <[email protected]>
Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
3 people authored Jul 1, 2024
1 parent 8604bba commit 68c2625
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 30 deletions.
77 changes: 53 additions & 24 deletions llmfoundry/callbacks/eval_output_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

import warnings
from copy import deepcopy
from typing import Any, Dict, List, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union

import torch
from composer.core import Callback, State
from composer.loggers import ConsoleLogger, Logger
from composer.models import HuggingFaceModel
from composer.utils.dist import all_gather_object


Expand All @@ -24,51 +25,85 @@ class EvalOutputLogging(Callback):
into `batch_keys_to_log`. It will do so after every eval batch.
"""

def __init__(self, log_tokens: bool = False, *args: Any, **kwargs: Any):
def __init__(
self,
log_tokens: bool = False,
log_output_text: Optional[bool] = None,
*args: Any,
**kwargs: Any,
):
super().__init__(self, *args, **kwargs)
self.log_tokens = log_tokens
self.columns = None
self.name = None
self.rows = []
self.log_output_text = log_output_text

def init(self, state: State, logger: Logger) -> None:
if self.log_output_text is False:
return

has_output_text = (
isinstance(state.model, HuggingFaceModel)
and state.dataloader is not None
and hasattr(
state.dataloader.dataset, # pyright: ignore[reportGeneralTypeIssues]
'tokenizer',
)
)
if self.log_output_text is True and has_output_text is False:
raise ValueError(
'`log_output_text=True` is only supported for HuggingFace models and datasets with tokenizers.',
)
elif self.log_output_text is None:
self.log_output_text = has_output_text

def eval_batch_end(self, state: State, logger: Logger) -> None:
if not isinstance(state.batch, Dict):
warnings.warn(
f'''EvalOutputLogging only supports batches that are dictionary. \
f"""EvalOutputLogging only supports batches that are dictionary. \
Found batch for type {type(state.batch)}. \
Not logging eval outputs.''',
Not logging eval outputs.""",
)
return

assert state.outputs is not None
assert state.metric_outputs is not None
logging_dict: Dict[str, Union[List[Any], torch.Tensor,
Sequence[torch.Tensor]]] = deepcopy(
state.metric_outputs,
)

# If batch mode is not generate, outputs will be logits
if state.batch['mode'] == 'generate':
logging_dict: Dict[str,
Union[List[Any], torch.Tensor,
Sequence[torch.Tensor]],
] = deepcopy(
state.metric_outputs,
)

if state.batch.get('mode') == 'generate':
# Outputs are already detokenized
logging_dict['outputs'] = state.outputs
elif self.log_output_text and isinstance(state.outputs, torch.Tensor):
# If batch mode is not generate, outputs will be logits
logging_dict['outputs'] = state.outputs.argmax(dim=-1)

input_ids = state.batch['input_ids']
logged_input = []
assert state.dataloader is not None
dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues]
tokenizer = dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues]
pad_token_id = getattr(
dataset,
'pad_tok_id',
dataset.tokenizer.pad_token_id,
)

# Depad and decode input_ids
for input_list in input_ids.tolist():
dataset = state.dataloader.dataset # pyright: ignore[reportGeneralTypeIssues]
depadded_input = [
tok for tok in input_list if tok != dataset.pad_tok_id
]
logged_input.append(dataset.tokenizer.decode(depadded_input))
depadded_input = [tok for tok in input_list if tok != pad_token_id]
logged_input.append(tokenizer.decode(depadded_input))
logging_dict['input'] = logged_input

# Log token indices if toggled
if self.log_tokens:
logging_dict['input_tokens'] = input_ids.tolist()
if not state.batch['mode'] == 'generate':
if not state.batch.get('mode') == 'generate':
if isinstance(state.outputs, torch.Tensor): # pyright
logging_dict['label_tokens'] = state.outputs.tolist()

Expand All @@ -85,15 +120,9 @@ def eval_batch_end(self, state: State, logger: Logger) -> None:
for key, value in logging_dict.items():
# All types in list are the same
if isinstance(value[0], torch.Tensor):
logging_dict[key] = [
state.dataloader.dataset. # pyright: ignore[reportGeneralTypeIssues]
tokenizer.decode( # pyright: ignore[reportGeneralTypeIssues]
t,
) for t in value
]
logging_dict[key] = [tokenizer.decode(t) for t in value]
elif isinstance(value[0], list):
if isinstance(value[0][0], torch.Tensor):
tokenizer = state.dataloader.dataset.tokenizer # pyright: ignore[reportGeneralTypeIssues]
logging_dict[key] = [[
tokenizer.decode(choice) for choice in t
] for t in value]
Expand Down
Loading

0 comments on commit 68c2625

Please sign in to comment.