From fc12593556706884d74af7f3e61151340db33efe Mon Sep 17 00:00:00 2001 From: ZZZZkp Date: Fri, 5 Apr 2024 21:43:46 +0800 Subject: [PATCH] Use datasets instead of nlp. And add requirements.txt. --- data/squad_multitask/squad_multitask.py | 22 +++++++++++----------- prepare_data.py | 6 +++--- requirements.txt | 4 ++++ 3 files changed, 18 insertions(+), 14 deletions(-) create mode 100644 requirements.txt diff --git a/data/squad_multitask/squad_multitask.py b/data/squad_multitask/squad_multitask.py index b917002..3cd41f5 100644 --- a/data/squad_multitask/squad_multitask.py +++ b/data/squad_multitask/squad_multitask.py @@ -25,7 +25,7 @@ import nltk nltk.download('punkt') -import nlp +import datasets _CITATION = """\ @@ -56,7 +56,7 @@ ] -class SquadMultitaskConfig(nlp.BuilderConfig): +class SquadMultitaskConfig(datasets.BuilderConfig): """BuilderConfig for SQUAD.""" def __init__(self, qg_format="highlight", **kwargs): @@ -69,7 +69,7 @@ def __init__(self, qg_format="highlight", **kwargs): self.qg_format = qg_format -class SquadMultitask(nlp.GeneratorBasedBuilder): +class SquadMultitask(datasets.GeneratorBasedBuilder): """SQUAD: The Stanford Question Answering Dataset. Version 1.1.""" _URL = "https://rajpurkar.github.io/SQuAD-explorer/dataset/" @@ -79,7 +79,7 @@ class SquadMultitask(nlp.GeneratorBasedBuilder): BUILDER_CONFIGS = [ SquadMultitaskConfig( name=f"{format_}_qg_format", - version=nlp.Version("1.0.0", "New split API (https://tensorflow.org/datasets/splits)"), + version=datasets.Version("1.0.0", "New split API (https://tensorflow.org/datasets/splits)"), description="Plain text", qg_format=format_ ) @@ -87,13 +87,13 @@ class SquadMultitask(nlp.GeneratorBasedBuilder): ] def _info(self): - return nlp.DatasetInfo( + return datasets.DatasetInfo( description=_DESCRIPTION, - features=nlp.Features( + features=datasets.Features( { - "source_text": nlp.Value("string"), - "target_text": nlp.Value("string"), - "task": nlp.Value("string"), + "source_text": datasets.Value("string"), + "target_text": datasets.Value("string"), + "task": datasets.Value("string"), } ), # No default supervised_keys (as we have to pass both question @@ -111,8 +111,8 @@ def _split_generators(self, dl_manager): downloaded_files = dl_manager.download_and_extract(urls_to_download) return [ - nlp.SplitGenerator(name=nlp.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}), - nlp.SplitGenerator(name=nlp.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}), + datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}), + datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}), ] def _get_correct_alignement(self, context, answer): diff --git a/prepare_data.py b/prepare_data.py index 765b7aa..ce060fa 100644 --- a/prepare_data.py +++ b/prepare_data.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional import torch -import nlp +import datasets from transformers import T5Tokenizer, BartTokenizer, HfArgumentParser @@ -152,8 +152,8 @@ def main(): tokenizer.add_tokens(['', '']) - train_dataset = nlp.load_dataset(data_args.dataset_path, name=data_args.qg_format, split=nlp.Split.TRAIN) - valid_dataset = nlp.load_dataset(data_args.dataset_path, name=data_args.qg_format, split=nlp.Split.VALIDATION) + train_dataset = datasets.load_dataset(data_args.dataset_path, name=data_args.qg_format, split=datasets.Split.TRAIN) + valid_dataset = datasets.load_dataset(data_args.dataset_path, name=data_args.qg_format, split=datasets.Split.VALIDATION) processor = DataProcessor( tokenizer, diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b535730 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +transformers>=3.0.0 +nltk +torch +datasets>=2.12.0 \ No newline at end of file