Skip to content

Commit

Permalink
replace SLURM_JOB_ID with SHA
Browse files Browse the repository at this point in the history
  • Loading branch information
ofivite committed Jun 5, 2024
1 parent 3e77715 commit c40faf8
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 8 deletions.
7 changes: 5 additions & 2 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import random
import hashlib
from typing import Any, Dict, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -532,10 +534,11 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
slurm_job_id = int(os.getenv('SLURM_JOB_ID', -1))
random_number = random.randint(0, 999999)
sha_signature = hashlib.sha256(str(random_number).encode()).hexdigest()
signal_file_path = os.path.join(
finetune_dir,
f'.node_{dist.get_node_rank()}_slurm_job_id{slurm_job_id}_local_rank0_completed',
f'.sha_{sha_signature}_completed',
)
if dist.get_local_rank() == 0:
try:
Expand Down
7 changes: 5 additions & 2 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import importlib
import logging
import os
import random
import hashlib
import warnings
from collections.abc import Mapping
from functools import partial
Expand Down Expand Up @@ -814,8 +816,9 @@ def build_from_hf(
Returns:
Dataset: The tokenized dataset.
"""
slurm_job_id = int(os.getenv('SLURM_JOB_ID', -1))
signal_file_path = f'.node_{dist.get_node_rank()}_slurm_job_id{slurm_job_id}_local_rank0_data_prep_completed'
random_number = random.randint(0, 999999)
sha_signature = hashlib.sha256(str(random_number).encode()).hexdigest()
signal_file_path = f'.sha_{sha_signature}_data_prep_completed'

# Non local rank 0 ranks will wait here for local rank 0 to finish the data processing.
# Once local rank 0 is done, the datasets are all cached on disk, and all other ranks
Expand Down
7 changes: 5 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import logging
import os
import random
import hashlib
import warnings
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -339,8 +341,9 @@ def _autoset_attn_implementation_monkeypatch(
f'init_device="{init_device}" must be either "cpu" or "meta".',
)

slurm_job_id = int(os.getenv('SLURM_JOB_ID', -1))
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_slurm_job_id{slurm_job_id}_completed'
random_number = random.randint(0, 999999)
sha_signature = hashlib.sha256(str(random_number).encode()).hexdigest()
signal_file_path = f'.sha_{sha_signature}_completed'
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')
Expand Down
7 changes: 5 additions & 2 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import functools
import logging
import os
import random
import hashlib
import re
from collections import OrderedDict
from typing import (
Expand Down Expand Up @@ -453,8 +455,9 @@ def build_tokenizer(
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

slurm_job_id = int(os.getenv('SLURM_JOB_ID', -1))
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_slurm_job_id{slurm_job_id}_completed_tokenizer_setup'
random_number = random.randint(0, 999999)
sha_signature = hashlib.sha256(str(random_number).encode()).hexdigest()
signal_file_path = f'.sha_{sha_signature}_completed_tokenizer_setup'

if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
Expand Down

0 comments on commit c40faf8

Please sign in to comment.