Skip to content

Commit

Permalink
Model probes
Browse files Browse the repository at this point in the history
  • Loading branch information
janchorowski committed Dec 30, 2020
1 parent 3ed2a13 commit 3573075
Show file tree
Hide file tree
Showing 7 changed files with 439 additions and 5 deletions.
71 changes: 71 additions & 0 deletions fairseq/criterions/probe_criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
from dataclasses import dataclass, field
from typing import List, Optional

import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.logging.meters import safe_round
from fairseq.models.probed_model import reduce_probe_metrics

@dataclass
class ProbeCriterionConfig(FairseqDataclass):
pass


@register_criterion("probes", dataclass=ProbeCriterionConfig)
class ProbeCriterion(FairseqCriterion):
def __init__(self, task):
super().__init__(task)

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
sample_size = 1

probe_loss, probe_log_outs = model.get_probe_losses(sample)
loss = probe_loss

logging_output = {
"loss": loss.item(),
"ntokens": 1,
"nsentences": 1,
"sample_size": sample_size,
}
logging_output.update(probe_log_outs)

return loss, sample_size, logging_output

@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size, sample_size, round=3
)
reduce_probe_metrics(logging_outputs, metrics)

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
11 changes: 10 additions & 1 deletion fairseq/criterions/wav2vec_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.logging.meters import safe_round

from fairseq.models.probed_model import reduce_probe_metrics

@dataclass
class Wav2VecCriterionConfig(FairseqDataclass):
Expand Down Expand Up @@ -109,6 +109,11 @@ def forward(self, model, sample, reduce=True):
for i, l in enumerate(losses):
logging_output[f"loss_{i}"] = l.item()

if hasattr(model, 'get_probe_losses'):
probe_loss, probe_log_outs = model.get_probe_losses(sample)
loss += probe_loss
logging_output.update(probe_log_outs)

if self.infonce:
with torch.no_grad():
if logits.numel() == 0:
Expand Down Expand Up @@ -170,6 +175,9 @@ def reduce_metrics(logging_outputs) -> None:
"count",
}

handled_keys = reduce_probe_metrics(logging_outputs, metrics)
builtin_keys.update(handled_keys)

for k in logging_outputs[0]:
if k not in builtin_keys:
val = sum(log.get(k, 0) for log in logging_outputs)
Expand All @@ -179,6 +187,7 @@ def reduce_metrics(logging_outputs) -> None:
)
else:
metrics.log_scalar(k, val / len(logging_outputs), round=3)


@staticmethod
def logging_outputs_can_be_summed() -> bool:
Expand Down
187 changes: 187 additions & 0 deletions fairseq/models/probed_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import torch.nn
import torch.nn.functional as F
import logging

logger = logging.getLogger(__name__)


def _pick_nth(tensor_or_sequence, which=0):
if isinstance(tensor_or_sequence, (list, tuple)):
tensor_or_sequence = tensor_or_sequence[which]
else:
if which > 0:
raise ValueError("Requested output not present")
return tensor_or_sequence


def _detach(tensor_or_iterable):
if isinstance(tensor_or_iterable, (list, tuple)):
return [_detach(elem) for elem in tensor_or_iterable]
elif isinstance(tensor_or_iterable, dict):
return {k: _detach(v) for k, v in tensor_or_iterable.items()}
else:
return tensor_or_iterable.detach()


def _compile_selector(selector, default):
if selector is None:
return default
elif isinstance(selector, str):
return eval(selector)
else:
return selector


class Probe(torch.nn.Module):
def __init__(
self,
model,
module_name,
backprop_to_main=False,
output_selector=None,
target_selector=None,
loss_weigth=1.0,
):
super().__init__()
self._saved_tensor = None
self._target_selector = _compile_selector(
target_selector, default=lambda x: {"target": x}
)
self._loss_weigth = loss_weigth

output_selector = _compile_selector(
output_selector, default=lambda x: {"output": x}
)
hook_fn = self._get_hook(output_selector, backprop_to_main)
self._attach(model, module_name, hook_fn)
if backprop_to_main:
logger.info("Registered an auxiliary loss at %s: %s", module_name, self)
else:
logger.info("Registered a probe at %s: %s", module_name, self)

def _get_hook(self, output_selector, backprop_to_main):
def hook_fn(mod, unused_inputs, outputs):
outputs = output_selector(outputs)
if backprop_to_main:
self._saved_tensor = outputs
else:
self._saved_tensor = _detach(outputs)

return hook_fn

def _attach(self, model, module_name, hook_fn):
module = dict(model.named_modules())[module_name]
module.register_forward_hook(hook_fn)

def compute_loss(self, minibatch):
self._saved_tensor.update(self._target_selector(minibatch))
ret = self(**self._saved_tensor)
self._saved_tensor = None
return ret


class FeedForwardProbe(Probe):
def __init__(
self,
layer_dims,
activation="torch.nn.ReLU",
loss="torch.nn.CrossEntropyLoss",
**kwargs,
):
super().__init__(**kwargs)
activation = eval(activation)
in_dim, last_dim, *rest = layer_dims
modules = [torch.nn.Linear(in_dim, last_dim)]
for dim in rest:
modules.append(activation())
modules.append(torch.nn.Linear(last_dim, dim))
last_dim = dim
self.layers = torch.nn.Sequential(*modules)
self.loss = eval(loss)()

def forward(self, output, target):
output = self.layers(output)
return self.loss(output, target)


class Conv1DProbe(Probe):
def __init__(self, layer_dims, kernel_size=1, activation="torch.nn.ReLU", **kwargs):
super().__init__(**kwargs)
activation = eval(activation)
in_dim, last_dim, *rest = layer_dims
assert kernel_size % 2 == 1
modules = [
torch.nn.Conv1d(in_dim, last_dim, kernel_size, padding=kernel_size // 2)
]
for dim in rest:
modules.append(activation())
modules.append(
torch.nn.Conv1d(last_dim, dim, kernel_size, padding=kernel_size // 2)
)
last_dim = dim
self.layers = torch.nn.Sequential(*modules)
self.loss = torch.nn.CrossEntropyLoss()

def forward(self, output, target, padding_mask):
N, Cin, L = output.shape
Nm, Cpad, Lm = padding_mask.shape
assert Cpad == 1
assert N == Nm
output = F.interpolate(output, scale_factor=Lm // L)
output = self.layers(output)
padding_mask = padding_mask.float().squeeze(1)
neg_mask = 1.0 - padding_mask
target = (target * neg_mask + padding_mask * self.loss.ignore_index).long()
loss = self.loss(output, target)
weigth = neg_mask.sum()
acc = (neg_mask * (torch.argmax(output, 1) == target).float()).sum() / weigth
probe_logs = {
"loss": loss.item(),
"loss_weigth": weigth.item(),
"acc": acc.item(),
"acc_weigth": weigth.item(),
}
# logging.info("Probe logs: %s", probe_logs)
return loss * self._loss_weigth, probe_logs


class ProbedModel:
"""A model which can attach small probes to analyze model behavior."""

def _build_probe(self, cls, **kwargs):
cls = eval(cls)
return cls(model=self, **kwargs)

def attach_probes(self, probe_defs):
if not probe_defs:
return
self._probes = torch.nn.ModuleDict(
{
probe_name: self._build_probe(**probe_def)
for probe_name, probe_def in probe_defs.items()
}
)

def get_probe_losses(self, minibatch):
loss = 0.0
extra_log_keys = {}
for probe_name, probe in self._probes.items():
probe_loss, probe_log_keys = probe.compute_loss(minibatch)
loss += probe_loss * probe._loss_weigth
for k, v in probe_log_keys.items():
extra_log_keys[f"probe_{probe_name}_{k}"] = v
return loss, extra_log_keys

def reduce_probe_metrics(logging_outputs, metrics):
handled_keys = set()
def get_v(k):
handled_keys.add(k)
return sum(log.get(k, 0) for log in logging_outputs)
for k in logging_outputs[0]:
if k.startswith("probe_"):
if k.endswith("_weigth"):
continue
v = get_v(k)
weigth = get_v(f'{k}_weigth')
metrics.log_scalar(k, v, weigth, round=3)
return handled_keys
10 changes: 6 additions & 4 deletions fairseq/models/wav2vec/wav2vec2_scribblelens.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict, Any

import numpy as np
import torch
Expand All @@ -14,7 +14,7 @@
from fairseq import utils
from fairseq.data.data_utils import compute_mask_indices
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models import BaseFairseqModel, register_model, probed_model
from fairseq.modules import (
Fp32GroupNorm,
Fp32LayerNorm,
Expand All @@ -33,10 +33,10 @@

@dataclass
class Wav2Vec2SLConfig(Wav2Vec2Config):
pass
probe_defs: Optional[Dict[str, Any]] = field(default=None, metadata={"help": "probes"})

@register_model("wav2vec2_scribblelens", dataclass=Wav2Vec2SLConfig)
class Wav2Vec2ModelSL(BaseFairseqModel):
class Wav2Vec2ModelSL(BaseFairseqModel, probed_model.ProbedModel):
def __init__(self, cfg: Wav2Vec2Config):
super().__init__()
self.cfg = cfg
Expand Down Expand Up @@ -135,6 +135,8 @@ def __init__(self, cfg: Wav2Vec2Config):

self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)

self.attach_probes(cfg.probe_defs)

def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
Expand Down
10 changes: 10 additions & 0 deletions uwr_related/configs/scribblelens_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ task:
max_sample_size: 250000
min_sample_size: 32000
normalize: false
labels: True

dataset:
num_workers: 0
Expand Down Expand Up @@ -66,3 +67,12 @@ model:
latent_vars: 320
latent_groups: 2
latent_temp: [2,0.5,0.999995]

probe_defs:
post_extract_proj_mlp:
cls: Conv1DProbe
module_name: post_extract_proj
layer_dims: [768, 512, 73]
kernel_size: 3
output_selector: 'lambda x: {"output": x.transpose(1, 2)}'
target_selector: 'lambda x: {"target":x["alignments"], "padding_mask": x["net_input"].get("padding_mask")}'
Loading

0 comments on commit 3573075

Please sign in to comment.