forked from AI-Hypercomputer/jetstream-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_interactive_disaggregated.py
202 lines (168 loc) · 7.81 KB
/
run_interactive_disaggregated.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import time
from typing import List
from absl import app
from absl import flags
import jax
from jetstream.engine import token_utils
from jetstream_pt import ray_engine
FLAGS = flags.FLAGS
_TOKENIZER_PATH = flags.DEFINE_string(
"tokenizer_path",
"tokenizer.model",
"The tokenizer model path",
required=False,
)
_CKPT_PATH = flags.DEFINE_string(
"checkpoint_path", None, "Directory for .pth checkpoints", required=False
)
_BF16_ENABLE = flags.DEFINE_bool(
"bf16_enable", False, "Whether to enable bf16", required=False
)
_CONTEXT_LENGTH = flags.DEFINE_integer(
"context_length", 1024, "The context length", required=False
)
_BATCH_SIZE = flags.DEFINE_integer(
"batch_size", 32, "The batch size", required=False
)
_PROFILING_OUTPUT = flags.DEFINE_string(
"profiling_output",
"",
"The profiling output",
required=False,
)
_SIZE = flags.DEFINE_string("size", "tiny", "size of model")
_QUANTIZE_WEIGHTS = flags.DEFINE_bool(
"quantize_weights", False, "weight quantization"
)
_QUANTIZE_KV_CACHE = flags.DEFINE_bool(
"quantize_kv_cache", False, "kv_cache_quantize"
)
_MAX_CACHE_LENGTH = flags.DEFINE_integer(
"max_cache_length", 1024, "kv_cache_quantize"
)
_MODEL_NAME = flags.DEFINE_string(
"model_name", None, "model type", required=False
)
_SHARDING_CONFIG = flags.DEFINE_string(
"sharding_config", "", "config file for sharding"
)
_IS_DISAGGREGATED = flags.DEFINE_bool(
"is_disaggregated", False, "Disaggregated serving if it's True"
)
_NUM_HOSTS = flags.DEFINE_integer(
"num_hosts", 4, "Number of TPU host", required=False
)
_DECODE_POD_SLICE_NAME = flags.DEFINE_string(
"decode_pod_slice_name", "", "Decode pod slice name"
)
def create_disaggregated_engines():
"""create a pytorch engine"""
# jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
start = time.perf_counter()
prefill_engine_list, decode_engine_list = (
ray_engine.create_pytorch_ray_engine(
model_name=_MODEL_NAME.value,
tokenizer_path=_TOKENIZER_PATH.value,
ckpt_path=_CKPT_PATH.value,
bf16_enable=True,
param_size=_SIZE.value,
context_length=_CONTEXT_LENGTH.value,
batch_size=_BATCH_SIZE.value,
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
sharding_config=_SHARDING_CONFIG.value,
is_disaggregated=_IS_DISAGGREGATED.value,
num_hosts=_NUM_HOSTS.value,
decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value,
)
)
print("Initialize engine", time.perf_counter() - start)
return (prefill_engine_list[0], decode_engine_list[0])
# pylint: disable-next=all
def main(argv):
print("start the test")
prefill_engine, decode_engine = create_disaggregated_engines()
start = time.perf_counter()
prefill_engine.load_params()
decode_engine.load_params()
print("Load params ", time.perf_counter() - start)
metadata = prefill_engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
stop_tokens = [vocab.eos_id, vocab.pad_id]
max_output_length = 1024
if _PROFILING_OUTPUT.value:
jax.profiler.start_trace(_PROFILING_OUTPUT.value)
decode_engine.init_decode_state()
prompts: List[str] = [
"I believe the meaning of life is",
# pylint: disable-next=all
"To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList<Person>`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList<Person> peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.",
# pylint: disable-next=all
"<s>[INST] <<SYS>>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<</SYS>>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]",
# pylint: disable-next=all
"<s>[INST] <<SYS>>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<</SYS>>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST",
# pylint: disable-next=all
"<s>[INST] <<SYS>>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<</SYS>>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]",
]
for prompt in prompts:
slot = random.randint(0, _BATCH_SIZE.value - 1)
tokens, true_length = token_utils.tokenize_and_pad(
prompt, vocab, is_bos=True, jax_padding=False
)
print(f"---- Input prompts are: {prompt}")
print(f"---- Encoded tokens are: {tokens}")
print(
# pylint: disable-next=all
f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}"
)
prefill_result, _ = prefill_engine.prefill(
params=None, padded_tokens=tokens, true_length=true_length
)
print(
# pylint: disable-next=all
f"---- Transfer prefill result to decode engine pod_slice_name: {decode_engine.pod_slice_name}"
)
decode_engine.transfer(prefill_result)
print(
# pylint: disable-next=all
f"---- Do insert in decode engine pod_slice_name: {decode_engine.pod_slice_name}"
)
decode_state = decode_engine.insert(prefill_result, None, slot=slot)
sampled_tokens_list = []
while True:
# pylint: disable-next=all
decode_state, result_tokens = decode_engine.generate(None, decode_state)
result_tokens = result_tokens.convert_to_numpy()
slot_data = result_tokens.get_result_at_slot(slot)
slot_tokens = slot_data.tokens
slot_lengths = slot_data.lengths
token_id = slot_tokens[slot, 0].item()
if slot_lengths > max_output_length or token_id in stop_tokens:
break
sampled_tokens_list.append(token_id)
print("---- All output tokens.")
print(sampled_tokens_list)
print("---- All output text.")
print(vocab.tokenizer.decode(sampled_tokens_list))
if _PROFILING_OUTPUT.value:
jax.profiler.stop_trace()
if __name__ == "__main__":
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
app.run(main)