Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RWKV5] Add support for RWKV5 model #29095

Draft
wants to merge 102 commits into
base: main
Choose a base branch
from
Draft

[RWKV5] Add support for RWKV5 model #29095

wants to merge 102 commits into from

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Feb 19, 2024

What does this PR do?

Adds RWKV5, superseeds #26963

@BBuf
Copy link

BBuf commented Mar 28, 2024

Yeah, we can just set to infinite, juste used 500 for a quick fix, I don't know if the original one has a limit as wel

I think it would be a good idea to set it to infinite, because rwkv don't have sequence length limit in theory.

@ArthurZucker
Copy link
Collaborator Author

Okay my slow tests are all green for the tokenizer time to focus on the model!

@ArthurZucker
Copy link
Collaborator Author

The fast CUDA path works thanks to @kashif , but the cpu does not yet

@JL-er
Copy link

JL-er commented Apr 4, 2024

image
I set trust_remote_code=True but still get an error.
image
code:
python3 utils/prepare_dataset.py -i JeanKaddour/minipile -o /users/aigclab/copilot/data/Msample

import click

import datasets
from transformers import AutoTokenizer

import torch


def _rechunk_tokenize(rechunk_size: int, input_column: str, output_column: str, examples):
    tokenizer = AutoTokenizer.from_pretrained(
        "RWKV/rwkv-6-world-1b6", trust_remote_code=True
    )

    special_token = torch.tensor([0], dtype=torch.long)
    seqs = []
    for e in examples[input_column]:
        seq = tokenizer(e, padding=False, truncation=False, return_tensors="pt")
        seqs.append(seq.input_ids[0])
        seqs.append(special_token)
    seqs = torch.cat(seqs)
    rechunked = seqs[: (seqs.size(0) // rechunk_size) * rechunk_size].view(
        -1, rechunk_size
    )
    return {output_column: rechunked}


@click.command()
@click.option("--rechunk_size", default=513, help="Rechunk size for the dataset")
@click.option("--input_column", default="text", help="Column to tokenize")
@click.option(
    "--output_column", default="input_ids", help="Output column for the tokenized dataset"
)
@click.option(
    "-i",
    "--input_name",
    help="HuggingFace dataset name to tokenize, accept format:"
    + '"dataset_name" or "json:file_a,file_b,..."',
)
@click.option("-o", "--output_dir", help="Output directory for the tokenized dataset")
def main(rechunk_size, input_column, output_column, input_name, output_dir):
    print(f"Tokenizing HuggingFace dataset {input_name} to locally saved {output_dir}")
    if ":" in input_name:
        input_name, input_data_file = input_name.split(":")
        if "," in input_data_file:
            input_data_file = input_data_file.split(",")
    else:
        input_data_file = None
    dataset = datasets.load_dataset(
        input_name, data_files=input_data_file, trust_remote_code=True
    )
    dataset.shuffle().flatten_indices(num_proc=8).map(
        lambda x: _rechunk_tokenize(rechunk_size, input_column, output_column, x),
        batched=True,
        remove_columns=dataset["train"].column_names,
        num_proc=8,
    ).save_to_disk(output_dir, num_proc=8)


if __name__ == "__main__":
    main()

@JL-er
Copy link

JL-er commented Apr 4, 2024

One month ago, there was no problem.

@SmerkyG
Copy link

SmerkyG commented Apr 4, 2024

Important: maybe less problematic for v5 (or maybe not!), but I found that for v6 the following line is absolutely terrible for inference accuracy:

out = self.ln_x(rwkv.to(hidden.dtype)).view(batch, seq_length, -1)

verus the original @BBuf version (which I tweaked and adapted to v6):

out = F.group_norm(out / self.config.head_size_divisor, num_groups=H, weight=self.ln_x.weight.to(out.dtype), bias=self.ln_x.bias.to(out.dtype), eps=self.ln_x.eps).reshape(B, T, H * S)
out = out.to(dtype=hidden.dtype) * gate

The issue is that the potential down-cast to bf16 prior to the groupnorm causes really bad inference quality. If you look closely this is written differently than in the original @BBuf version where the down-cast occurs after the groupnorm during non-cuda inference.

Please see these lines of Bo Peng's original ChatRWKV code for reference about this groupnorm needing float32:
https://github.com/BlinkDL/ChatRWKV/blob/28ed01a8423842c3082f668922a1b45ac182dff0/rwkv_pip_package/src/rwkv/model.py#L377
https://github.com/BlinkDL/ChatRWKV/blob/28ed01a8423842c3082f668922a1b45ac182dff0/rwkv_pip_package/src/rwkv/model.py#L669

@ArthurZucker
Copy link
Collaborator Author

Will update the group norm!

@BBuf
Copy link

BBuf commented Apr 6, 2024

image I set trust_remote_code=True but still get an error. image code: python3 utils/prepare_dataset.py -i JeanKaddour/minipile -o /users/aigclab/copilot/data/Msample

import click

import datasets
from transformers import AutoTokenizer

import torch


def _rechunk_tokenize(rechunk_size: int, input_column: str, output_column: str, examples):
    tokenizer = AutoTokenizer.from_pretrained(
        "RWKV/rwkv-6-world-1b6", trust_remote_code=True
    )

    special_token = torch.tensor([0], dtype=torch.long)
    seqs = []
    for e in examples[input_column]:
        seq = tokenizer(e, padding=False, truncation=False, return_tensors="pt")
        seqs.append(seq.input_ids[0])
        seqs.append(special_token)
    seqs = torch.cat(seqs)
    rechunked = seqs[: (seqs.size(0) // rechunk_size) * rechunk_size].view(
        -1, rechunk_size
    )
    return {output_column: rechunked}


@click.command()
@click.option("--rechunk_size", default=513, help="Rechunk size for the dataset")
@click.option("--input_column", default="text", help="Column to tokenize")
@click.option(
    "--output_column", default="input_ids", help="Output column for the tokenized dataset"
)
@click.option(
    "-i",
    "--input_name",
    help="HuggingFace dataset name to tokenize, accept format:"
    + '"dataset_name" or "json:file_a,file_b,..."',
)
@click.option("-o", "--output_dir", help="Output directory for the tokenized dataset")
def main(rechunk_size, input_column, output_column, input_name, output_dir):
    print(f"Tokenizing HuggingFace dataset {input_name} to locally saved {output_dir}")
    if ":" in input_name:
        input_name, input_data_file = input_name.split(":")
        if "," in input_data_file:
            input_data_file = input_data_file.split(",")
    else:
        input_data_file = None
    dataset = datasets.load_dataset(
        input_name, data_files=input_data_file, trust_remote_code=True
    )
    dataset.shuffle().flatten_indices(num_proc=8).map(
        lambda x: _rechunk_tokenize(rechunk_size, input_column, output_column, x),
        batched=True,
        remove_columns=dataset["train"].column_names,
        num_proc=8,
    ).save_to_disk(output_dir, num_proc=8)


if __name__ == "__main__":
    main()

Bug fixed by BBuf/RWKV-World-HF-Tokenizer@6dd44c8 , it has no relation with this pr, I will update hf repo RWKV/rwkv-6-world-1b6 later .

@BBuf
Copy link

BBuf commented Apr 22, 2024

BTW @BBuf don't you think it would be great to have a seperate github repo with installables for the kernels? (you can track usage, easier to maintain and propagate to here!) WDYT?

It has been solved in https://huggingface.co/RWKV/rwkv-5-world-1b5/blob/main/modeling_rwkv5.py , we need pip install flash-rwkv first, modeling_rwkv5.py in this pr can be replaced directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants