Skip to content

Commit

Permalink
Merge pull request #54 from cdpierse/feature/add-nsteps-param
Browse files Browse the repository at this point in the history
Feature/add nsteps param Closes #51
  • Loading branch information
cdpierse authored Jun 24, 2021
2 parents 0cd3324 + d864150 commit deb65f3
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"test",
]
),
version="0.5.0",
version="0.5.1",
license="Apache-2.0",
description="Transformers Interpret is a model explainability tool designed to work exclusively with 🤗 transformers.",
long_description=long_description,
Expand Down
18 changes: 18 additions & 0 deletions test/test_question_answering_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,21 @@ def test_question_answering_visualize_save_append_html_file_ending():
qa_explainer.visualize(html_filename)
assert os.path.exists(html_filename + ".html")
os.remove(html_filename + ".html")


def test_question_answering_custom_steps():
qa_explainer = QuestionAnsweringExplainer(
DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER
)
explainer_question = "what is his name ?"
explainer_text = "his name is Bob"
qa_explainer(explainer_question, explainer_text, n_steps=1)


def test_question_answering_custom_internal_batch_size():
qa_explainer = QuestionAnsweringExplainer(
DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER
)
explainer_question = "what is his name ?"
explainer_text = "his name is Bob"
qa_explainer(explainer_question, explainer_text, internal_batch_size=1)
17 changes: 16 additions & 1 deletion test/test_sequence_classification_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def test_sequence_classification_explainer_init_custom_labels_size_error():
)



def test_sequence_classification_encode():
seq_explainer = SequenceClassificationExplainer(
DISTILBERT_MODEL, DISTILBERT_TOKENIZER
Expand Down Expand Up @@ -263,3 +262,19 @@ def test_sequence_classification_viz():
)
seq_explainer(explainer_string)
seq_explainer.visualize()


def sequence_classification_custom_steps():
explainer_string = "I love you , I like you"
seq_explainer = SequenceClassificationExplainer(
DISTILBERT_MODEL, DISTILBERT_TOKENIZER
)
seq_explainer(explainer_string, n_steps=1)


def sequence_classification_internal_batch_size():
explainer_string = "I love you , I like you"
seq_explainer = SequenceClassificationExplainer(
DISTILBERT_MODEL, DISTILBERT_TOKENIZER
)
seq_explainer(explainer_string, internal_batch_size=1)
26 changes: 26 additions & 0 deletions test/test_zero_shot_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,29 @@ def test_zero_shot_model_lowercase_entailment():
DISTILBERT_MNLI_MODEL,
DISTILBERT_MNLI_TOKENIZER,
)


def test_zero_shot_custom_steps():
zero_shot_explainer = ZeroShotClassificationExplainer(
DISTILBERT_MNLI_MODEL,
DISTILBERT_MNLI_TOKENIZER,
)

zero_shot_explainer(
"I have a problem with my iphone that needs to be resolved asap!!",
labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
n_steps=1,
)


def test_zero_shot_internal_batch_size():
zero_shot_explainer = ZeroShotClassificationExplainer(
DISTILBERT_MNLI_MODEL,
DISTILBERT_MNLI_TOKENIZER,
)

zero_shot_explainer(
"I have a problem with my iphone that needs to be resolved asap!!",
labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
internal_batch_size=1,
)
12 changes: 12 additions & 0 deletions transformers_interpret/attributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def __init__(
position_ids: torch.Tensor = None,
ref_token_type_ids: torch.Tensor = None,
ref_position_ids: torch.Tensor = None,
internal_batch_size: int = None,
n_steps: int = 50,
):
super().__init__(custom_forward, embeddings, tokens)
self.input_ids = input_ids
Expand All @@ -38,6 +40,8 @@ def __init__(
self.position_ids = position_ids
self.ref_token_type_ids = ref_token_type_ids
self.ref_position_ids = ref_position_ids
self.internal_batch_size = internal_batch_size
self.n_steps = n_steps

self.lig = LayerIntegratedGradients(self.custom_forward, self.embeddings)

Expand All @@ -51,6 +55,8 @@ def __init__(
),
return_convergence_delta=True,
additional_forward_args=(self.attention_mask),
internal_batch_size=self.internal_batch_size,
n_steps=self.n_steps,
)
elif self.position_ids is not None:
self._attributions, self.delta = self.lig.attribute(
Expand All @@ -61,6 +67,8 @@ def __init__(
),
return_convergence_delta=True,
additional_forward_args=(self.attention_mask),
internal_batch_size=self.internal_batch_size,
n_steps=self.n_steps,
)
elif self.token_type_ids is not None:
self._attributions, self.delta = self.lig.attribute(
Expand All @@ -71,13 +79,17 @@ def __init__(
),
return_convergence_delta=True,
additional_forward_args=(self.attention_mask),
internal_batch_size=self.internal_batch_size,
n_steps=self.n_steps,
)

else:
self._attributions, self.delta = self.lig.attribute(
inputs=self.input_ids,
baselines=self.ref_input_ids,
return_convergence_delta=True,
internal_batch_size=self.internal_batch_size,
n_steps=self.n_steps,
)

@property
Expand Down
30 changes: 28 additions & 2 deletions transformers_interpret/explainers/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(

self.position = 0

self.internal_batch_size = None
self.n_steps = 50

def encode(self, text: str) -> list: # type: ignore
"Encode 'text' using tokenizer, special tokens are not added"
return self.tokenizer.encode(text, add_special_tokens=False)
Expand Down Expand Up @@ -320,6 +323,8 @@ def _calculate_attributions(self, embeddings: Embedding): # type: ignore
ref_position_ids=self.ref_position_ids,
token_type_ids=self.token_type_ids,
ref_token_type_ids=self.ref_token_type_ids,
internal_batch_size=self.internal_batch_size,
n_steps=self.n_steps,
)
start_lig.summarize()
self.start_attributions = start_lig
Expand All @@ -337,12 +342,21 @@ def _calculate_attributions(self, embeddings: Embedding): # type: ignore
ref_position_ids=self.ref_position_ids,
token_type_ids=self.token_type_ids,
ref_token_type_ids=self.ref_token_type_ids,
internal_batch_size=self.internal_batch_size,
n_steps=self.n_steps,
)
end_lig.summarize()
self.end_attributions = end_lig
self.attributions = [self.start_attributions, self.end_attributions]

def __call__(self, question: str, text: str, embedding_type: int = 2) -> dict:
def __call__(
self,
question: str,
text: str,
embedding_type: int = 2,
internal_batch_size: int = None,
n_steps: int = None,
) -> dict:
"""
Calculates start and end position word attributions for `question` and `text` using the model
and tokenizer given in the constructor.
Expand All @@ -357,9 +371,21 @@ def __call__(self, question: str, text: str, embedding_type: int = 2) -> dict:
question (str): The question text
text (str): The text or context from which the model finds an answers
embedding_type (int, optional): The embedding type word(0), position(1), all(2) to calculate attributions for.
Defaults to 2.
Defaults to 2.
internal_batch_size (int, optional): Divides total #steps * #examples
data points into chunks of size at most internal_batch_size,
which are computed (forward / backward passes)
sequentially. If internal_batch_size is None, then all evaluations are
processed in one batch.
n_steps (int, optional): The number of steps used by the approximation
method. Default: 50.
Returns:
dict: Dict for start and end position word attributions.
"""

if n_steps:
self.n_steps = n_steps
if internal_batch_size:
self.internal_batch_size = internal_batch_size
return self._run(question, text, embedding_type)
20 changes: 19 additions & 1 deletion transformers_interpret/explainers/sequence_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def __init__(

self._single_node_output = False

self.internal_batch_size = None
self.n_steps = 50

@staticmethod
def _get_id2label_and_label2id_dict(
labels: List[str],
Expand Down Expand Up @@ -239,6 +242,8 @@ def _calculate_attributions( # type: ignore
self.attention_mask,
position_ids=self.position_ids,
ref_position_ids=self.ref_position_ids,
internal_batch_size=self.internal_batch_size,
n_steps=self.n_steps,
)
lig.summarize()
self.attributions = lig
Expand Down Expand Up @@ -279,6 +284,8 @@ def __call__(
index: int = None,
class_name: str = None,
embedding_type: int = 0,
internal_batch_size: int = None,
n_steps: int = None,
) -> list:
"""
Calculates attribution for `text` using the model
Expand All @@ -299,10 +306,21 @@ def __call__(
index (int, optional): Optional output index to provide attributions for. Defaults to None.
class_name (str, optional): Optional output class name to provide attributions for. Defaults to None.
embedding_type (int, optional): The embedding type word(0) or position(1) to calculate attributions for. Defaults to 0.
internal_batch_size (int, optional): Divides total #steps * #examples
data points into chunks of size at most internal_batch_size,
which are computed (forward / backward passes)
sequentially. If internal_batch_size is None, then all evaluations are
processed in one batch.
n_steps (int, optional): The number of steps used by the approximation
method. Default: 50.
Returns:
list: List of tuples containing words and their associated attribution scores.
"""

if n_steps:
self.n_steps = n_steps
if internal_batch_size:
self.internal_batch_size = internal_batch_size
return self._run(text, index, class_name, embedding_type=embedding_type)

def __str__(self):
Expand Down
20 changes: 19 additions & 1 deletion transformers_interpret/explainers/zero_shot_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(
self.include_hypothesis = False
self.attributions = []

self.internal_batch_size = None
self.n_steps = 50

@property
def word_attributions(self) -> dict:
"Returns the word attributions for model and the text provided. Raises error if attributions not calculated."
Expand Down Expand Up @@ -235,6 +238,8 @@ def _calculate_attributions( # type: ignore
ref_position_ids=self.ref_position_ids,
token_type_ids=self.token_type_ids,
ref_token_type_ids=self.ref_token_type_ids,
internal_batch_size=self.internal_batch_size,
n_steps=self.n_steps,
)
if self.include_hypothesis:
lig.summarize()
Expand All @@ -249,6 +254,8 @@ def __call__(
embedding_type: int = 0,
hypothesis_template="this text is about {} .",
include_hypothesis: bool = False,
internal_batch_size: int = None,
n_steps: int = None,
) -> dict:
"""
Calculates attribution for `text` using the model and
Expand Down Expand Up @@ -285,10 +292,21 @@ def __call__(
Defaults to "this text is about {} .".
include_hypothesis (bool, optional): Alternative option to include hypothesis text in attributions
and visualization. Defaults to False.
internal_batch_size (int, optional): Divides total #steps * #examples
data points into chunks of size at most internal_batch_size,
which are computed (forward / backward passes)
sequentially. If internal_batch_size is None, then all evaluations are
processed in one batch.
n_steps (int, optional): The number of steps used by the approximation
method. Default: 50.
Returns:
list: List of tuples containing words and their associated attribution scores.
"""

if n_steps:
self.n_steps = n_steps
if internal_batch_size:
self.internal_batch_size = internal_batch_size
self.attributions = []
self.pred_probs = []
self.include_hypothesis = include_hypothesis
Expand Down

0 comments on commit deb65f3

Please sign in to comment.