Skip to content

Commit

Permalink
Fix: Add datasets cache to loading tags job runner (#2549)
Browse files Browse the repository at this point in the history
Add datasets cache to loading tags
  • Loading branch information
AndreaFrancis authored Mar 5, 2024
1 parent 35a1891 commit e4aac49
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
1 change: 1 addition & 0 deletions services/worker/src/worker/job_runner_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def _create_job_runner(self, job_info: JobInfo) -> JobRunner:
return DatasetLoadingTagsJobRunner(
job_info=job_info,
app_config=self.app_config,
hf_datasets_cache=self.hf_datasets_cache,
)

supported_job_types = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
DatasetTag,
LoadingCode,
)
from worker.job_runners.dataset.dataset_job_runner import DatasetJobRunner
from worker.job_runners.dataset.dataset_job_runner import DatasetJobRunnerWithDatasetsCache

NON_WORD_GLOB_SEPARATOR = f"[{NON_WORDS_CHARS}/]"
NON_WORD_REGEX_SEPARATOR = NON_WORD_GLOB_SEPARATOR.replace(".", "\.").replace("/", "\/")
Expand Down Expand Up @@ -621,7 +621,7 @@ def compute_loading_tags_response(dataset: str, hf_token: Optional[str] = None)
return DatasetLoadingTagsResponse(tags=tags, libraries=libraries)


class DatasetLoadingTagsJobRunner(DatasetJobRunner):
class DatasetLoadingTagsJobRunner(DatasetJobRunnerWithDatasetsCache):
@staticmethod
def get_job_type() -> str:
return "dataset-loading-tags"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DatasetLoadingTagsJobRunner,
get_builder_configs_with_simplified_data_files,
)
from worker.resources import LibrariesResource

from ..utils import REVISION_NAME, UpstreamResponse

Expand Down Expand Up @@ -207,6 +208,7 @@ def __init__(self, path: str = str(hf), target_protocol: str = "local", **kwargs

@pytest.fixture
def get_job_runner(
libraries_resource: LibrariesResource,
cache_mongo_resource: CacheMongoResource,
queue_mongo_resource: QueueMongoResource,
) -> GetJobRunner:
Expand All @@ -228,6 +230,7 @@ def _get_job_runner(
"difficulty": 20,
},
app_config=app_config,
hf_datasets_cache=libraries_resource.hf_datasets_cache,
)

return _get_job_runner
Expand Down

0 comments on commit e4aac49

Please sign in to comment.