Skip to content

Commit

Permalink
Merge branch 'main' into uc-hf
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Nov 9, 2023
2 parents e1f4891 + efaa545 commit 2f19d2c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
save_interval: Union[str, int, Time],
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = False,
overwrite: bool = True,
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
):
Expand Down
16 changes: 12 additions & 4 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,12 @@ def dataset_mapper(example: Dict):
num_proc=num_cpus_to_use,
desc='Tokenizing dataset',
)

def filter_long_prompts(example: Dict) -> bool:
return len(example['input_ids']) < max_seq_len

prompt_length_filtered_dataset = tokenized_dataset.filter(
lambda example: len(example['input_ids']) < max_seq_len,
filter_long_prompts,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)
Expand All @@ -376,10 +380,14 @@ def dataset_mapper(example: Dict):
)

pad_token_id = tokenizer.pad_token_id

def filter_empty_examples(example: Dict) -> bool:
return len(example['input_ids']) > 0 and len(
example['labels']) > 0 and any(
token_id != pad_token_id for token_id in example['labels'])

empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter(
lambda example: len(example['input_ids']) > 0 and len(example[
'labels']) > 0 and any(token_id != pad_token_id
for token_id in example['labels']),
filter_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out empty examples')

Expand Down
15 changes: 15 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def build_tokenizer(
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'

# Make sure the tokenizer files are downloaded and cached first by local rank 0
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

if tokenizer_name.startswith('tiktoken'):
tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs)
else:
Expand All @@ -202,6 +208,15 @@ def build_tokenizer(
int(1e30),
)

if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_tokenizer_setup')

dist.barrier()

if dist.get_local_rank() == 0:
os.remove(signal_file_path)

return tokenizer


Expand Down

0 comments on commit 2f19d2c

Please sign in to comment.