Skip to content

Commit

Permalink
Fix lints
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohanzhan-db committed Jan 8, 2024
1 parent 4651be7 commit 5cd6a94
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 107 deletions.
169 changes: 97 additions & 72 deletions scripts/data_prep/validate_and_tokenize_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Databricks notebook source
# MAGIC %md
# MAGIC JIRA: https://databricks.atlassian.net/jira/software/c/projects/STR/issues/STR-141?filter=allissues
Expand Down Expand Up @@ -73,17 +76,16 @@

import os
import re
from enum import Enum
from composer.utils import (ObjectStore, maybe_create_object_store_from_uri, parse_uri)
from torch.utils.data import DataLoader
from streaming import StreamingDataset
import numpy as np
from argparse import ArgumentParser, Namespace
from typing import Tuple, Union

from composer.utils import (ObjectStore, maybe_create_object_store_from_uri,
parse_uri)
from datasets import get_dataset_split_names
from huggingface_hub import dataset_info
from omegaconf import OmegaConf as om
from argparse import Namespace
from typing import Union, Tuple

from llmfoundry.utils import build_tokenizer
from huggingface_hub import dataset_info
from datasets import get_dataset_split_names

# COMMAND ----------

Expand All @@ -93,23 +95,28 @@
# COMMAND ----------

FT_API_args = Namespace(
model = 'EleutherAI/gpt-neox-20b',
train_data_path = 'tatsu-lab/alpaca', # 'mosaicml/dolly_hhrlhf/train', # tatsu-lab/alpaca/train',
save_folder = 'dbfs:/databricks/mlflow-tracking/EXPERIMENT_ID/RUN_ID/artifacts/checkpoints',
task_type = "INSTRUCTION_FINETUNE",
eval_data_path = None,
eval_prompts = None,
custom_weights_path = None,
training_duration = None,
learning_rate = None,
context_length = 2048,
experiment_trackers = None,
disable_credentials_check = None,
model='EleutherAI/gpt-neox-20b',
train_data_path=
'tatsu-lab/alpaca', # 'mosaicml/dolly_hhrlhf/train', # tatsu-lab/alpaca/train',
save_folder=
'dbfs:/databricks/mlflow-tracking/EXPERIMENT_ID/RUN_ID/artifacts/checkpoints',
task_type='INSTRUCTION_FINETUNE',
eval_data_path=None,
eval_prompts=None,
custom_weights_path=None,
training_duration=None,
learning_rate=None,
context_length=2048,
experiment_trackers=None,
disable_credentials_check=None,
# Extra argument to add to FT API
# See comment https://databricks.atlassian.net/browse/STR-141?focusedCommentId=4308948
data_prep_config = {'data_validation': True, 'data_prep': False},
timeout = 10,
future = False,
data_prep_config={
'data_validation': True,
'data_prep': False
},
timeout=10,
future=False,
)

os.environ['HF_ASSETS_CACHE'] = '/tmp/'
Expand All @@ -131,14 +138,12 @@

import logging
import math
import os
import tempfile
from argparse import ArgumentParser, Namespace
from argparse import Namespace
from concurrent.futures import ProcessPoolExecutor
from glob import glob
from typing import Iterable, List, Tuple, cast

import psutil
from composer.utils import (ObjectStore, maybe_create_object_store_from_uri,
parse_uri)
from streaming import MDSWriter
Expand All @@ -153,27 +158,33 @@
DONE_FILENAME = '.text_to_mds_conversion_done'


def parse_args( tokenizer,
concat_tokens,
output_folder,
input_folder,
compression = 'zstd',
bos_text = '',
eos_text = '',
no_wrap = False ,
processes = 32, # min(max(psutil.cpu_count() - 2, 1), 32),
reprocess = False ) -> Namespace:

parsed = Namespace(tokenizer = tokenizer,
concat_tokens = model_max_length,
output_folder = output_folder,
input_folder = input_folder,
eos_text = eos_text,
bos_text = bos_text,
no_wrap = no_wrap,
compression = compression,
processes = processes,
reprocess = reprocess)
def parse_args(
tokenizer: str,
concat_tokens: int,
output_folder: str,
input_folder: str,
compression: str = 'zstd',
bos_text: str = '',
eos_text: str = '',
no_wrap: bool = False,
processes: int = 32, # min(max(psutil.cpu_count() - 2, 1), 32),
reprocess: bool = False
) -> Namespace:

parser = ArgumentParser(
description=
'Convert text files into MDS format, optionally concatenating and tokenizing',
)
parsed = Namespace(tokenizer=tokenizer,
concat_tokens=concat_tokens,
output_folder=output_folder,
input_folder=input_folder,
eos_text=eos_text,
bos_text=bos_text,
no_wrap=no_wrap,
compression=compression,
processes=processes,
reprocess=reprocess)

# Make sure we have needed concat options
if (parsed.concat_tokens is not None and
Expand Down Expand Up @@ -505,10 +516,12 @@ def _args_str(original_args: Namespace) -> str:

# COMMAND ----------

from streaming.base.storage.upload import CloudUploader
from streaming.base.storage.download import download_file
import json

from streaming.base.storage.download import download_file
from streaming.base.storage.upload import CloudUploader


def integrity_check(out: Union[str, Tuple[str, str]]):
"""Check if the index file has integrity.
Expand Down Expand Up @@ -540,34 +553,37 @@ def count_shards(mds_root: str):
actual_n_shard_files = count_shards(cu.local)

merged_index = json.load(open(local_merged_index_path, 'r'))
n_shard_files = len({b['raw_data']['basename'] for b in merged_index['shards']})
n_shard_files = len(
{b['raw_data']['basename'] for b in merged_index['shards']})
return n_shard_files == actual_n_shard_files

def check_HF_datasets(dataset_names_with_splits):
token = os.environ.get("HUGGING_FACE_HUB_TOKEN")

def check_HF_datasets(dataset_names_with_splits: list):
token = os.environ.get('HUGGING_FACE_HUB_TOKEN')
for dataset_name_with_split in dataset_names_with_splits:
dataset_name, split = os.path.split(dataset_name_with_split)
# make sure we have a dataset and split
if not dataset_name or not split:
return False, f"Failed to load Hugging Face dataset {dataset_name_with_split}. Please ensure that you include the split name (e.g. 'mosaicml/dolly_hhrlhf/train')."
# check user access to the dataset
try:
info = dataset_info(dataset_name)
_ = dataset_info(dataset_name)
except:
token_warning = ""
token_warning = ''
if not token:
token_warning = " If this is a private dataset, please set your HUGGING_FACE_HUB_TOKEN using: mcli create secret hf."
token_warning = ' If this is a private dataset, please set your HUGGING_FACE_HUB_TOKEN using: mcli create secret hf.'
return False, f"Failed to load Hugging Face dataset {dataset_name_with_split}. Please ensure that the dataset exists and that you have access to it. Remember to include the split name (e.g. 'mosaicml/dolly_hhrlhf/train')." + token_warning
# check that split exists
try:
splits = get_dataset_split_names(dataset_name)
except: # error raised in the case of multiple subsets
return False, f"Failed to load Hugging Face dataset {dataset_name_with_split}. Please make sure that the split is valid and that your dataset does not have subsets."
return False, f'Failed to load Hugging Face dataset {dataset_name_with_split}. Please make sure that the split is valid and that your dataset does not have subsets.'
if split not in splits:
return False, f"Failed to load Hugging Face dataset {dataset_name_with_split}. Split not found."
return True, ""
return False, f'Failed to load Hugging Face dataset {dataset_name_with_split}. Split not found.'
return True, ''

def is_hf_dataset_path(path):

def is_hf_dataset_path(path: str):
"""Check if a given string is a dataset path used by Hugging Face.
Args:
Expand All @@ -577,11 +593,12 @@ def is_hf_dataset_path(path):
bool: True if the string is a dataset path, False otherwise.
"""
# Regular expression to match the dataset path pattern
pattern = r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+/?(train|validation|test)?/?$"
pattern = r'^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+/?(train|validation|test)?/?$'

return bool(re.match(pattern, path))

def create_om_cfg(FT_API_args):

def create_om_cfg(FT_API_args: Namespace):
task_type = FT_API_args.task_type
train_data_path = FT_API_args.train_data_path
model = FT_API_args.model
Expand Down Expand Up @@ -631,14 +648,18 @@ def create_om_cfg(FT_API_args):

return cfg, tokenizer


# COMMAND ----------


# build cfg from the inputs
def main():
if FT_API_args.task_type == 'INSTRUCTION_FINETUNE':
# check if train_data_path is a valid HF dataset url with splits.
if not is_hf_dataset_path(FT_API_args.train_data_path):
raise ValueError(f"Input path {FT_API_args.train_data_path} is not supported. It needs to be a valid Huggingface dataset path.")
raise ValueError(
f'Input path {FT_API_args.train_data_path} is not supported. It needs to be a valid Huggingface dataset path.'
)
# load dataset.info and see if HF tokens are correctly set.
check_HF_datasets(FT_API_args.train_data_path)

Expand Down Expand Up @@ -669,16 +690,20 @@ def main():

# Check if the MDS dataset is integral by checking index.json
if integrity_check(args.output_folder):
raise RuntimeError(f"{args.output_folder} has mismatched number of shard files between merged index.json and actual shards!")
raise RuntimeError(
f'{args.output_folder} has mismatched number of shard files between merged index.json and actual shards!'
)

print("Converted data for continnued pre-training was saved in: ", args.output_folder)
print('Converted data for continnued pre-training was saved in: ',
args.output_folder)

else:
raise ValueError(f"task_type can only be INSTRUCTION_FINETUNE or Continued_Pretraining but got {FT_API_args.task_type} instead!")
# Run a few checks on resulted MDS datasets
# 1. no shards in output_folder
# 2. check shard completeness by downloading and inspecting index.json

raise ValueError(
f'task_type can only be INSTRUCTION_FINETUNE or Continued_Pretraining but got {FT_API_args.task_type} instead!'
)
# Run a few checks on resulted MDS datasets
# 1. no shards in output_folder
# 2. check shard completeness by downloading and inspecting index.json

from llmfoundry.data.finetuning import build_finetuning_dataloader
tokenizer_name = 'EleutherAI/gpt-neox-20b'
Expand All @@ -694,10 +719,10 @@ def main():
for batch in dataloader:
total_tokens += token_counting_func(batch)

print("Total number of tokens:", total_tokens)
print('Total number of tokens:', total_tokens)

# COMMAND ----------

# COMMAND ----------

if __name__ == '__main__':
main()
Loading

0 comments on commit 5cd6a94

Please sign in to comment.