Skip to content

Commit

Permalink
dataloader-fix: small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vitaly Protasov committed Mar 23, 2024
1 parent 707c852 commit 20a276c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 13 deletions.
24 changes: 13 additions & 11 deletions probing/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ def _get_embeddings_by_layers(
aggregation_embeddings: AggregationType,
) -> List[torch.Tensor]:
layers_outputs = []
for output in model_outputs[1:]: # type: ignore
if len(model_outputs) == 1:
process_outputs = model_outputs
else:
process_outputs = model_outputs[1:]
for output in process_outputs: # type: ignore
if aggregation_embeddings == AggregationType("first"):
sent_vector = output[:, 0, :] # type: ignore
elif aggregation_embeddings == AggregationType("last"):
Expand Down Expand Up @@ -255,11 +259,13 @@ def model_layers_forward(
return_dict=self.return_dict,
)

model_outputs = (
model_outputs["hidden_states"]
if "hidden_states" in model_outputs
else model_outputs["encoder_hidden_states"]
)
if "hidden_states" in model_outputs:
model_outputs = model_outputs["hidden_states"]
elif "last_hidden_state" in model_outputs:
model_outputs = model_outputs["last_hidden_state"]
else:
model_outputs = model_outputs["encoder_hidden_states"]

layers_outputs = self._get_embeddings_by_layers(
model_outputs, aggregation_embeddings=aggregation_embeddings
)
Expand Down Expand Up @@ -357,18 +363,14 @@ def get_encoded_dataloaders(
verbose: bool = True,
do_control_task: bool = False,
) -> Tuple[Dict[Literal["tr", "va", "te"], DataLoader], Dict[str, int]]:
# if self.tokenizer.model_max_length > self.model_max_length:
# logger.warning(
# f"In tokenizer model_max_length = {self.tokenizer.model_max_length}. Changed to {self.model_max_length} for preventing Out-Of-Memory."
# )
if self.Caching is None:
if self.tokenizer is None:
raise RuntimeError("Tokenizer is None")
self.Caching = Cacher(tokenizer=self.tokenizer, cache={})

tokenized_datasets = self.get_tokenized_datasets(task_dataset)
encoded_dataloaders = {}
for stage, _ in tokenized_datasets.items():
for stage in tokenized_datasets:
stage_dataloader_tokenized = DataLoader(
tokenized_datasets[stage], batch_size=encoding_batch_size
)
Expand Down
9 changes: 7 additions & 2 deletions probing/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,18 @@ def run(
do_control_task=do_control_task,
)

if self.probing_type == ProbingType.LAYERWISE:
num_layers_to_test = self.transformer_model.config.num_hidden_layers
elif self.probing_type == ProbingType.SINGLERUN:
num_layers_to_test = 1

probing_iter_range = (
trange(
self.transformer_model.config.num_hidden_layers,
num_layers_to_test,
desc="Probing by layers",
)
if verbose
else range(self.transformer_model.config.num_hidden_layers)
else range(num_layers_to_test)
)

for layer in probing_iter_range:
Expand Down
1 change: 1 addition & 0 deletions probing/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class ClassifierType(str, Enum):

class ProbingType(str, Enum):
LAYERWISE = "layerwise"
SINGLERUN = "singlerun"


class AggregationType(str, Enum):
Expand Down

0 comments on commit 20a276c

Please sign in to comment.