Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove subclasses from composer #2962

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions composer/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,8 @@

"""Natively supported datasets."""

from composer.datasets.in_context_learning_evaluation import (InContextLearningCodeEvalDataset,
InContextLearningDataset, InContextLearningLMTaskDataset,
InContextLearningMultipleChoiceTaskDataset,
InContextLearningQATaskDataset,
InContextLearningSchemaTaskDataset)
from composer.datasets.in_context_learning_evaluation import InContextLearningDataset

__all__ = [
'InContextLearningDataset',
'InContextLearningQATaskDataset',
'InContextLearningLMTaskDataset',
'InContextLearningCodeEvalDataset',
'InContextLearningMultipleChoiceTaskDataset',
'InContextLearningSchemaTaskDataset',
]
1,112 changes: 3 additions & 1,109 deletions composer/datasets/in_context_learning_evaluation.py

Large diffs are not rendered by default.

20 changes: 2 additions & 18 deletions composer/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@

from composer.metrics.map import MAP
from composer.metrics.metrics import CrossEntropy, Dice, LossMetric, MIoU
from composer.metrics.nlp import (BinaryF1Score, InContextLearningCodeEvalAccuracy, InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError, InContextLearningMetric,
InContextLearningMultipleChoiceAccuracy, InContextLearningQAAccuracy,
LanguageCrossEntropy, LanguagePerplexity, MaskedAccuracy)
from composer.metrics.nlp import (BinaryF1Score, InContextLearningMetric, LanguageCrossEntropy, LanguagePerplexity,
MaskedAccuracy)

__all__ = [
'MAP',
Expand All @@ -21,18 +18,5 @@
'LanguageCrossEntropy',
'MaskedAccuracy',
'LanguagePerplexity',
'InContextLearningLMAccuracy',
'InContextLearningMultipleChoiceAccuracy',
'InContextLearningQAAccuracy',
'InContextLearningMCExpectedCalibrationError',
'InContextLearningLMExpectedCalibrationError',
'InContextLearningMetric',
'InContextLearningCodeEvalAccuracy',
]

METRIC_DEFAULT_CTORS = {
'InContextLearningLMAccuracy': InContextLearningLMAccuracy,
'InContextLearningMultipleChoiceAccuracy': InContextLearningMultipleChoiceAccuracy,
'InContextLearningQAAccuracy': InContextLearningQAAccuracy,
'InContextLearningCodeEvalAccuracy': InContextLearningCodeEvalAccuracy,
}
460 changes: 2 additions & 458 deletions composer/metrics/nlp.py

Large diffs are not rendered by default.

7 changes: 2 additions & 5 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
from torchmetrics import Metric

from composer.metrics import InContextLearningMetric, InContextLearningQAAccuracy
from composer.metrics import InContextLearningMetric
from composer.models.base import ComposerModel
from composer.utils import MissingConditionalImportError, dist, get_file, import_object, is_model_fsdp, safe_torch_load

Expand Down Expand Up @@ -532,10 +532,7 @@ def get_metrics(self, is_train: bool = False) -> Dict[str, Metric]:
return metrics if metrics else {}

def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
if isinstance(metric, InContextLearningQAAccuracy):
assert self.labels is not None
metric.update(batch=batch, outputs=outputs, labels=self.labels) # pyright: ignore [reportGeneralTypeIssues]
elif isinstance(metric, InContextLearningMetric):
if isinstance(metric, InContextLearningMetric):
assert self.labels is not None
metric.update(batch, outputs, self.labels) # pyright: ignore [reportGeneralTypeIssues]
else:
Expand Down
Loading
Loading