Skip to content

Commit

Permalink
play with datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 20, 2023
1 parent 82ad103 commit d89c8d8
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 5 deletions.
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python3 $HOME/transformers/examples/pytorch/xla_spawn.py --num_cores ${TPU_NUM_DEVICES} wtpsplit/train/train.py $1
python3 ~/wtpsplit/xla_spawn.py --num_cores ${TPU_NUM_DEVICES} wtpsplit/train/train.py $1
25 changes: 21 additions & 4 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,9 @@ def main():
elif args.from_scratch:
backbone = LACanineForTokenClassification(config)
else:
backbone = LACanineForTokenClassification.from_pretrained(args.model_name_or_path,
ignore_mismatched_sizes=True,
config=config)
backbone = LACanineForTokenClassification.from_pretrained(
args.model_name_or_path, ignore_mismatched_sizes=True, config=config
)

model = Model(
backbone,
Expand All @@ -270,8 +270,15 @@ def prepare_dataset(
num_workers=1,
include_languages=None,
shuffle=False,
split="train",
):
dataset = load_dataset("parquet", data_files=path, split="train")
from datasets.download import DownloadConfig

dlconf = DownloadConfig(cache_dir="/home/Markus/.cache/huggingface/datasets")
dataset = load_dataset("markus583/mC4-TEST", split=split, download_config=dlconf)
# optional: delete downloaded dataset, it is stored in /dev/shm/cache now
# os.system("rm -rf /home/Markus/.cache/huggingface/datasets")

if include_languages is not None:
include_languages = set(include_languages)

Expand Down Expand Up @@ -443,6 +450,14 @@ def maybe_pad(text):
# a bit hacky but oh well, only drop if sentence
remove_columns=["ends_with_punctuation"] if args.text_column == "text" else [],
)

# used dataset is in cache now
# recursively remove the files in cache_dir starting with m_c4
for root, dirs, files in os.walk(os.environ.get("HF_DATASETS_CACHE")):
for file in files:
if file.startswith("m_c4"):
print(f"Removing {os.path.join(root, file)}")
os.remove(os.path.join(root, file))

return dataset

Expand All @@ -451,12 +466,14 @@ def maybe_pad(text):
num_workers=args.preprocessing_num_workers,
include_languages=args.include_languages,
shuffle=args.shuffle,
split="train",
)
valid_dataset = prepare_dataset(
args.valid_text_path,
num_workers=args.preprocessing_num_workers,
include_languages=args.include_languages,
shuffle=False,
split="valid",
)

eval_data = torch.load(
Expand Down
83 changes: 83 additions & 0 deletions xla_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A simple launcher script for TPU training
Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py
::
>>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE
YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
arguments of your training script)
"""


import importlib
import sys
from argparse import REMAINDER, ArgumentParser
from pathlib import Path

import torch_xla.distributed.xla_multiprocessing as xmp


def parse_args():
"""
Helper function parsing the command line options
@retval ArgumentParser
"""
parser = ArgumentParser(
description=(
"PyTorch TPU distributed training launch helper utility that will spawn up multiple distributed processes"
)
)

# Optional arguments for the launch helper
parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).")

# positional
parser.add_argument(
"training_script",
type=str,
help=(
"The full path to the single TPU training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script"
),
)

# rest from the training program
parser.add_argument("training_script_args", nargs=REMAINDER)

return parser.parse_args()


def main():
args = parse_args()

# Import training_script as a module.
script_fpath = Path(args.training_script)
sys.path.append(str(script_fpath.parent.resolve()))
mod_name = script_fpath.stem
mod = importlib.import_module(mod_name)

# Patch sys.argv
sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)]

xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)


if __name__ == "__main__":
main()

0 comments on commit d89c8d8

Please sign in to comment.