diff --git a/probing/encoder.py b/probing/encoder.py index 5e375ae..7ab9ccd 100644 --- a/probing/encoder.py +++ b/probing/encoder.py @@ -162,7 +162,11 @@ def _get_embeddings_by_layers( aggregation_embeddings: AggregationType, ) -> List[torch.Tensor]: layers_outputs = [] - for output in model_outputs[1:]: # type: ignore + if len(model_outputs) == 1: + process_outputs = model_outputs + else: + process_outputs = model_outputs[1:] + for output in process_outputs: # type: ignore if aggregation_embeddings == AggregationType("first"): sent_vector = output[:, 0, :] # type: ignore elif aggregation_embeddings == AggregationType("last"): @@ -255,11 +259,13 @@ def model_layers_forward( return_dict=self.return_dict, ) - model_outputs = ( - model_outputs["hidden_states"] - if "hidden_states" in model_outputs - else model_outputs["encoder_hidden_states"] - ) + if "hidden_states" in model_outputs: + model_outputs = model_outputs["hidden_states"] + elif "last_hidden_state" in model_outputs: + model_outputs = model_outputs["last_hidden_state"] + else: + model_outputs = model_outputs["encoder_hidden_states"] + layers_outputs = self._get_embeddings_by_layers( model_outputs, aggregation_embeddings=aggregation_embeddings ) @@ -357,10 +363,6 @@ def get_encoded_dataloaders( verbose: bool = True, do_control_task: bool = False, ) -> Tuple[Dict[Literal["tr", "va", "te"], DataLoader], Dict[str, int]]: - # if self.tokenizer.model_max_length > self.model_max_length: - # logger.warning( - # f"In tokenizer model_max_length = {self.tokenizer.model_max_length}. Changed to {self.model_max_length} for preventing Out-Of-Memory." - # ) if self.Caching is None: if self.tokenizer is None: raise RuntimeError("Tokenizer is None") @@ -368,7 +370,7 @@ def get_encoded_dataloaders( tokenized_datasets = self.get_tokenized_datasets(task_dataset) encoded_dataloaders = {} - for stage, _ in tokenized_datasets.items(): + for stage in tokenized_datasets: stage_dataloader_tokenized = DataLoader( tokenized_datasets[stage], batch_size=encoding_batch_size ) diff --git a/probing/pipeline.py b/probing/pipeline.py index 91e053c..a79d0dd 100644 --- a/probing/pipeline.py +++ b/probing/pipeline.py @@ -205,13 +205,18 @@ def run( do_control_task=do_control_task, ) + if self.probing_type == ProbingType.LAYERWISE: + num_layers_to_test = self.transformer_model.config.num_hidden_layers + elif self.probing_type == ProbingType.SINGLERUN: + num_layers_to_test = 1 + probing_iter_range = ( trange( - self.transformer_model.config.num_hidden_layers, + num_layers_to_test, desc="Probing by layers", ) if verbose - else range(self.transformer_model.config.num_hidden_layers) + else range(num_layers_to_test) ) for layer in probing_iter_range: diff --git a/probing/types.py b/probing/types.py index a3157e0..05652bf 100644 --- a/probing/types.py +++ b/probing/types.py @@ -20,6 +20,7 @@ class ClassifierType(str, Enum): class ProbingType(str, Enum): LAYERWISE = "layerwise" + SINGLERUN = "singlerun" class AggregationType(str, Enum):