Skip to content

Commit

Permalink
[example] add grok-1 inference (#5485)
Browse files Browse the repository at this point in the history
* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme
  • Loading branch information
ver217 authored Mar 21, 2024
1 parent d158fc0 commit 848a574
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 0 deletions.
43 changes: 43 additions & 0 deletions examples/language/grok-1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Grok-1 Inference

## Install

```bash
# Make sure you install colossalai from the latest source code
git clone https://github.com/hpcaitech/ColossalAI.git
cd ColossalAI
pip install .
cd examples/language/grok-1
pip install -r requirements.txt
```

## Tokenizer preparation

You should download the tokenizer from the official grok-1 repository.

```bash
wget https://github.com/xai-org/grok-1/raw/main/tokenizer.model
```

## Inference

You need 8x A100 80GB or equivalent GPUs to run the inference.

We provide two scripts for inference. `run_inference_fast.sh` uses tensor parallelism provided by ColossalAI, and it is faster. `run_inference_slow.sh` uses auto device provided by transformers, and it is slower.

Command format:

```bash
./run_inference_fast.sh <model_name_or_path> <tokenizer_path>
./run_inference_slow.sh <model_name_or_path> <tokenizer_path>
```

`model_name_or_path` can be a local path or a model name from Hugging Face model hub. We provided weights on model hub, named `hpcaitech/grok-1`.

Command example:

```bash
./run_inference_fast.sh hpcaitech/grok-1 tokenizer.model
```

It will take 5-10 minutes to load checkpoints. Don't worry, it's not stuck.
99 changes: 99 additions & 0 deletions examples/language/grok-1/grok1_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Dict, Union

import torch.nn as nn

from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription


class Grok1Policy(Policy):
def config_sanity_check(self):
pass

def preprocess(self) -> nn.Module:
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
assert vocab_size % world_size == 0, f"vocab_size {vocab_size} must be divisible by world_size {world_size}"
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
decoder_submodule_replacement = [
SubModuleReplacementDescription(
suffix="attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attn.o_proj",
target_module=Linear1D_Row,
),
]
for i in range(self.model.config.num_experts):
decoder_submodule_replacement.extend(
[
SubModuleReplacementDescription(
suffix=f"moe_block.experts[{i}].linear",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix=f"moe_block.experts[{i}].linear_v",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix=f"moe_block.experts[{i}].linear_1",
target_module=Linear1D_Row,
),
]
)

policy["DecoderLayer"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=decoder_submodule_replacement,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
),
policy=policy,
target_key="Grok1Model",
)
return policy

def postprocess(self):
return self.model


class Grok1ModelPolicy(Grok1Policy):
pass


class Grok1ForCausalLMPolicy(Grok1Policy):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = super().module_policy()
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs={"gather_output": not self.shard_config.parallel_output},
),
policy=policy,
target_key="Grok1ModelForCausalLM",
)
return policy
32 changes: 32 additions & 0 deletions examples/language/grok-1/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import time

import torch
from sentencepiece import SentencePieceProcessor
from transformers import AutoModelForCausalLM
from utils import get_defualt_parser, inference, print_output

if __name__ == "__main__":
parser = get_defualt_parser()
args = parser.parse_args()
start = time.time()
torch.set_default_dtype(torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16,
)
sp = SentencePieceProcessor(model_file=args.tokenizer)
for text in args.text:
output = inference(
model,
sp,
text,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
print_output(text, sp.decode(output))
print(f"Overall time: {time.time() - start} seconds.")
50 changes: 50 additions & 0 deletions examples/language/grok-1/inference_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import time

import torch
from grok1_policy import Grok1ForCausalLMPolicy
from sentencepiece import SentencePieceProcessor
from transformers import AutoModelForCausalLM
from utils import get_defualt_parser, inference, print_output

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.utils import get_current_device

if __name__ == "__main__":
parser = get_defualt_parser()
args = parser.parse_args()
start = time.time()
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
plugin = HybridParallelPlugin(
tp_size=coordinator.world_size,
pp_size=1,
precision="bf16",
parallel_output=False,
custom_policy=Grok1ForCausalLMPolicy(),
)
booster = Booster(plugin=plugin)
torch.set_default_dtype(torch.bfloat16)
with LazyInitContext(default_device=get_current_device()):
model = AutoModelForCausalLM.from_pretrained(
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
)
model, *_ = booster.boost(model)
sp = SentencePieceProcessor(model_file=args.tokenizer)
for text in args.text:
output = inference(
model.unwrap(),
sp,
text,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
if coordinator.is_master():
print_output(text, sp.decode(output))
coordinator.print_on_master(f"Overall time: {time.time() - start} seconds.")
4 changes: 4 additions & 0 deletions examples/language/grok-1/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch>=2.1.0,<2.2.0
colossalai>=0.3.6
sentencepiece==0.1.99
transformers==4.35.0
11 changes: 11 additions & 0 deletions examples/language/grok-1/run_inference_fast.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env bash

PRETRAINED=${1:-"hpcaitech/grok-1"}
TOKENIZER=${2:-"tokenizer.model"}

torchrun --standalone --nproc_per_node 8 inference_tp.py --pretrained "$PRETRAINED" \
--tokenizer "$TOKENIZER" \
--max_new_tokens 64 \
--text "The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence." \
"将以下句子翻译成英语。 我喜欢看电影和读书。" \
"All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books?"
11 changes: 11 additions & 0 deletions examples/language/grok-1/run_inference_slow.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env bash

PRETRAINED=${1:-"hpcaitech/grok-1"}
TOKENIZER=${2:-"tokenizer.model"}

python3 inference.py --pretrained "$PRETRAINED" \
--tokenizer "$TOKENIZER" \
--max_new_tokens 64 \
--text "The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence." \
"将以下句子翻译成英语。 我喜欢看电影和读书。" \
"All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books?"
1 change: 1 addition & 0 deletions examples/language/grok-1/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pip install -r requirements.txt
46 changes: 46 additions & 0 deletions examples/language/grok-1/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import argparse

import torch


class Bcolors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"


def print_output(text, output):
print(f"-----\n{Bcolors.OKBLUE}{text}{Bcolors.ENDC}{output[len(text):]}")


@torch.no_grad()
def inference(model, sp, text, **generate_kwargs):
input_ids = sp.encode(text)
input_ids = torch.tensor([input_ids]).cuda()
attention_mask = torch.ones_like(input_ids)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
**generate_kwargs,
}
outputs = model.generate(**inputs)
return outputs[0].tolist()


def get_defualt_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained", type=str, default="hpcaitech/grok-1")
parser.add_argument("--tokenizer", type=str, default="tokenizer.model")
parser.add_argument("--text", type=str, nargs="+", default=["Hi, what's your name?"])
parser.add_argument("--max_new_tokens", type=int, default=30)
parser.add_argument("--do_sample", action="store_true", default=False)
parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value")
parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering")
parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation")
return parser

0 comments on commit 848a574

Please sign in to comment.