Skip to content

Commit

Permalink
apply loss_v_len changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Feb 3, 2024
1 parent cd6e0d6 commit 1627634
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 11 deletions.
10 changes: 8 additions & 2 deletions composer/loggers/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,15 @@ def log_hyperparameters(self, parameters: Dict[str, Any]):
for destination in self.destinations:
destination.log_hyperparameters(parameters)

def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Table') -> None:
def log_table(self,
columns: List[str],
rows: List[List[Any]],
name: str = 'Table',
step: Optional[int] = None) -> None:
if step is None:
step = self._state.timestamp.batch.value
for destination in self.destinations:
destination.log_table(columns, rows, name)
destination.log_table(columns, rows, name, step)

def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
if step is None:
Expand Down
8 changes: 6 additions & 2 deletions composer/loggers/logger_destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,19 @@ def log_hyperparameters(self, hyperparameters: Dict[str, Any]):
del hyperparameters # unused
pass

def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Table') -> None:
def log_table(self,
columns: List[str],
rows: List[List[Any]],
name: str = 'Table',
step: Optional[int] = None) -> None:
"""Log a table.
Args:
columns (List[str]): Names of the columns in the table.
rows (List[List[Any]]): 2D row-oriented array of values.
name (str): Name of table. (Default: ``'Table'``)
"""
del columns, rows, name
del columns, rows, name, step
pass

def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
Expand Down
8 changes: 6 additions & 2 deletions composer/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,15 @@ def log_hyperparameters(self, hyperparameters: Dict[str, Any]):
import wandb
wandb.config.update(hyperparameters)

def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Table') -> None:
def log_table(self,
columns: List[str],
rows: List[List[Any]],
name: str = 'Table',
step: Optional[int] = None) -> None:
if self._enabled:
import wandb
table = wandb.Table(columns=columns, rows=rows)
wandb.log({name: table})
wandb.log({name: table}, step)

def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
if self._enabled:
Expand Down
5 changes: 4 additions & 1 deletion composer/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError, InContextLearningMetric,
InContextLearningMultipleChoiceAccuracy, InContextLearningQAAccuracy,
LanguageCrossEntropy, LanguagePerplexity, MaskedAccuracy)
LanguageCrossEntropy, LanguagePerplexity, MaskedAccuracy, MetricsRequiringBatchInfo,
LossPerpVLen)

__all__ = [
'MAP',
Expand All @@ -28,6 +29,8 @@
'InContextLearningLMExpectedCalibrationError',
'InContextLearningMetric',
'InContextLearningCodeEvalAccuracy',
'MetricsRequiringBatchInfo',
'LossPerpVLen'
]

METRIC_DEFAULT_CTORS = {
Expand Down
134 changes: 134 additions & 0 deletions composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,140 @@ def compute(self) -> Tensor:
assert isinstance(self.false_negative, Tensor)
f1 = (self.true_positive) / (self.true_positive + (0.5 * (self.false_negative + self.false_positive)))
return f1

class MetricsRequiringBatchInfo(Metric):

def update(self, batch: dict, output: Union[Mapping, torch.Tensor],
target: torch.Tensor) -> None:
"""Abstract interface for computing metrics that require the batch.
Args:
batch (dict): Batch must consist minimally of `input_ids` as well as any other structure needed
to compute the metric.
output_logits (torch.Tensor): The model outputs evaluated on the batch `input_ids`
labels (torch.Tensor): The correct outputs.
Raises:
NotImplementedError: Abstract method must be implemented by subclasses
"""
raise NotImplementedError

class LossPerpVLen(MetricsRequiringBatchInfo):

# Make torchmetrics call update only once
full_state_update = False

def __init__(self,
dist_sync_on_step: bool = False,
ignore_index: int = -100):
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.ignore_index = ignore_index
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index,
reduction='none')
self.add_state('sum_loss',
default=torch.Tensor(),
dist_reduce_fx='sum')
self.add_state('sum_perplexity',
default=torch.Tensor(),
dist_reduce_fx='sum')
self.add_state('sum_length',
default=torch.Tensor(),
dist_reduce_fx='sum')

self.add_state('sum_loss_seq_id',
default=torch.Tensor(),
dist_reduce_fx='sum')
self.add_state('sum_perplexity_seq_id',
default=torch.Tensor(),
dist_reduce_fx='sum')
self.add_state('sum_length_seq_id',
default=torch.Tensor(),
dist_reduce_fx='sum')

def update(self, batch: dict, output: Union[Mapping, torch.Tensor],
target: torch.Tensor) -> None:
"""Updates the internal state with results from a new batch.
Args:
output (Mapping): The output from the model, which must contain
either the Tensor or a Mapping type that contains the loss or model logits.
target (~torch.Tensor): A Tensor of ground-truth values to compare against.
"""
if isinstance(output, Mapping):
logits = output['logits']
elif isinstance(output, torch.Tensor):
logits = output
else:
raise Exception(
f'Type {type(output)} for the output is unsupported.')

bsz, seq_len = target.shape
target = target.view(-1)
logits = logits.view(target.shape[0], -1)
loss = self.loss_fn(logits, target)
perplexity = torch.exp(loss)

loss = loss.view(bsz, seq_len)
perplexity = perplexity.view(bsz, seq_len)
target = target.view(bsz, seq_len)

valid_target_mask = torch.where(target != self.ignore_index, torch.ones_like(target), torch.zeros_like(target))

if self.sum_loss.numel() == 0:
self.sum_loss = torch.zeros(seq_len, device=loss.device, dtype=loss.dtype)
self.sum_perplexity = torch.zeros(seq_len, device=loss.device, dtype=loss.dtype)
self.sum_length = torch.zeros(seq_len, device=loss.device, dtype=torch.long)
self.sum_loss_seq_id = torch.zeros(seq_len, device=loss.device, dtype=loss.dtype)
self.sum_perplexity_seq_id = torch.zeros(seq_len, device=loss.device, dtype=loss.dtype)
self.sum_length_seq_id = torch.zeros(seq_len, device=loss.device, dtype=torch.long)

self.sum_loss += torch.sum(loss, dim=(0))
self.sum_perplexity += torch.sum(perplexity, dim=(0))
self.sum_length += valid_target_mask.sum(dim=0)

if 'sequence_id' in batch:
seq_id = batch['sequence_id']
seq_id_expanded = torch.nn.functional.one_hot(seq_id).transpose(-1,-2)
seq_lens = seq_id_expanded.sum(dim=-1)
max_num_seq = seq_lens.shape[1]
seq_tok_ids = torch.arange(seq_len, device=seq_id.device)[None, None, :].expand(bsz, max_num_seq, -1)
mask = seq_tok_ids < seq_lens[:,:, None]
seq_len_offsets = torch.nn.functional.pad(seq_lens.cumsum(dim=1)[:, :-1], (1, 0), value=0)
seq_tok_ids = seq_tok_ids + seq_len_offsets[:,:, None]
seq_tok_ids = torch.where(mask, seq_tok_ids, torch.zeros_like(seq_tok_ids))

loss = loss[:, None, :].expand(-1, max_num_seq, -1)
perplexity = perplexity[:, None, :].expand(-1, max_num_seq, -1)
valid_target_mask = valid_target_mask[:, None, :].expand(-1, max_num_seq, -1)
loss = torch.where(mask, torch.gather(input=loss, dim=2, index=seq_tok_ids), torch.zeros_like(loss))
perplexity = torch.where(mask, torch.gather(input=perplexity, dim=2, index=seq_tok_ids), torch.zeros_like(perplexity))
mask = torch.where(mask, torch.gather(input=valid_target_mask, dim=2, index=seq_tok_ids), torch.zeros_like(valid_target_mask))

self.sum_loss_seq_id += torch.sum(loss, dim=(0,1))
self.sum_perplexity_seq_id += torch.sum(perplexity, dim=(0,1))
self.sum_length_seq_id += torch.sum(mask, dim=(0,1))

def compute(self) -> torch.Tensor:
"""Aggregate the state over all processes to compute the metric.
Returns:
loss: The loss averaged across all batches as a :class:`~torch.Tensor`.
"""
# Return average loss over entire dataset

sum_perplexity = torch.where(self.sum_length != 0, self.sum_perplexity, -1)
sum_loss = torch.where(self.sum_length != 0, self.sum_loss, -1)
sum_length = torch.where(self.sum_length != 0, self.sum_length, 1)

sum_perplexity_seq_id = torch.where(self.sum_length_seq_id != 0, self.sum_perplexity_seq_id, -1)
sum_loss_seq_id = torch.where(self.sum_length_seq_id != 0, self.sum_loss_seq_id, -1)
sum_length_seq_id = torch.where(self.sum_length_seq_id != 0, self.sum_length_seq_id, 1)

metric_dict = {
'loss_perp_v_len_metrics': True,
'mean_loss_v_len': sum_loss/sum_length,
'mean_perplexity_v_len': sum_perplexity/sum_length,
'mean_loss_seq_id_v_len': sum_loss_seq_id/sum_length_seq_id,
'mean_perplexity_seq_id_v_len': sum_perplexity_seq_id/sum_length_seq_id,
}
return metric_dict


class LanguagePerplexity(LanguageCrossEntropy):
Expand Down
4 changes: 2 additions & 2 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from torchmetrics import Metric

from composer.metrics import InContextLearningMetric, InContextLearningQAAccuracy
from composer.metrics import InContextLearningMetric, InContextLearningQAAccuracy, MetricsRequiringBatchInfo
from composer.models.base import ComposerModel
from composer.utils import MissingConditionalImportError, dist, get_file, import_object, is_model_fsdp, safe_torch_load

Expand Down Expand Up @@ -476,7 +476,7 @@ def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
if isinstance(metric, InContextLearningQAAccuracy):
assert self.labels is not None
metric.update(batch=batch, outputs=outputs, labels=self.labels) # pyright: ignore [reportGeneralTypeIssues]
elif isinstance(metric, InContextLearningMetric):
elif isinstance(metric, InContextLearningMetric) or isinstance(metric, MetricsRequiringBatchInfo):
assert self.labels is not None
metric.update(batch, outputs, self.labels) # pyright: ignore [reportGeneralTypeIssues]
else:
Expand Down
11 changes: 9 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1977,11 +1977,18 @@ def _compute_and_log_metrics(self, dataloader_label: str, metrics: Dict[str, Met

# log computed metrics
computed_metrics = {}
metrics_logged_in_table = []
for metric_name, metric in metrics.items():
computed_metrics[metric_name] = metric.compute()
metric_value = metric.compute()
computed_metrics[metric_name] = metric_value
if isinstance(metric_value, dict) and metric_value.get('loss_perp_v_len_metrics', False):
metrics_logged_in_table.append(metric_name)
for k, v in metric_value.items():
if k != 'loss_perp_v_len_metrics':
self.logger.log_table(columns=['context_length', k], rows=[[i, b] for (i, b) in enumerate(v.tolist())], name=f'metrics/{dataloader_label}/{metric_name}/{k}/{self.logger._state.timestamp.batch.value}')

self.logger.log_metrics(
{f'metrics/{dataloader_label}/{name}': val for (name, val) in computed_metrics.items()},)
{f'metrics/{dataloader_label}/{name}': val for (name, val) in computed_metrics.items() if name not in metrics_logged_in_table},)

# store metric instances
for metric_name, metric in metrics.items():
Expand Down

0 comments on commit 1627634

Please sign in to comment.