Skip to content

Commit

Permalink
added if statement
Browse files Browse the repository at this point in the history
  • Loading branch information
Vitaly Protasov committed Mar 23, 2024
1 parent 20a276c commit 3a184b1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions probing/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from time import time
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, get_args

import numpy as np
import torch
Expand Down Expand Up @@ -170,7 +170,7 @@ def run(
save_checkpoints: bool = False,
verbose: bool = True,
) -> None:
if path_to_task_file:
if path_to_task_file or probe_task in get_args(UDProbingTaskName):
task_data = TextFormer(probe_task, path_to_task_file)
task_dataset, num_classes = task_data.samples, len(task_data.unique_labels)
probing_task_language, probing_task_category = lang_category_extraction(
Expand All @@ -191,7 +191,7 @@ def run(

clear_memory()
start_time = time()
if path_to_task_file:
if path_to_task_file or probe_task in get_args(UDProbingTaskName):
(
probing_dataloaders,
mapped_labels,
Expand Down

0 comments on commit 3a184b1

Please sign in to comment.