forked from AI-Hypercomputer/jetstream-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_server_with_ray.py
156 lines (131 loc) · 4.99 KB
/
run_server_with_ray.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs a pytorch server with ray."""
import os
import time
from typing import Sequence
from absl import app, flags
# import torch_xla2 first!
import jax
from jetstream.core import server_lib
from jetstream.core.config_lib import ServerConfig
from jetstream_pt import ray_engine
from jetstream_pt.config import FLAGS
flags.DEFINE_integer("port", 9000, "port to listen on")
flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool")
flags.DEFINE_string(
"config",
"InterleavedCPUTestServer",
"available servers",
)
flags.DEFINE_integer("prometheus_port", 0, "")
flags.DEFINE_integer("tpu_chips", 16, "all devices tpu_chips")
flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler")
flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server")
flags.DEFINE_bool(
"is_disaggregated", False, "Disaggregated serving if it's True"
)
flags.DEFINE_integer("num_hosts", 0, "Number of TPU host", required=False)
flags.DEFINE_integer(
"worker_chips", 4, "Number of TPU chips per worker", required=False
)
flags.DEFINE_string("decode_pod_slice_name", "", "Decode pod slice name")
def create_engine():
"""create a pytorch engine"""
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
start = time.perf_counter()
engine = ray_engine.create_pytorch_ray_engine(
model_name=FLAGS.model_name,
tokenizer_path=FLAGS.tokenizer_path,
ckpt_path=FLAGS.checkpoint_path,
bf16_enable=FLAGS.bf16_enable,
param_size=FLAGS.size,
context_length=FLAGS.context_length,
batch_size=FLAGS.batch_size,
quantize_weights=FLAGS.quantize_weights,
quantize_kv=FLAGS.quantize_kv_cache,
max_cache_length=FLAGS.max_cache_length,
sharding_config=FLAGS.sharding_config,
enable_jax_profiler=FLAGS.enable_jax_profiler,
jax_profiler_port=FLAGS.jax_profiler_port,
num_hosts=FLAGS.num_hosts,
worker_chips=FLAGS.worker_chips,
tpu_chips=FLAGS.tpu_chips,
)
print("Initialize engine", time.perf_counter() - start)
return engine
def create_disaggregated_engine():
"""create a pytorch engine"""
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
start = time.perf_counter()
prefill_engine_list, decode_engine_list = (
ray_engine.create_pytorch_ray_engine(
model_name=FLAGS.model_name,
tokenizer_path=FLAGS.tokenizer_path,
ckpt_path=FLAGS.checkpoint_path,
bf16_enable=FLAGS.bf16_enable,
param_size=FLAGS.size,
context_length=FLAGS.context_length,
batch_size=FLAGS.batch_size,
quantize_weights=FLAGS.quantize_weights,
quantize_kv=FLAGS.quantize_kv_cache,
max_cache_length=FLAGS.max_cache_length,
sharding_config=FLAGS.sharding_config,
enable_jax_profiler=FLAGS.enable_jax_profiler,
jax_profiler_port=FLAGS.jax_profiler_port,
is_disaggregated=FLAGS.is_disaggregated,
num_hosts=FLAGS.num_hosts,
decode_pod_slice_name=FLAGS.decode_pod_slice_name,
)
)
print("Initialize engine", time.perf_counter() - start)
return (prefill_engine_list, decode_engine_list)
# pylint: disable-next=all
def main(argv: Sequence[str]):
del argv
os.environ["XLA_FLAGS"] = "--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text"
devices = []
for i in range(FLAGS.tpu_chips):
devices.append(i)
print(f"devices: {devices}")
if FLAGS.is_disaggregated:
prefill_engine_list, decode_engine_list = create_disaggregated_engine()
chips = int(len(devices) / 2)
server_config = ServerConfig(
prefill_slices=(f"tpu={chips}",),
prefill_engine_create_fns=(lambda a: prefill_engine_list[0],),
generate_slices=(f"tpu={chips}",),
generate_engine_create_fns=(lambda a: decode_engine_list[0],),
is_ray_backend=True,
)
else:
engine = create_engine()
server_config = ServerConfig(
interleaved_slices=(f"tpu={len(devices)}",),
interleaved_engine_create_fns=(lambda a: engine,),
)
print(f"server_config: {server_config}")
jetstream_server = server_lib.run(
threads=FLAGS.threads,
port=FLAGS.port,
config=server_config,
devices=devices,
jax_padding=False, # Jax_padding must be set as False
)
print("Started jetstream_server....")
jetstream_server.wait_for_termination()
if __name__ == "__main__":
app.run(main)