Skip to content

Commit

Permalink
Merge pull request #49 from cdpierse/feature/multi-label-support-one-…
Browse files Browse the repository at this point in the history
…shot

Feature/multi label support one shot
  • Loading branch information
cdpierse authored Jun 6, 2021
2 parents cbbeaf1 + 3289c59 commit f181ec5
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 67 deletions.
100 changes: 73 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ Let's start by initializing a transformers' sequence classification model and to
For this example we are using `facebook/bart-large-mnli` which is a checkpoint for a bart-large model trained on the
[MNLI dataset](https://huggingface.co/datasets/multi_nli). This model typically predicts whether a sentence pair are an entailment, neutral, or a contradiction, however for zero-shot we only look the entailment label.

Notice that we pass our own custom labels `["finance", "technology", "sports"]` to the class instance. Any number of labels can be passed including as little as one. Whichever label scores highest for entailment is the predicted class. If you want to see the attributions for a particular label it is recommended just to pass in that one label and then the attributions will be guaranteed to be calculated w.r.t. that label.
Notice that we pass our own custom labels `["finance", "technology", "sports"]` to the class instance. Any number of labels can be passed including as little as one. Whichever label scores highest for entailment can be accessed via `predicted_label`, however the attributions themselves are calculated for every label. If you want to see the attributions for a particular label it is recommended just to pass in that one label and then the attributions will be guaranteed to be calculated w.r.t. that label.

```python
from transformers import AutoModelForSequenceClassification, AutoTokenizer
Expand All @@ -201,44 +201,90 @@ word_attributions = zero_shot_explainer(

```

Which will return the following list of tuples:
Which will return the following dict of attribution tuple lists for each label:

```python
>>> word_attributions
[('<s>', 0.0),
('Today', 0.0),
('apple', 0.22505152647747717),
('released', -0.16164146624851905),
('the', 0.5026975657258089),
('new', 0.052589263167955536),
('Mac', 0.2528325960993759),
('book', -0.06445090203729663),
('showing', -0.21204922293777534),
('off', 0.06319714817612732),
('a', 0.032048012090796815),
('range', 0.08553079346908955),
('of', 0.1409201107994034),
('new', 0.0515261917112576),
('features', -0.09656406466213506),
('found', 0.02336613296843605),
('in', -0.0011649894272190678),
('the', 0.14229640664777807),
('proprietary', -0.23169065661847646),
('silicon', 0.5963924257008087),
('chip', -0.19908474233975806),
('computer', 0.030620295844734646),
('.', 0.1995076958535378)]
{'finance': [('<s>', 0.0),
('Today', 0.0),
('apple', -0.016100065046282107),
('released', 0.3348383988281792),
('the', -0.8932952916127369),
('new', 0.14207183688642497),
('Mac', 0.016309545780430777),
('book', -0.06956802041125129),
('showing', -0.12661404114316252),
('off', -0.11470154900720078),
('a', -0.03299250484912159),
('range', -0.002532332125100561),
('of', -0.022451943898971004),
('new', -0.01859870581213379),
('features', -0.020774327263810944),
('found', -0.007734346326330102),
('in', 0.005100588658589585),
('the', 0.04711084622588314),
('proprietary', 0.046352064964644286),
('silicon', -0.0033502000158946127),
('chip', -0.010419324929115785),
('computer', -0.11507972995022273),
('.', 0.12237840300907425)],
'technology': [('<s>', 0.0),
('Today', 0.0),
('apple', 0.22505152647747717),
('released', -0.16164146624851905),
('the', 0.5026975657258089),
('new', 0.052589263167955536),
('Mac', 0.2528325960993759),
('book', -0.06445090203729663),
('showing', -0.21204922293777534),
('off', 0.06319714817612732),
('a', 0.032048012090796815),
('range', 0.08553079346908955),
('of', 0.1409201107994034),
('new', 0.0515261917112576),
('features', -0.09656406466213506),
('found', 0.02336613296843605),
('in', -0.0011649894272190678),
('the', 0.14229640664777807),
('proprietary', -0.23169065661847646),
('silicon', 0.5963924257008087),
('chip', -0.19908474233975806),
('computer', 0.030620295844734646),
('.', 0.1995076958535378)],
'sports': [('<s>', 0.0),
('Today', 0.0),
('apple', 0.1776618164760026),
('released', 0.10067773539491479),
('the', 0.4813466937627506),
('new', -0.018555244191949295),
('Mac', 0.016338241133536224),
('book', 0.39311969562943677),
('showing', 0.03579210145504227),
('off', 0.0016710813632476176),
('a', 0.04367940034297261),
('range', 0.06076859006993011),
('of', 0.11039711284328052),
('new', 0.003932416031994724),
('features', -0.009660883377622588),
('found', -0.06507586539836184),
('in', 0.2957812911667922),
('the', 0.1584106228974514),
('proprietary', 0.0005789280604917397),
('silicon', -0.04693795680472678),
('chip', -0.1699508539245465),
('computer', -0.4290823663975582),
('.', 0.469314992542427)]}
```

We can find out which label was predicted with:

```python
>>> zero_shot_explainer.predicted_label
'technology (entailment)'
'technology'
```
#### Visualize Zero Shot Classification attributions

For the `ZeroShotClassificationExplainer` the visualize() method returns a table similar to the `SequenceClassificationExplainer`.
For the `ZeroShotClassificationExplainer` the visualize() method returns a table similar to the `SequenceClassificationExplainer` but with attributions for every label.

```python
zero_shot_explainer.visualize("zero_shot.html")
Expand Down
Binary file modified images/zero_shot_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 14 additions & 8 deletions test/test_zero_shot_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_zero_shot_explainer_init_distilbert():
)

assert zero_shot_explainer.attribution_type == "lig"
assert zero_shot_explainer.attributions is None
assert zero_shot_explainer.attributions == []
assert zero_shot_explainer.label_exists is True
assert zero_shot_explainer.entailment_key == "ENTAILMENT"

Expand Down Expand Up @@ -53,12 +53,14 @@ def test_zero_shot_explainer_word_attributions():
DISTILBERT_MNLI_MODEL,
DISTILBERT_MNLI_TOKENIZER,
)

labels = ["urgent", "phone", "tablet", "computer"]
word_attributions = zero_shot_explainer(
"I have a problem with my iphone that needs to be resolved asap!!",
labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
labels=labels,
)
assert isinstance(word_attributions, list)
assert isinstance(word_attributions, dict)
for label in labels:
assert label in word_attributions.keys()


def test_zero_shot_explainer_call_word_attributions_early_raises_error():
Expand All @@ -76,18 +78,22 @@ def test_zero_shot_explainer_word_attributions_include_hypothesis():
DISTILBERT_MNLI_MODEL,
DISTILBERT_MNLI_TOKENIZER,
)

labels = ["urgent", "phone", "tablet", "computer"]
word_attributions_with_hyp = zero_shot_explainer(
"I have a problem with my iphone that needs to be resolved asap!!",
labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
labels=labels,
include_hypothesis=True,
)
word_attributions_without_hyp = zero_shot_explainer(
"I have a problem with my iphone that needs to be resolved asap!!",
labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
labels=labels,
include_hypothesis=False,
)
assert len(word_attributions_with_hyp) > len(word_attributions_without_hyp)

for label in labels:
assert len(word_attributions_with_hyp[label]) > len(
word_attributions_without_hyp[label]
)


def test_zero_shot_explainer_visualize():
Expand Down
108 changes: 76 additions & 32 deletions transformers_interpret/explainers/zero_shot_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ class ZeroShotClassificationExplainer(
`model.config.label2id.keys()` in order for it to work correctly.
This explainer works by forcing the model to explain it's output with respect to
the entailment class. For each label passed at inference the explainer forms a hypothesis with each and
determines which has the highest predicted probability and then feeds that label as
a hypothesis to the model for attribution.
the entailment class. For each label passed at inference the explainer forms a hypothesis with each
and calculates attributions for each hypothesis label. The label with the highest predicted probability
can be accessed via the attribute `predicted_label`.
"""

def __init__(
Expand Down Expand Up @@ -65,15 +64,24 @@ def __init__(

self.entailment_idx = self.label2id[self.entailment_key]
self.include_hypothesis = False
self.attributions = []

@property
def word_attributions(self) -> list:
def word_attributions(self) -> dict:
"Returns the word attributions for model and the text provided. Raises error if attributions not calculated."
if self.attributions is not None:
if self.attributions != []:
if self.include_hypothesis:
return self.attributions.word_attributions
return dict(
zip(
self.labels,
[attr.word_attributions for attr in self.attributions],
)
)
else:
return self.attributions.word_attributions[: self.sep_idx]
spliced_wa = [
attr.word_attributions[: self.sep_idx] for attr in self.attributions
]
return dict(zip(self.labels, spliced_wa))
else:
raise ValueError(
"Attributions have not yet been calculated. Please call the explainer on text first."
Expand All @@ -94,14 +102,17 @@ def visualize(self, html_filepath: str = None, true_class: str = None):
if not self.include_hypothesis:
tokens = tokens[: self.sep_idx]

score_viz = self.attributions.visualize_attributions( # type: ignore
self.pred_probs,
self.predicted_label,
self.predicted_label,
self.predicted_label,
tokens,
)
html = viz.visualize_text([score_viz])
score_viz = [
self.attributions[i].visualize_attributions( # type: ignore
self.pred_probs[i],
self.labels[i],
self.labels[i],
self.labels[i],
tokens,
)
for i in range(len(self.attributions))
]
html = viz.visualize_text(score_viz)

if html_filepath:
if not html_filepath.endswith(".html"):
Expand Down Expand Up @@ -132,9 +143,15 @@ def _get_top_predicted_label_idx(self, text, hypothesis_labels: List[str]) -> in
input_ids, token_type_ids, position_ids, attention_mask
)
entailment_outputs.append(
float(torch.softmax(preds[0], dim=1)[0][self.entailment_idx])
float(torch.sigmoid(preds[0])[0][self.entailment_idx])
)

normed_entailment_outputs = [
float(i) / sum(entailment_outputs) for i in entailment_outputs
]

self.pred_probs = normed_entailment_outputs

return entailment_outputs.index(max(entailment_outputs))

def _make_input_reference_pair(
Expand Down Expand Up @@ -167,6 +184,19 @@ def _make_input_reference_pair(
len(text_ids),
)

def _forward( # type: ignore
self,
input_ids: torch.Tensor,
token_type_ids=None,
position_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
):

preds = self._get_preds(input_ids, token_type_ids, position_ids, attention_mask)
preds = preds[0]

return torch.softmax(preds, dim=1)[:, self.selected_index]

def _calculate_attributions( # type: ignore
self, embeddings: Embedding, class_name: str, index: int = None
):
Expand All @@ -181,6 +211,11 @@ def _calculate_attributions( # type: ignore
self.ref_position_ids,
) = self._make_input_reference_position_id_pair(self.input_ids)

(
self.token_type_ids,
self.ref_token_type_ids,
) = self._make_input_reference_token_type_pair(self.input_ids, self.sep_idx)

self.attention_mask = self._make_attention_mask(self.input_ids)

self.selected_index = int(self.label2id[class_name])
Expand All @@ -198,12 +233,14 @@ def _calculate_attributions( # type: ignore
self.attention_mask,
position_ids=self.position_ids,
ref_position_ids=self.ref_position_ids,
token_type_ids=self.token_type_ids,
ref_token_type_ids=self.ref_token_type_ids,
)
if self.include_hypothesis:
lig.summarize()
else:
lig.summarize(self.sep_idx)
self.attributions = lig
self.attributions.append(lig)

def __call__(
self,
Expand All @@ -212,15 +249,16 @@ def __call__(
embedding_type: int = 0,
hypothesis_template="this text is about {} .",
include_hypothesis: bool = False,
) -> list:
) -> dict:
"""
Calculates attribution for `text` using the model and
tokenizer given in the constructor. Since `self.model` is
a NLI type model each label in `labels` is formatted to the
`hypothesis_template`. Whichever label gets the highest prediction
score for entailment is selected as the predicted label.
`hypothesis_template`. By default attributions are provided for all
labels. The top predicted label can be found in the `predicted_label`
attribute.
Attribution is then forced to be on the axis of whatever index
Attribution is forced to be on the axis of whatever index
the entailment class resolves to. e.g. {"entailment": 0, "neutral": 1, "contradiction": 2 }
in the above case attributions would be for the label at index 0.
Expand Down Expand Up @@ -251,17 +289,23 @@ def __call__(
Returns:
list: List of tuples containing words and their associated attribution scores.
"""
self.attributions = []
self.pred_probs = []
self.include_hypothesis = include_hypothesis
hypothesis_labels = [hypothesis_template.format(label) for label in labels]
self.labels = labels
self.hypothesis_labels = [hypothesis_template.format(label) for label in labels]

text_idx = self._get_top_predicted_label_idx(text, hypothesis_labels)
self.hypothesis_text = hypothesis_labels[text_idx]
self.predicted_label = (
labels[text_idx] + " (" + self.entailment_key.lower() + ")"
predicted_text_idx = self._get_top_predicted_label_idx(
text, self.hypothesis_labels
)

return super().__call__(
text,
class_name=self.entailment_key,
embedding_type=embedding_type,
)
for i, _ in enumerate(self.labels):
self.hypothesis_text = self.hypothesis_labels[i]
self.predicted_label = labels[i] + " (" + self.entailment_key.lower() + ")"
super().__call__(
text,
class_name=self.entailment_key,
embedding_type=embedding_type,
)
self.predicted_label = self.labels[predicted_text_idx]
return self.word_attributions

0 comments on commit f181ec5

Please sign in to comment.