-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
90c3770
commit 2a149f4
Showing
11 changed files
with
194 additions
and
212 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import warnings | ||
|
||
import jax | ||
import numpy as np | ||
import jax.numpy as jnp | ||
from flax.core.frozen_dict import freeze, unfreeze | ||
from jax.experimental import PartitionSpec as P | ||
from jax.experimental import maps | ||
from jax.experimental.pjit import pjit | ||
from jax.experimental.compilation_cache import compilation_cache as cc | ||
|
||
from transformers import FlaxGPTJForCausalLM, GPTJConfig | ||
from transformers import AutoTokenizer | ||
|
||
from bloom_inference.partitions import set_partitions | ||
|
||
cc.initialize_cache("~/jax_cache") | ||
|
||
warnings.filterwarnings("ignore") | ||
warnings.filterwarnings("ignore", category=ResourceWarning) | ||
|
||
if jax.host_id() == 0: | ||
warnings.filterwarnings("default") | ||
|
||
|
||
# print but only on the first node | ||
def head_print(*args, **kwargs): | ||
if jax.host_id() == 0: | ||
print(*args, **kwargs) | ||
|
||
class Generator: | ||
def __init__(self, mesh_shape, ckpt="EleutherAI/gpt-j-6B"): | ||
# create a mesh and bind names to mesh axses | ||
self.mesh_shape = mesh_shape | ||
devices = np.array(jax.devices()).reshape(self.mesh_shape) | ||
self.mesh = maps.Mesh(devices, ("dp", "mp")) | ||
|
||
# load the model and params | ||
# self.load_model_and_params() | ||
|
||
self.shard_params = pjit( | ||
self.model.to_bf16, | ||
in_axis_resources=(self.spec,), | ||
out_axis_resources=self.spec, | ||
) | ||
|
||
def generate(params, input_ids, attention_mask): | ||
output_ids = self.model.generate(input_ids, attention_mask=attention_mask, params=params).sequences | ||
return output_ids | ||
|
||
self.p_generate = pjit(generate, in_axis_resources=(self.spec, P("dp"), P("dp")), out_axis_resources=P("dp")) | ||
|
||
def load_model_and_params(self): | ||
# TODO loading params should be done in a thread | ||
model, self.params = FlaxGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", _do_init=False) | ||
self.spec = set_partitions(model.params_shape_tree) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") | ||
# setup for generation | ||
tokenizer.pad_token = tokenizer.eos_token | ||
tokenizer.padding_sided = "left" | ||
model.config.max_length = 128 | ||
model.config.num_beams = 1 | ||
model.config.do_sample = True | ||
model.config.pad_token_id = tokenizer.pad_token | ||
|
||
self.model = model | ||
self.tokenizer = tokenizer | ||
|
||
def shard_params(self): | ||
with self.mesh: | ||
self.params = self.shard_params(self.params) | ||
|
||
def generate(self, prompts): | ||
inputs = self.tokenizer(prompts, return_tensors="jax", padding="max_length", truncation=True, max_length=32) # BS = 8 | ||
|
||
with self.mesh: | ||
gen_ids = self.p_generate(freeze(self.params), inputs["input_ids"], inputs["attention_mask"]) | ||
|
||
generated_text = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True) | ||
return generated_text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# copied from https://github.com/kingoflolz/mesh-trzansformer-jax | ||
|
||
import ray | ||
import time | ||
import numpy as np | ||
from queue import Queue | ||
|
||
|
||
@ray.remote(resources=***REMOVED***"tpu": 1***REMOVED***) | ||
class TPUHostWorker(object): | ||
def __init__(self, mesh_shape): | ||
self.mesh_shape = mesh_shape | ||
|
||
self.input_q = Queue(maxsize=1) | ||
self.output_q = Queue(maxsize=1) | ||
|
||
def run(self): | ||
print(f"jax runtime initialization starting") | ||
import jax | ||
|
||
from bloom_inference.generator import Generator, head_print | ||
|
||
start = time.time() | ||
jax.devices() | ||
head_print(f"jax devices: ***REMOVED***jax.device_count()***REMOVED***") | ||
head_print(f"jax runtime initialized in ***REMOVED***time.time() - start:.06***REMOVED***s") | ||
|
||
# load model and params here | ||
head_print("Loading model") | ||
generator = Generator(self.mesh_shape) | ||
generator.load_model_and_params() | ||
head_print("Loading complete") | ||
|
||
start = time.time() | ||
generator.shard_params() | ||
head_print(f"Initialized in ***REMOVED***time.time() - start:.06***REMOVED***s") | ||
|
||
while True: | ||
operation, prompts = self.input_q.get() | ||
if operation == "generate": | ||
generated_text = generator.generate(prompts) | ||
self.output_q.put(generated_text) | ||
else: | ||
raise Exception("Not implemented") | ||
|
||
def generate(self, input): | ||
self.input_q.put(("generate", input)) | ||
return self.output_q.get() |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.