Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement gpt-2 forward language modeling extractor #464

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pliers/extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
VADERSentimentExtractor, SpaCyExtractor,
WordCounterExtractor, BertExtractor,
BertSequenceEncodingExtractor, BertLMExtractor,
BertSentimentExtractor)
BertSentimentExtractor, GPTForwardLMExtractor)
from .video import (FarnebackOpticalFlowExtractor)

__all__ = [
Expand Down Expand Up @@ -154,6 +154,7 @@
'BertSequenceEncodingExtractor',
'BertLMExtractor',
'BertSentimentExtractor',
'GPTForwardLMExtractor',
'AudiosetLabelExtractor',
'WordCounterExtractor',
'MetricExtractor',
Expand Down
153 changes: 153 additions & 0 deletions pliers/extractors/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,3 +855,156 @@ def _extract(self, stims):
return ExtractorResult(word_counter, stims, self,
features=self.features,
onsets=onsets, durations=durations)


class GPTForwardLMExtractor(ComplexTextExtractor):
''' Returns next word predictions for GPT models .
Args:
pretrained_model (str): A string specifying which transformer
model to use.
tokenizer (str): Type of tokenization used in the tokenization step.
If different from model, out-of-vocabulary tokens may be treated as
unknown tokens.
framework (str): name deep learning framework to use. Must be 'pt'
(PyTorch) or 'tf' (tensorflow). Defaults to 'pt'.
top_n (int): Specifies how many of the highest-probability tokens are
to be returned. Mutually exclusive with target and threshold.
target (str or list): Vocabulary token(s) for which probability is to
be returned. Tokens defined in the vocabulary change across
tokenizers. Mutually exclusive with top_n and threshold.
threshold (float): If defined, only values above this threshold will
be returned. Mutually exclusive with top_n and target.
return_softmax (bool): if True, returns probability scores instead of
raw predictions.
return_true (bool): if True, returns true_token and its probability.
return_input (bool): whether to return input sequence
onset (str): whether the onset in the result is the one from
the target word ('target') or from the last word in the
context ('last_context')
model_kwargs (dict): Named arguments for pretrained model.
tokenizer_kwargs (dict): Named arguments for tokenizer.
'''

_log_attributes = ('pretrained_model', 'framework', 'top_n', 'target',
'threshold', 'tokenizer_type', 'return_softmax', 'return_true_word',
'return_true_token', 'return_input', 'return_context', 'onset')
_model_attributes = ('pretrained_model', 'framework', 'top_n',
'target', 'threshold', 'tokenizer_type')

def __init__(self,
pretrained_model='gpt2',
tokenizer='gpt2',
model_class='GPT2LMHeadModel',
tokenizer_class='GPT2TokenizerFast',
framework='pt',
top_n=None,
threshold=None,
target=None,
return_true_token=True,
return_true_word=False,
return_softmax=None,
return_input=True,
return_context=True,
onset='target',
model_kwargs=None,
tokenizer_kwargs=None):
verify_dependencies(['transformers'])
if framework not in ['pt', 'tf']:
raise(ValueError('''Invalid framework;
must be one of 'pt' (pytorch) or 'tf' (tensorflow)'''))
if onset not in ['target', 'last_context']:
raise(ValueError('''Onset must be one of
'target' or 'last_context'.'''))
self.pretrained_model = pretrained_model
self.tokenizer_type = tokenizer
self.model_class = model_class
self.framework = framework
self.model_kwargs = model_kwargs if model_kwargs else {}
self.tokenizer_kwargs = tokenizer_kwargs if tokenizer_kwargs else {}
model = model_class if self.framework == 'pt' else 'TF' + model_class
self.model = getattr(transformers, model).from_pretrained(
pretrained_model, **self.model_kwargs)
self.tokenizer = getattr(transformers, tokenizer_class).from_pretrained(
tokenizer, **self.tokenizer_kwargs)
self.target = listify(target)
if self.target:
missing = set(self.target) - set(self.tokenizer.vocab.keys())
if missing:
logging.warning(f'{missing} not in vocabulary. Dropping.')
present = set(self.target) & set(self.tokenizer.vocab.keys())
self.target = list(present)
if self.target == []:
raise ValueError('No valid target token. Import transformers'
' and run transformers.GPT2Tokenizer.from_pretrained'
f'(\'{tokenizer}\').vocab.keys() to see available tokens')
self.top_n = top_n
self.threshold = threshold
self.return_softmax = return_softmax
self.return_context = return_context
self.return_true_word = return_true_word
self.return_true_token = return_true_token
self.return_input = return_input
self.onset = onset
super().__init__()

def _preprocess(self, stims):
''' Tokenizes input and returns context and target info '''
els = [(e.text, e.onset, e.duration) for e in stims.elements]
wds, ons, dur = map(list, zip(*els))
c_wds, c_ons, c_dur = (l[:-1] for l in [wds,ons,dur]) # second last
c_tok = self.tokenizer.encode(' '.join(c_wds), return_tensors=self.framework)
stims.name = ' '.join(wds) if stims.name == '' else stims.name
t_wds = ' ' + wds[-1]
t_id = self.tokenizer.encode(t_wds, return_tensors=self.framework)[0,0]
t_tok = self.tokenizer.decode(t_id)
return ((c_ons, c_dur, c_tok, c_wds),
(t_id, t_tok, t_wds, ons[-1], dur[-1]))

def _extract(self, stims):
c_outs, t_outs = self._preprocess(stims)
c_ons, c_dur, c_tok, c_wds = c_outs
t_id, t_tok, t_wds, t_ons, t_dur = t_outs
outputs = self.model(c_tok)
if self.framework == 'pt':
preds = outputs.logits[0,-1,:].detach().numpy()
else:
preds = outputs.logits[0,-1,:].numpy()
if self.return_softmax:
preds = scipy.special.softmax(preds, axis=-1)
out_idx = preds.argsort()[::-1]
if self.top_n:
sub_idx = out_idx[:self.top_n]
elif self.target:
sub_idx = self.tokenizer.convert_tokens_to_ids(self.target)
elif self.threshold:
sub_idx = np.where(preds >= self.threshold)[0]
else:
sub_idx = out_idx
out_idx = [idx for idx in out_idx if idx in sub_idx]
feat = [self.tokenizer.decode(o) for o in out_idx]
data = [listify(float(p)) for p in preds[out_idx]]
if self.return_true_token:
feat += ['true_token', 'true_token_score']
data += [t_tok, float(preds[t_id])]
if self.return_true_word:
feat += ['true_word']
data += [t_wds]
if self.return_context:
feat += ['lm_context']
data += [' '.join(c_wds)]
if self.return_input:
feat += ['lm_sequence']
data += [stims.name]
if self.onset == 'target':
ons = listify(t_ons)
dur = listify(t_dur)
else:
ons = listify(c_ons[-1])
dur = listify(c_dur[-1])
return ExtractorResult(data, stims, self,
features=feat, onsets=ons, durations=dur)

def _to_df(self, result):
res_df = pd.DataFrame(dict(zip(result.features, result._data)))
res_df['object_id'] = range(res_df.shape[0])
return res_df