Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output eval logging (batch level) #2977

Merged
Merged
Show file tree
Hide file tree
Changes from 214 commits
Commits
Show all changes
254 commits
Select commit Hold shift + click to select a range
ec6fc17
prelim commit
bmosaicml Sep 12, 2023
a59b644
fix max answer lengths for cot
bmosaicml Sep 12, 2023
97b1218
add output logger
bmosaicml Sep 12, 2023
7174e75
create eval output logger
bmosaicml Sep 12, 2023
fdbd53b
fix pyright; git push
bmosaicml Sep 12, 2023
909d07b
Merge branch 'dev' into error_logging_callback
bmosaicml Sep 13, 2023
9f4e3d2
change dist reduce fx
bmosaicml Sep 13, 2023
dce297c
Merge branch 'error_logging_callback' of github.com:bmosaicml/compose…
bmosaicml Sep 13, 2023
ea4e7ee
change dist reduce fx
bmosaicml Sep 13, 2023
5630c23
fix pyright
bmosaicml Sep 13, 2023
30623f7
Merge branch 'dev' into error_logging_callback
bmosaicml Sep 13, 2023
e161e33
Add nightly docker image (#2452)
j316chuck Aug 23, 2023
743fbe1
Fix local eval (#2465)
rishab-partha Aug 24, 2023
0c333b6
Add torch 2.1.0 args for github release-docker workflow
j316chuck Aug 24, 2023
da4e19f
Log system metrics on each event (#2412)
prithvikannan Aug 24, 2023
60d3dc6
Fix torch 2.1.0 docker tag (#2472)
j316chuck Aug 24, 2023
15385b2
Upstream Generate Callback (#2449)
irenedea Aug 25, 2023
ec59026
Upgrade torch nightly docker image for 0.18.3 NCCL version (#2476)
j316chuck Aug 25, 2023
a5ec1ac
Test pytorch 2.1.0 docker images on ci/cd (#2469)
j316chuck Aug 25, 2023
145aeb8
Fix huggingface tokenizer loading for slow tokenizers (#2483)
dakinggg Aug 28, 2023
816a61b
Deprecate Fused LayerNorm (#2475)
nik-mosaic Aug 28, 2023
de68763
Transformers upgrade (#2489)
dakinggg Aug 29, 2023
c4488b5
Update RTD build config with build.os (#2490)
bandish-shah Aug 29, 2023
d91fe4d
Upgrade torch docker version and github workflow tests (#2488)
j316chuck Aug 29, 2023
3a9706d
upgrade node version (#2492)
j316chuck Aug 29, 2023
ee67e99
Gating tying modules w/ FSDP for torch 2.0 (#2467)
bcui19 Aug 30, 2023
99b98ef
Removing min_params (#2494)
bcui19 Aug 30, 2023
91d961d
Fix torchmetrics backwards compatibility issue (#2468)
eracah Aug 31, 2023
6add304
Adding some fixes to FSDP tests (#2495)
bcui19 Aug 31, 2023
b5e0950
fail count (#2496)
mvpatel2000 Aug 31, 2023
8e106a6
Remove PR curve metrics from backward compatibility test and skip tor…
eracah Aug 31, 2023
fc6c995
filter warning (#2500)
mvpatel2000 Sep 1, 2023
ac60704
bump version (#2498)
mvpatel2000 Sep 1, 2023
9274a77
Skip metrics in state dict (#2501)
mvpatel2000 Sep 1, 2023
d7b49c7
Add peak memory stats (#2504)
mvpatel2000 Sep 1, 2023
c24d60d
fix sharded ckpt (#2505)
mvpatel2000 Sep 1, 2023
4e50192
Bump gitpython from 3.1.31 to 3.1.34 (#2509)
dependabot[bot] Sep 5, 2023
90e8bf2
Annotate `torch_prof_remote_file_name` as Optional (#2512)
srstevenson Sep 5, 2023
dac9054
fix: when there is no train_metrics, do not checkpoint (#2502)
furkanbiten Sep 5, 2023
8dfa2db
Remove metric saving (#2514)
mvpatel2000 Sep 7, 2023
c8f3ecd
Fix daily tests by removing gpu marker (#2515)
j316chuck Sep 7, 2023
284c1b7
Refactor mosaic_fsdp.py (#2506)
b-chu Sep 7, 2023
c507e30
fix pr (#2517)
mvpatel2000 Sep 7, 2023
303b7c3
Add custom sharding to ChunkShardingSpec (#2507)
b-chu Sep 8, 2023
3a19321
Update nightly docker image to torch nightly 09-03-23 (#2518)
j316chuck Sep 8, 2023
4ca8f5a
Update pre-commit in setup.py (#2522)
b-chu Sep 8, 2023
c1f87f7
Add FSDP custom wrap with torch 2.1 (#2460)
mvpatel2000 Sep 8, 2023
decf2b2
Fix GCSObjectStore bug where hmac keys auth doesn't work (#2519)
eracah Sep 9, 2023
b521207
prelim commit
bmosaicml Sep 12, 2023
3b09be7
add output logger
bmosaicml Sep 12, 2023
5697e1f
create eval output logger
bmosaicml Sep 12, 2023
2e01b89
change dist reduce fx
bmosaicml Sep 13, 2023
b3b1377
Bump gitpython from 3.1.34 to 3.1.35 (#2525)
dependabot[bot] Sep 12, 2023
1f5012b
Bump pytest from 7.4.0 to 7.4.2 (#2523)
dependabot[bot] Sep 12, 2023
6bdc53e
Upgrade to mlflow version 2.5.0 (#2528)
ngcgarcia Sep 12, 2023
1818b51
disable cifar daily (#2527)
mvpatel2000 Sep 12, 2023
13d411e
mosaicml logger robustness improvements (#2530)
mvpatel2000 Sep 12, 2023
51650ff
Fix metrics keys sort in DecoupledAdamW for OptimizerMonitor FSDP met…
m1kol Sep 12, 2023
c780740
Fix github actions for GCS integration testing (#2532)
mvpatel2000 Sep 13, 2023
17953f4
change dist reduce fx
bmosaicml Sep 13, 2023
cb0ce0e
fix pyright
bmosaicml Sep 13, 2023
f2dd81f
Fix GCS tests (#2535)
mvpatel2000 Sep 13, 2023
8fea658
merge
bmosaicml Sep 14, 2023
75260fc
finish error logging cb
bmosaicml Sep 14, 2023
8bb395f
Merge branch 'dev' into error_logging_callback
bmosaicml Sep 14, 2023
fada3b5
fix
bmosaicml Sep 14, 2023
0540383
Merge branch 'dev' into error_logging_callback
bmosaicml Sep 18, 2023
0e6e6d8
add import to init
bmosaicml Sep 18, 2023
3668c29
Merge branch 'error_logging_callback' of github.com:bmosaicml/compose…
bmosaicml Sep 18, 2023
c653090
add import to init
bmosaicml Sep 18, 2023
b8be3a2
add import to init
bmosaicml Sep 18, 2023
f309785
add file writing
bmosaicml Sep 18, 2023
33b35af
add file writing
bmosaicml Sep 18, 2023
1a0ef89
add file writing
bmosaicml Sep 18, 2023
8aa77f0
add file writing
bmosaicml Sep 18, 2023
1b7e6db
add file writing
bmosaicml Sep 18, 2023
9c75b53
move tensors to cpu
bmosaicml Sep 19, 2023
7a41a01
remove tensors
bmosaicml Sep 19, 2023
e5c8b61
remove tensors
bmosaicml Sep 19, 2023
a33cbd9
remove tensors
bmosaicml Sep 20, 2023
fa88a05
add prompt to qa
bmosaicml Sep 20, 2023
8111682
add prompt to qa
bmosaicml Sep 20, 2023
6e651fd
add prompt to qa
bmosaicml Sep 20, 2023
501bc0c
add prompt to qa
bmosaicml Sep 20, 2023
0116903
add prompt to qa
bmosaicml Sep 20, 2023
5fa5957
add prompt to qa
bmosaicml Sep 20, 2023
5ffb804
add prompt to qa
bmosaicml Sep 20, 2023
605f437
add prompt to qa
bmosaicml Sep 20, 2023
afaa437
add prompt to qa
bmosaicml Sep 20, 2023
b772029
add prompt to qa
bmosaicml Sep 20, 2023
6f8e0d7
add prompt to qa
bmosaicml Sep 20, 2023
92779c4
add prompt to qa
bmosaicml Sep 21, 2023
1ec300e
add prompt to qa
bmosaicml Sep 21, 2023
cf943f4
add prompt to qa
bmosaicml Sep 22, 2023
be859bb
try debugging dist sync issue
bmosaicml Sep 25, 2023
8ab2b04
nit
bmosaicml Sep 25, 2023
a6999aa
debugging
bmosaicml Sep 25, 2023
828ceec
debugging
bmosaicml Sep 25, 2023
b98ffe6
debugging
bmosaicml Sep 25, 2023
0c61063
debugging
jcd2020 Sep 26, 2023
72a4f2b
debugging
jcd2020 Sep 26, 2023
8cf8829
debugging
jcd2020 Sep 26, 2023
484510c
debugging
jcd2020 Sep 26, 2023
07cbebf
debugging
jcd2020 Sep 26, 2023
e6af285
debugging
jcd2020 Sep 26, 2023
6f55ff5
debugging
jcd2020 Sep 26, 2023
e0d80ab
debugging
jcd2020 Sep 26, 2023
177b935
debugging
jcd2020 Sep 26, 2023
dcfa6de
debugging
jcd2020 Sep 26, 2023
cce1fb0
debugging
jcd2020 Sep 26, 2023
10ab1ca
debugging
jcd2020 Sep 26, 2023
3b3fd26
fix syncing of non tensor state
jcd2020 Sep 26, 2023
ab6d797
Merge branch 'dev' into error_logging_callback
bmosaicml Nov 8, 2023
6266eeb
added gpu test
bmosaicml Nov 8, 2023
cd1fc58
merge
bmosaicml Nov 8, 2023
76882cb
fix error
bmosaicml Nov 9, 2023
2855e1f
finish testing callback
bmosaicml Nov 15, 2023
29a5803
fix all errors
bmosaicml Nov 16, 2023
f56c9de
Merge branch 'dev' into error_logging_callback
bmosaicml Nov 16, 2023
57133a4
test commit
tbarton16 Nov 21, 2023
7196028
roll back test commit
tbarton16 Nov 21, 2023
4410203
Merge branch 'dev' into error_logging_callback
bmosaicml Nov 27, 2023
21e322e
remove ranks
bmosaicml Nov 27, 2023
e4eb7ee
Merge branch 'error_logging_callback' of github.com:bmosaicml/compose…
bmosaicml Nov 27, 2023
61447a2
Merge branch 'error_logging_callback' of github.com:bmosaicml/compose…
bmosaicml Nov 27, 2023
d69bdba
re-tesT
bmosaicml Nov 27, 2023
d999b68
Merge branch 'error_logging_callback' of github.com:bmosaicml/compose…
bmosaicml Nov 27, 2023
c030717
Merge branch 'dev' into error_logging_callback
bmosaicml Dec 5, 2023
ca4d3c4
add custome gen kwargs and stopping on eos token
bmosaicml Dec 14, 2023
1e39623
modify test
bmosaicml Dec 14, 2023
09af753
modify test
bmosaicml Dec 14, 2023
47f3c91
Merge branch 'dev' into pass_on_custom_generation_kwargs
bmosaicml Dec 15, 2023
9f9a6bc
Merge branch 'pass_on_custom_generation_kwargs' into error_logging_ca…
bmosaicml Dec 15, 2023
a3501e9
finish
bmosaicml Dec 18, 2023
fadce0e
finish
bmosaicml Dec 18, 2023
92157da
finish
bmosaicml Dec 18, 2023
d137bbc
finish
bmosaicml Dec 18, 2023
b25da3f
Merge branch 'dev' into pass_on_custom_generation_kwargs
bmosaicml Dec 18, 2023
909ed63
finish pr
bmosaicml Dec 20, 2023
e263b5b
Merge branch 'pass_on_custom_generation_kwargs' of github.com:bmosaic…
bmosaicml Dec 20, 2023
32d6668
Merge branch 'dev' into pass_on_custom_generation_kwargs
bmosaicml Dec 20, 2023
4ff16b4
implement early stop
bmosaicml Dec 20, 2023
7ee0a72
Merge branch 'pass_on_custom_generation_kwargs' into add_custom_stopp…
bmosaicml Dec 20, 2023
bcf002e
Merge branch 'add_custom_stopping_criteria' into error_logging_callback
bmosaicml Dec 20, 2023
83a60b7
add tesT
bmosaicml Dec 20, 2023
a031772
Merge branch 'pass_on_custom_generation_kwargs' into add_custom_stopp…
bmosaicml Dec 20, 2023
e5943d6
merge update
bmosaicml Dec 20, 2023
67b4685
Merge branch 'add_custom_stopping_criteria' of github.com:bmosaicml/c…
bmosaicml Dec 20, 2023
be32781
Merge branch 'add_custom_stopping_criteria' into error_logging_callback
bmosaicml Dec 20, 2023
e512a21
merge
bmosaicml Dec 20, 2023
a1af91a
fix
bmosaicml Dec 22, 2023
5f23b3e
finish
bmosaicml Dec 23, 2023
42fb431
finish
bmosaicml Dec 23, 2023
aa05076
fix bug
bmosaicml Dec 23, 2023
076731d
Merge branch 'add_custom_stopping_criteria' into error_logging_callback
bmosaicml Dec 23, 2023
89669c6
finish
bmosaicml Dec 23, 2023
95a7d28
Merge branch 'error_logging_callback' of github.com:bmosaicml/compose…
bmosaicml Dec 23, 2023
dce4ef0
bug fix
bmosaicml Dec 23, 2023
cb3c69d
add keys
bmosaicml Dec 26, 2023
cea85d4
Merge branch 'add_custom_stopping_criteria' into error_logging_callback
bmosaicml Dec 26, 2023
7371e66
add correcT
bmosaicml Dec 26, 2023
c7f5198
modify sync
bmosaicml Dec 26, 2023
786c64c
diff split
bmosaicml Dec 26, 2023
559beee
Merge branch 'add_custom_stopping_criteria' into error_logging_callback
bmosaicml Dec 26, 2023
7f20954
fix typo
bmosaicml Dec 26, 2023
bd10cdd
Merge branch 'add_custom_stopping_criteria' into error_logging_callback
bmosaicml Dec 26, 2023
adf5bab
edit condition
bmosaicml Dec 26, 2023
3cc2442
broken wip
maxisawesome Jan 31, 2024
5c71eab
design demonstration commit
maxisawesome Jan 31, 2024
dd774fb
simplify pr
maxisawesome Feb 1, 2024
489e9c1
further simplify
maxisawesome Feb 1, 2024
af02fc6
wip
maxisawesome Feb 5, 2024
38aedb7
add comments
maxisawesome Feb 5, 2024
9c9fe9b
add other icl metrics
maxisawesome Feb 7, 2024
88f063d
wip
maxisawesome Feb 8, 2024
4e172fc
change dict method, add more stuff to logging
maxisawesome Feb 8, 2024
328627e
fix typos, change some comments
maxisawesome Feb 8, 2024
8d89708
Merge branch 'mosaicml_dev' into error_logging_callback_in_batch
maxisawesome Feb 8, 2024
33176f3
decode tensors, fix wrong dict key
maxisawesome Feb 8, 2024
fbc75ca
fix mc
maxisawesome Feb 8, 2024
7dffa47
1 to 0 lol
maxisawesome Feb 8, 2024
b38799a
wip linting
maxisawesome Feb 9, 2024
4dbccb6
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Feb 9, 2024
637247b
adjust to step logging
maxisawesome Feb 9, 2024
d97cd23
adjust logging names
maxisawesome Feb 12, 2024
12669d5
add mflow, rm batch keys
maxisawesome Feb 12, 2024
0230827
add comments, check for dict in huggingface model update_metric
maxisawesome Feb 15, 2024
13cf569
add user specified logging
maxisawesome Feb 15, 2024
4d79393
move metric_name duplication to update_metric
maxisawesome Feb 15, 2024
642663f
wip fix testing
maxisawesome Feb 16, 2024
0f8909f
fix input shape error
maxisawesome Feb 16, 2024
a1bc29d
rm init
maxisawesome Feb 16, 2024
7e1df22
rm eval_after_all
maxisawesome Feb 16, 2024
ed1fa1c
step=None
maxisawesome Feb 22, 2024
677e686
step=state.timestamp.batch.value
maxisawesome Feb 22, 2024
d8352c3
update name to include step
maxisawesome Feb 22, 2024
25cd65d
merge with dev
maxisawesome Feb 22, 2024
c08e2eb
Merge branch 'mosaicml_dev' into error_logging_callback_in_batch
maxisawesome Feb 22, 2024
8928121
linting, wip on test
maxisawesome Feb 22, 2024
a165e92
fix test
maxisawesome Feb 23, 2024
c2b71a4
pyright wip
maxisawesome Feb 23, 2024
3661ae1
Merge branch 'mosaicml_dev' into error_logging_callback_in_batch
maxisawesome Feb 23, 2024
3bccd00
add non-batch warning
maxisawesome Feb 23, 2024
9e23105
pyright
maxisawesome Feb 23, 2024
5328038
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Feb 23, 2024
d95dba5
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Feb 26, 2024
f1a8d41
debug
maxisawesome Feb 26, 2024
a7708fb
rm this commit that wasn't the right branch
maxisawesome Feb 26, 2024
6b9e720
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Feb 26, 2024
25f1872
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Feb 27, 2024
1f1700b
log at the end of training
maxisawesome Feb 28, 2024
eabee96
Merge branch 'error_logging_callback_in_batch' of github.com:maxisawe…
maxisawesome Feb 28, 2024
4ab7d95
rm silly wandb table logging
maxisawesome Feb 28, 2024
4f5fecb
add run_name
maxisawesome Feb 28, 2024
84f6982
add docstring
maxisawesome Feb 28, 2024
86d92bd
add debug logging
maxisawesome Feb 28, 2024
b13bf47
more logging
maxisawesome Feb 28, 2024
bfa9621
rm info logging
maxisawesome Feb 28, 2024
5fa6d54
improve comments
maxisawesome Feb 28, 2024
eea7df4
Update composer/callbacks/eval_output_logging_callback.py
maxisawesome Feb 28, 2024
eb7200c
rm logging bool
maxisawesome Feb 29, 2024
a3419af
Merge branch 'error_logging_callback_in_batch' of github.com:maxisawe…
maxisawesome Feb 29, 2024
d994296
fix logging for schema tasks
maxisawesome Mar 1, 2024
8c89376
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Mar 1, 2024
ae83f03
fix schema / mc tasks
maxisawesome Mar 1, 2024
454a174
yapf
maxisawesome Mar 1, 2024
8029ebc
rm reshape
maxisawesome Mar 1, 2024
95f81c8
fix tests
maxisawesome Mar 1, 2024
79fa8bb
cleanup test
maxisawesome Mar 1, 2024
8db97f5
pyright
maxisawesome Mar 1, 2024
b1da147
pyright
maxisawesome Mar 1, 2024
acc3e92
docstring
maxisawesome Mar 1, 2024
bfd76a7
pyright
maxisawesome Mar 1, 2024
04188fd
update tests
maxisawesome Mar 1, 2024
cf83165
rm attention mask requirement
maxisawesome Mar 1, 2024
db14ec5
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Mar 1, 2024
a89d12b
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Mar 2, 2024
7785168
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Mar 3, 2024
6fb823d
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Mar 4, 2024
2f63d99
Update composer/metrics/nlp.py
maxisawesome Mar 5, 2024
29951dc
Update composer/metrics/nlp.py
maxisawesome Mar 5, 2024
560e702
rm todo
maxisawesome Mar 6, 2024
8bccb7e
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Mar 6, 2024
e3e7b54
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Mar 7, 2024
685aa2a
Merge branch 'dev' into error_logging_callback_in_batch
maxisawesome Mar 8, 2024
68bb9d5
Merge branch 'error_logging_callback_in_batch' of github.com:maxisawe…
maxisawesome Mar 8, 2024
87e5b26
lint
mvpatel2000 Mar 8, 2024
dd50335
lint
mvpatel2000 Mar 8, 2024
df16843
lint
mvpatel2000 Mar 8, 2024
747fe35
more lint
mvpatel2000 Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from composer.callbacks.activation_monitor import ActivationMonitor
from composer.callbacks.checkpoint_saver import CheckpointSaver
from composer.callbacks.early_stopper import EarlyStopper
from composer.callbacks.eval_output_logging_callback import EvalOutputLogging
from composer.callbacks.export_for_inference import ExportForInferenceCallback
from composer.callbacks.free_outputs import FreeOutputs
from composer.callbacks.generate import Generate
Expand All @@ -35,6 +36,7 @@
'CheckpointSaver',
'MLPerfCallback',
'EarlyStopper',
'EvalOutputLogging',
'ExportForInferenceCallback',
'ThresholdStopper',
'ImageVisualizer',
Expand Down
81 changes: 81 additions & 0 deletions composer/callbacks/eval_output_logging_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Log model outputs and expected outputs during ICL evaluation."""

import warnings
from copy import deepcopy
from typing import Any, Dict, List, Sequence, Union

import torch

from composer.core import Callback, State
from composer.loggers import ConsoleLogger, Logger

# from torch.utils.data import DataLoader, Dataset


class EvalOutputLogging(Callback):
"""Logs eval outputs for each sample of each ICL evaluation dataset.

ICL metrics are required to support caching the model's responses including information on whether model was correct.
Metrics are responsible for returning the results of individual datapoints in a dictionary of lists.
The callback will log the metric name, the depadded and detokenized input, any data stored in state.metric_outputs, and
any keys from the batch pased into `batch_keys_to_log`. It will do so after every eval batch.
"""

def __init__(self, *args, **kwargs):
super().__init__(self, *args, **kwargs)
self.warn_batch_is_not_dict = True
maxisawesome marked this conversation as resolved.
Show resolved Hide resolved

def eval_batch_end(self, state: State, logger: Logger) -> None:
if not isinstance(state.batch, Dict):
if self.warn_batch_is_not_dict:
warnings.warn(f'''EvalOutputLogging only supports batchs that are dictionary. \
maxisawesome marked this conversation as resolved.
Show resolved Hide resolved
Found batch for type {type(state.batch)}. \
Not logging eval outputs.''')
self.warn_batch_is_not_dict = False
return

assert state.outputs is not None
assert state.metric_outputs is not None
logging_dict: Dict[str, Union[List[Any], torch.Tensor, Sequence[torch.Tensor]]] = deepcopy(state.metric_outputs)

if state.batch['mode'] == 'generate':
# Outputs are already detokenized
logging_dict['outputs'] = state.outputs

input_ids = state.batch['input_ids']
logged_input = []
assert state.dataloader is not None
# Depad and decode input_ids
for input_list in input_ids.tolist():
depadded_input = [
tok for tok in input_list
if tok != state.dataloader.dataset.pad_tok_id # pyright: ignore[reportGeneralTypeIssues]
]
logged_input.append(
state.dataloader.dataset.tokenizer.decode(depadded_input)) # pyright: ignore[reportGeneralTypeIssues]
logging_dict['input'] = logged_input

# Get column names
columns = list(logging_dict.keys())
# Convert logging_dict from kv pairs of column name and column values to a list of rows
rows = [list(item) for item in zip(*logging_dict.values())]
maxisawesome marked this conversation as resolved.
Show resolved Hide resolved

# NOTE: This assumes _any_ tensor logged are tokens to be decoded.
# This might not be true if, for example, logits are logged.
# detokenize data in rows
rows = [
[
state.dataloader.dataset.tokenizer.decode(x) # pyright: ignore[reportGeneralTypeIssues]
if isinstance(x, torch.Tensor) else x for x in row
] for row in rows
]

assert state.dataloader_label is not None
step = state.timestamp.batch.value
name = f'{state.dataloader_label}_step_{step}'
for dest_logger in logger.destinations:
if not isinstance(dest_logger, ConsoleLogger):
dest_logger.log_table(columns, rows, name=name, step=state.timestamp.batch.value)
2 changes: 2 additions & 0 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ def __init__(
self.eval_metric_values: Dict[str, float] = {}
self.total_loss_dict: Dict[str, float] = {}

self.metric_outputs: Dict[str, Any] = {}

def _dataset_of(self, dataloader: Optional[Union[Evaluator, DataSpec, DataLoader, Iterable]]) -> Optional[Dataset]:
"""Get the dataset contained by the given dataloader-like object.

Expand Down
6 changes: 4 additions & 2 deletions composer/loggers/in_memory_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ def log_table(self,
raise MissingConditionalImportError(extra_deps_group='pandas',
conda_package='pandas',
conda_channel='conda-forge') from e
table = pd.DataFrame.from_records(data=rows, columns=columns).to_json(orient='split', index=False)
assert isinstance(table, str)
table = pd.DataFrame.from_records(data=rows, columns=columns).to_json(orient='split',
index=False,
force_ascii=False)
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
assert table is not None
self.tables[name] = table

def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
Expand Down
13 changes: 11 additions & 2 deletions composer/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def __init__(
self.run_dir: Optional[str] = None
self.run_url: Optional[str] = None

self.table_dict = {}

def _set_is_in_atexit(self):
self._is_in_atexit = True

Expand All @@ -124,8 +126,15 @@ def log_table(self,
step: Optional[int] = None) -> None:
if self._enabled:
import wandb
table = wandb.Table(columns=columns, rows=rows)
wandb.log({name: table}, step)
if name in self.table_dict:
for row in rows:
self.table_dict[name].add_data(*row)
else:
table = wandb.Table(columns=columns, rows=rows)
self.table_dict[name] = table
# Need to do this copy because apparently wandb table logging is broken LOL
# https://github.com/wandb/wandb/issues/2981#issuecomment-1458447291
wandb.log({name: copy.copy(self.table_dict[name])}, step=step)

def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
if self._enabled:
Expand Down
92 changes: 88 additions & 4 deletions composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

"""A collection of common torchmetrics for NLP tasks."""

import copy
import functools
import logging
import os
import re
import string
import warnings
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -201,6 +203,31 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.needs_batch = True

def _wrap_update(self, update: Callable) -> Callable:
maxisawesome marked this conversation as resolved.
Show resolved Hide resolved

@functools.wraps(update)
def wrapped_func(*args: Any, **kwargs: Any) -> None:
self._computed = None
self._update_count += 1
with torch.set_grad_enabled(self._enable_grad):
try:
update_result = update(*args, **kwargs)
except RuntimeError as err:
if 'Expected all tensors to be on' in str(err):
raise RuntimeError(
'Encountered different devices in metric calculation (see stacktrace for details).'
' This could be due to the metric class not being on the same device as input.'
f' Instead of `metric={self.__class__.__name__}(...)` try to do'
f' `metric={self.__class__.__name__}(...).to(device)` where'
' device corresponds to the device of the input.') from err
raise err

if self.compute_on_cpu:
self._move_list_states_to_cpu()
return update_result

return wrapped_func

def update(self,
batch: dict,
output_logits: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -274,6 +301,12 @@ def __init__(self, dist_sync_on_step: bool = False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state('correct', default=torch.tensor(0.), dist_reduce_fx='sum')
self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum')
self.metric_result_dict = {
'cleaned_output': [],
'original_label': [],
'cleaned_label': [],
'result': [],
}

def normalize_answer(self, answer: str):
"""Lower text and remove punctuation, articles and extra whitespace.
Expand Down Expand Up @@ -301,8 +334,10 @@ def replace_underscore(text: str) -> str:

def update(self, outputs: List[str], labels: List[List[str]], batch: Dict[str, Any]):
cot_delimiter = batch.get('cot_delimiter', '')

maxisawesome marked this conversation as resolved.
Show resolved Hide resolved
do_normalization = batch.get('do_normalization', True)
stopping_criteria = batch.get('stopping_criteria', None)
metric_result_dict = copy.deepcopy(self.metric_result_dict)
for sample_output, sample_labels in zip(outputs, labels):
final_answer = sample_output

Expand All @@ -319,10 +354,20 @@ def update(self, outputs: List[str], labels: List[List[str]], batch: Dict[str, A
cleaned_final_answer = final_answer
cleaned_sample_labels = set(sample_labels)

metric_result_dict['original_label'].append(sample_labels)
metric_result_dict['cleaned_output'].append(cleaned_final_answer)
metric_result_dict['cleaned_label'].append(cleaned_sample_labels)

if any(cleaned_final_answer.startswith(label) for label in cleaned_sample_labels):
self.correct += torch.tensor(1.0)
metric_result_dict['result'].append(1)
else:
metric_result_dict['result'].append(0)

self.total += torch.tensor(1.0)

return metric_result_dict

def compute(self):
assert isinstance(self.correct, Tensor)
assert isinstance(self.total, Tensor)
Expand Down Expand Up @@ -358,6 +403,7 @@ def __init__(self, dist_sync_on_step: bool = False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state('correct', default=torch.tensor(0.), dist_reduce_fx='sum')
self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum')
self.metric_result_dict = {'context': [], 'label': [], 'output': [], 'result': []}

def update(self,
batch: dict,
Expand All @@ -369,13 +415,23 @@ def update(self,
labels=labels,
outputs=outputs)

metric_result_dict = copy.deepcopy(self.metric_result_dict)
for batch_idx, cont_idx in enumerate(batch['continuation_indices']):
cont_tok_pred = outputs[batch_idx].index_select(dim=0, index=cont_idx - 1).argmax(dim=-1)
cont_tok_targ = labels[batch_idx].index_select(dim=0, index=cont_idx - 1)
# TODO: okay to do context_tok here? or do we wanna do that in the logger?
maxisawesome marked this conversation as resolved.
Show resolved Hide resolved
metric_result_dict['context'].append(batch['input_ids'][batch_idx][:cont_idx[0]])
metric_result_dict['label'].append(cont_tok_targ)
metric_result_dict['output'].append(cont_tok_pred)

correct = (cont_tok_pred == cont_tok_targ).all().int()
self.correct += correct
metric_result_dict['result'].append(int(correct))

self.correct += (cont_tok_pred == cont_tok_targ).all().int()
self.total += torch.tensor(1.0)

return metric_result_dict

def compute(self):
assert isinstance(self.correct, Tensor)
assert isinstance(self.total, Tensor)
Expand Down Expand Up @@ -409,6 +465,7 @@ def __init__(self, dist_sync_on_step: bool = False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state('correct', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('total', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.metric_result_dict = {'context': [], 'correct_choice': [], 'selected_choice': [], 'result': []}

def update(self,
batch: dict,
Expand All @@ -430,14 +487,31 @@ def update(self,
perplexity = torch.exp(cross_entropy)
perplexities.append(perplexity)

metric_result_dict = copy.deepcopy(self.metric_result_dict)
for (start, end), gold_idx in zip(batch['choice_groupings'], batch['gold_indices']):
subset = perplexities[start:end]
idx_min = subset.index(min(subset))

if idx_min == gold_idx:
self.correct += torch.tensor(1.0)
metric_result_dict['result'].append(1)
else:
metric_result_dict['result'].append(0)

question = batch['input_ids'][start][:batch['continuation_indices'][start][0]]

# TODO: Seems broken for schema tasks
maxisawesome marked this conversation as resolved.
Show resolved Hide resolved
correct_choice = batch['input_ids'][start:end][gold_idx][batch['continuation_indices'][start:end][gold_idx][
0]:batch['continuation_indices'][start:end][gold_idx][-1] + 1]
selected_choice = batch['input_ids'][start:end][idx_min][batch['continuation_indices'][start:end][idx_min][
0]:batch['continuation_indices'][start:end][idx_min][-1] + 1]
metric_result_dict['context'].append(question)
metric_result_dict['correct_choice'].append(correct_choice)
metric_result_dict['selected_choice'].append(selected_choice)

self.total += torch.tensor(1.0)

return metric_result_dict

def compute(self):
assert isinstance(self.correct, Tensor)
assert isinstance(self.total, Tensor)
Expand Down Expand Up @@ -609,6 +683,8 @@ def __init__(self, dist_sync_on_step: bool = False):
if self.eval_device is not None:
self.eval_device = self.eval_device.upper()

self.metric_result_dict = {'context': [], 'output': [], 'result': [], 'sample_id': []}

def get_client(self) -> EvalClient:
"""Returns a client for the appropriate remote platform."""
client = None
Expand Down Expand Up @@ -647,7 +723,7 @@ def estimator(self, n: int, c: int, k: int) -> float:
return 1.0
return 1.0 - float(np.prod(1.0 - k / np.arange(n - c + 1, n + 1)))

def _initialize_state(self, batch: dict[str, Any]):
def _initialize_state(self, batch: Dict[str, Any]):
maxisawesome marked this conversation as resolved.
Show resolved Hide resolved
device = batch['input_ids'].device
self.dataset_size = batch['dataset_size']
self.pass_at_k = batch['pass_at_k']
Expand Down Expand Up @@ -689,15 +765,19 @@ def update(self, batch: Dict[str, Any], outputs: List[str], labels: List[str]):
del labels # never used
client = self.get_client()

metric_result_dict = copy.deepcopy(self.metric_result_dict)
maxisawesome marked this conversation as resolved.
Show resolved Hide resolved
for sample_id, code_gen, sample_prompt, test_inputs, test_outputs, entry_point, language in zip(
batch['sample_id'], outputs, batch['prompts'], batch['test_inputs'], batch['test_outputs'],
batch['entry_points'], batch['languages']):

idx = sample_id
self.total[idx] += 1.0
metric_result_dict['sample_id'].append(sample_id)

code_gen = re.split(r'\n[A-Za-z0-9#`]', code_gen)[0] # remove everything after function ends
final_code = sample_prompt + code_gen # combine prompt with the code generation
metric_result_dict['context'].append(sample_prompt)
metric_result_dict['output'].append(code_gen)

test_results = []
for test_input, test_output in zip(test_inputs, test_outputs):
Expand All @@ -714,8 +794,12 @@ def update(self, batch: Dict[str, Any], outputs: List[str], labels: List[str]):

if all(test_results):
self.correct[idx] += 1.0
metric_result_dict['result'].append(1)
else:
metric_result_dict['result'].append(0)

client.close() # pyright: ignore [reportOptionalMemberAccess]
return metric_result_dict

def compute(self):
assert isinstance(self.correct, Tensor)
Expand Down
12 changes: 9 additions & 3 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,17 @@ def get_metrics(self, is_train: bool = False) -> Dict[str, Metric]:

return metrics if metrics else {}

def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> Dict:
if getattr(metric, 'needs_batch', False):
metric.update(batch=batch, outputs=outputs, labels=self.labels)
metric_result = metric.update(batch=batch, outputs=outputs, labels=self.labels)
else:
metric.update(outputs, self.labels)
metric_result = metric.update(outputs, self.labels)
if metric_result is not None:
# Add the metric name once for each datapoint in the batch
metric_result['metric_name'] = [metric.__class__.__name__ for _ in range(0, batch['input_ids'].shape[0])]
else:
metric_result = {}
return metric_result

def get_metadata(self):
model_output = {}
Expand Down
3 changes: 2 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3031,11 +3031,12 @@ def _eval_loop(
outputs = self.state.outputs

for metric in metrics.values():
self._original_model.update_metric(
metric_outputs = self._original_model.update_metric(
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
self.state.batch,
outputs,
metric,
)
self.state.metric_outputs = metric_outputs or {}

except RuntimeError as e:
if evaluator.auto_microbatching and _is_cuda_oom(e):
Expand Down
Loading
Loading