diff --git a/pliers/extractors/models.py b/pliers/extractors/models.py index ed7cb7a2..672e32d5 100644 --- a/pliers/extractors/models.py +++ b/pliers/extractors/models.py @@ -51,6 +51,9 @@ class TFHubExtractor(Extractor): output for compatibility with extractor result transform_inp (optional): function to transform Stim.data for compatibility with model input format + output_key (str): key to desired in output in + dictionary. Set to None if the output is not a dictionary, + or to output all keys in dictionary. keras_kwargs (dict): arguments to hub.KerasLayer call ''' @@ -59,12 +62,12 @@ class TFHubExtractor(Extractor): def __init__(self, url_or_path, features=None, transform_out=None, transform_inp=None, - keras_kwargs=None): + output_key=None, keras_kwargs=None): verify_dependencies(['tensorflow_hub']) if keras_kwargs is None: keras_kwargs = {} self.keras_kwargs = keras_kwargs - + self.output_key = output_key self.model = hub.KerasLayer(url_or_path, **keras_kwargs) self.url_or_path = url_or_path @@ -74,11 +77,18 @@ def __init__(self, url_or_path, features=None, super().__init__() def get_feature_names(self, out): + # Manual feature names always take precedence if self.features: return listify(self.features) + # Infer feature names from output else: + # If dict, use provided output key, or all keys if isinstance(out, dict): - return list(out.keys()) + if self.output_key: + return [self.output_key] + else: + return list(out.keys()) + # Worst case, use generic feature names else: return ['feature_' + str(i) for i in range(out.shape[-1])] @@ -93,10 +103,27 @@ def _preprocess(self, stim): return stim.data def _postprocess(self, out): + # If key is provided, return only that key + if self.output_key: + try: + out = out[self.output_key] + except KeyError: + raise ValueError(f'{self.output_key} is not a valid key.' + 'Check which keys are available in the output ' + 'at the model URL ({self.url_or_path})') + except (IndexError, TypeError): + raise ValueError(f'Model output is not a dictionary. ' + 'Try initialize the extractor with output_key=None.') + + # If output is a dict and no output key, return all keys + if isinstance(out, dict): + out = np.vstack(list(out.values())).T + + # Always squeeze last dimension if it is 1 + out = out.numpy().squeeze() + if self.transform_out: out = self.transform_out(out) - if not isinstance(out, np.ndarray): - out = out.numpy().squeeze() return out def _get_timing(self, out, stim): @@ -116,12 +143,7 @@ def _get_timing(self, out, stim): def _extract(self, stim): inp = self._preprocess(stim) out = self.model(inp) - features = self.get_feature_names(out) - - if isinstance(out, dict): - out = np.vstack(list(out.values())).T - out = self._postprocess(out) onsets, durations, orders = self._get_timing(out, stim) @@ -232,7 +254,11 @@ class TFHubTextExtractor(TFHubExtractor): The number of items must match the number of features in the model output. For example, if a text encoder outputting 768-dimensional encoding is passed - (e.g. base BERT), this must be a list containing 768 items. + output_key (str): key to desired embedding in output + dictionary (see documentation at + https://www.tensorflow.org/hub/common_saved_model_apis/text). + Set to None is the output is not a dictionary, or to + output all keys (e.g. base BERT), this must be a list containing 768 items. Each dimension in the model output will be returned as a separate feature in the ExtractorResult. Alternatively, the model output can be packed into a single @@ -246,7 +272,8 @@ class TFHubTextExtractor(TFHubExtractor): output_key (str): key to desired embedding in output dictionary (see documentation at https://www.tensorflow.org/hub/common_saved_model_apis/text). - Set to None is the output is not a dictionary. + Set to None is the output is not a dictionary, or to + output all keys preprocessor_url_or_path (str): if the model requires preprocessing through another TFHub model, specifies the url or path to the preprocessing module. Information on @@ -266,7 +293,7 @@ def __init__(self, preprocessor_kwargs=None, keras_kwargs=None, **kwargs): - super().__init__(url_or_path, features, + super().__init__(url_or_path, features, output_key=output_key, keras_kwargs=keras_kwargs, **kwargs) self.output_key = output_key @@ -287,22 +314,6 @@ def _preprocess(self, stim): self.preprocessor_url_or_path, **self.preprocessor_kwargs) x = preprocessor(x) return x - - def _postprocess(self, out): - if not self.output_key: - return out.numpy().squeeze() - else: - try: - return out[self.output_key].numpy().squeeze() - except KeyError: - raise ValueError(f'{self.output_key} is not a valid key.' - 'Check which keys are available in the output ' - 'embedding dictionary in TFHub docs ' - '(https://www.tensorflow.org/hub/common_saved_model_apis/text)' - f' or at the model URL ({self.url_or_path})') - except (IndexError, TypeError): - raise ValueError(f'Model output is not a dictionary. ' - 'Try initialize the extractor with output_key=None.') class TensorFlowKerasApplicationExtractor(ImageExtractor):