Skip to content

Commit

Permalink
update table for rwkv4 support, fix process count for dataset (#822)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Nov 5, 2023
1 parent 6459ac7 commit cdc71f7
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 12 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Features:
| gpt-j ||||||||
| XGen ||||||||
| phi ||||||||
| RWKV ||||||||


## Quickstart ⚡
Expand Down
10 changes: 8 additions & 2 deletions src/axolotl/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import os
from typing import List
from typing import List, Optional

import torch
from datasets import Dataset, IterableDataset
Expand Down Expand Up @@ -30,14 +30,20 @@ def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
process_count: Optional[int] = None,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
self.process_count = process_count
super().__init__(self.process(dataset).data, **kwargs)

def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, os.cpu_count())
num_proc = (
min(64, self.process_count)
if self.process_count
else min(64, os.cpu_count())
)
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
Expand Down
40 changes: 30 additions & 10 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,14 @@ def get_dataset_wrapper(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy(
Expand All @@ -494,7 +498,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "explainchoice":
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
Expand All @@ -504,7 +510,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "concisechoice":
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
Expand All @@ -514,7 +522,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "summarizetldr":
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
Expand All @@ -524,7 +534,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "jeopardy":
dataset_prompter = JeopardyPrompter(d_prompt_style)
Expand All @@ -534,7 +546,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "oasst":
dataset_prompter = AlpacaPrompter(d_prompt_style)
Expand All @@ -544,7 +558,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "gpteacher":
dataset_prompter = GPTeacherPrompter(d_prompt_style)
Expand All @@ -554,7 +570,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "reflection":
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
Expand All @@ -564,7 +582,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
else:
suffix = ""
Expand Down

0 comments on commit cdc71f7

Please sign in to comment.