From c40faf8bc6d20f4c8151965aa7e3f2f02b507670 Mon Sep 17 00:00:00 2001 From: ofivite <20295221+ofivite@users.noreply.github.com> Date: Wed, 5 Jun 2024 17:10:43 +0200 Subject: [PATCH] replace SLURM_JOB_ID with SHA --- llmfoundry/data/finetuning/dataloader.py | 7 +++++-- llmfoundry/data/finetuning/tasks.py | 7 +++++-- llmfoundry/models/hf/hf_causal_lm.py | 7 +++++-- llmfoundry/utils/builders.py | 7 +++++-- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 6e9e6d7ca8..5b14680820 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import logging import os +import random +import hashlib from typing import Any, Dict, Optional, Tuple, Union import torch @@ -532,10 +534,11 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str: # Since we don't know exactly what the extension will be, since it is one of a list # use a signal file to wait for instead of the desired file - slurm_job_id = int(os.getenv('SLURM_JOB_ID', -1)) + random_number = random.randint(0, 999999) + sha_signature = hashlib.sha256(str(random_number).encode()).hexdigest() signal_file_path = os.path.join( finetune_dir, - f'.node_{dist.get_node_rank()}_slurm_job_id{slurm_job_id}_local_rank0_completed', + f'.sha_{sha_signature}_completed', ) if dist.get_local_rank() == 0: try: diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 3abac3ed4e..bd2b35d17f 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -34,6 +34,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import importlib import logging import os +import random +import hashlib import warnings from collections.abc import Mapping from functools import partial @@ -814,8 +816,9 @@ def build_from_hf( Returns: Dataset: The tokenized dataset. """ - slurm_job_id = int(os.getenv('SLURM_JOB_ID', -1)) - signal_file_path = f'.node_{dist.get_node_rank()}_slurm_job_id{slurm_job_id}_local_rank0_data_prep_completed' + random_number = random.randint(0, 999999) + sha_signature = hashlib.sha256(str(random_number).encode()).hexdigest() + signal_file_path = f'.sha_{sha_signature}_data_prep_completed' # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing. # Once local rank 0 is done, the datasets are all cached on disk, and all other ranks diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index a698425523..d43cc41f79 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -5,6 +5,8 @@ import logging import os +import random +import hashlib import warnings from typing import ( TYPE_CHECKING, @@ -339,8 +341,9 @@ def _autoset_attn_implementation_monkeypatch( f'init_device="{init_device}" must be either "cpu" or "meta".', ) - slurm_job_id = int(os.getenv('SLURM_JOB_ID', -1)) - signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_slurm_job_id{slurm_job_id}_completed' + random_number = random.randint(0, 999999) + sha_signature = hashlib.sha256(str(random_number).encode()).hexdigest() + signal_file_path = f'.sha_{sha_signature}_completed' if dist.get_local_rank() == 0: with open(signal_file_path, 'wb') as f: f.write(b'local_rank0_completed_download') diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 01327a8a37..bfab5ab9b7 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -5,6 +5,8 @@ import functools import logging import os +import random +import hashlib import re from collections import OrderedDict from typing import ( @@ -453,8 +455,9 @@ def build_tokenizer( os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' os.environ['TOKENIZERS_PARALLELISM'] = 'false' - slurm_job_id = int(os.getenv('SLURM_JOB_ID', -1)) - signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_slurm_job_id{slurm_job_id}_completed_tokenizer_setup' + random_number = random.randint(0, 999999) + sha_signature = hashlib.sha256(str(random_number).encode()).hexdigest() + signal_file_path = f'.sha_{sha_signature}_completed_tokenizer_setup' if dist.is_available() and dist.is_initialized( ) and dist.get_world_size() > 1: