Skip to content

Commit

Permalink
Harmonize postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
adelavega committed Nov 15, 2022
1 parent 4637ce4 commit 4a6543c
Showing 1 changed file with 40 additions and 29 deletions.
69 changes: 40 additions & 29 deletions pliers/extractors/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class TFHubExtractor(Extractor):
output for compatibility with extractor result
transform_inp (optional): function to transform Stim.data
for compatibility with model input format
output_key (str): key to desired in output in
dictionary. Set to None if the output is not a dictionary,
or to output all keys in dictionary.
keras_kwargs (dict): arguments to hub.KerasLayer call
'''

Expand All @@ -59,12 +62,12 @@ class TFHubExtractor(Extractor):

def __init__(self, url_or_path, features=None,
transform_out=None, transform_inp=None,
keras_kwargs=None):
output_key=None, keras_kwargs=None):
verify_dependencies(['tensorflow_hub'])
if keras_kwargs is None:
keras_kwargs = {}
self.keras_kwargs = keras_kwargs

self.output_key = output_key
self.model = hub.KerasLayer(url_or_path, **keras_kwargs)

self.url_or_path = url_or_path
Expand All @@ -74,11 +77,18 @@ def __init__(self, url_or_path, features=None,
super().__init__()

def get_feature_names(self, out):
# Manual feature names always take precedence
if self.features:
return listify(self.features)
# Infer feature names from output
else:
# If dict, use provided output key, or all keys
if isinstance(out, dict):
return list(out.keys())
if self.output_key:
return [self.output_key]
else:
return list(out.keys())
# Worst case, use generic feature names
else:
return ['feature_' + str(i)
for i in range(out.shape[-1])]
Expand All @@ -93,10 +103,27 @@ def _preprocess(self, stim):
return stim.data

def _postprocess(self, out):
# If key is provided, return only that key
if self.output_key:
try:
out = out[self.output_key]
except KeyError:
raise ValueError(f'{self.output_key} is not a valid key.'
'Check which keys are available in the output '
'at the model URL ({self.url_or_path})')
except (IndexError, TypeError):
raise ValueError(f'Model output is not a dictionary. '
'Try initialize the extractor with output_key=None.')

# If output is a dict and no output key, return all keys
if isinstance(out, dict):
out = np.vstack(list(out.values())).T

# Always squeeze last dimension if it is 1
out = out.numpy().squeeze()

if self.transform_out:
out = self.transform_out(out)
if not isinstance(out, np.ndarray):
out = out.numpy().squeeze()
return out

def _get_timing(self, out, stim):
Expand All @@ -116,12 +143,7 @@ def _get_timing(self, out, stim):
def _extract(self, stim):
inp = self._preprocess(stim)
out = self.model(inp)

features = self.get_feature_names(out)

if isinstance(out, dict):
out = np.vstack(list(out.values())).T

out = self._postprocess(out)

onsets, durations, orders = self._get_timing(out, stim)
Expand Down Expand Up @@ -232,7 +254,11 @@ class TFHubTextExtractor(TFHubExtractor):
The number of items must match the number of features
in the model output. For example, if a text encoder
outputting 768-dimensional encoding is passed
(e.g. base BERT), this must be a list containing 768 items.
output_key (str): key to desired embedding in output
dictionary (see documentation at
https://www.tensorflow.org/hub/common_saved_model_apis/text).
Set to None is the output is not a dictionary, or to
output all keys (e.g. base BERT), this must be a list containing 768 items.
Each dimension in the model output will be returned as a
separate feature in the ExtractorResult.
Alternatively, the model output can be packed into a single
Expand All @@ -246,7 +272,8 @@ class TFHubTextExtractor(TFHubExtractor):
output_key (str): key to desired embedding in output
dictionary (see documentation at
https://www.tensorflow.org/hub/common_saved_model_apis/text).
Set to None is the output is not a dictionary.
Set to None is the output is not a dictionary, or to
output all keys
preprocessor_url_or_path (str): if the model requires
preprocessing through another TFHub model, specifies the
url or path to the preprocessing module. Information on
Expand All @@ -266,7 +293,7 @@ def __init__(self,
preprocessor_kwargs=None,
keras_kwargs=None,
**kwargs):
super().__init__(url_or_path, features,
super().__init__(url_or_path, features, output_key=output_key,
keras_kwargs=keras_kwargs,
**kwargs)
self.output_key = output_key
Expand All @@ -287,22 +314,6 @@ def _preprocess(self, stim):
self.preprocessor_url_or_path, **self.preprocessor_kwargs)
x = preprocessor(x)
return x

def _postprocess(self, out):
if not self.output_key:
return out.numpy().squeeze()
else:
try:
return out[self.output_key].numpy().squeeze()
except KeyError:
raise ValueError(f'{self.output_key} is not a valid key.'
'Check which keys are available in the output '
'embedding dictionary in TFHub docs '
'(https://www.tensorflow.org/hub/common_saved_model_apis/text)'
f' or at the model URL ({self.url_or_path})')
except (IndexError, TypeError):
raise ValueError(f'Model output is not a dictionary. '
'Try initialize the extractor with output_key=None.')


class TensorFlowKerasApplicationExtractor(ImageExtractor):
Expand Down

0 comments on commit 4a6543c

Please sign in to comment.