Skip to content

Commit

Permalink
boom boom
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj committed Jul 7, 2022
1 parent 90c3770 commit 2a149f4
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 212 deletions.
File renamed without changes.
81 changes: 81 additions & 0 deletions bloom_inference/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import warnings

import jax
import numpy as np
import jax.numpy as jnp
from flax.core.frozen_dict import freeze, unfreeze
from jax.experimental import PartitionSpec as P
from jax.experimental import maps
from jax.experimental.pjit import pjit
from jax.experimental.compilation_cache import compilation_cache as cc

from transformers import FlaxGPTJForCausalLM, GPTJConfig
from transformers import AutoTokenizer

from bloom_inference.partitions import set_partitions

cc.initialize_cache("~/jax_cache")

warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=ResourceWarning)

if jax.host_id() == 0:
warnings.filterwarnings("default")


# print but only on the first node
def head_print(*args, **kwargs):
if jax.host_id() == 0:
print(*args, **kwargs)

class Generator:
def __init__(self, mesh_shape, ckpt="EleutherAI/gpt-j-6B"):
# create a mesh and bind names to mesh axses
self.mesh_shape = mesh_shape
devices = np.array(jax.devices()).reshape(self.mesh_shape)
self.mesh = maps.Mesh(devices, ("dp", "mp"))

# load the model and params
# self.load_model_and_params()

self.shard_params = pjit(
self.model.to_bf16,
in_axis_resources=(self.spec,),
out_axis_resources=self.spec,
)

def generate(params, input_ids, attention_mask):
output_ids = self.model.generate(input_ids, attention_mask=attention_mask, params=params).sequences
return output_ids

self.p_generate = pjit(generate, in_axis_resources=(self.spec, P("dp"), P("dp")), out_axis_resources=P("dp"))

def load_model_and_params(self):
# TODO loading params should be done in a thread
model, self.params = FlaxGPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", _do_init=False)
self.spec = set_partitions(model.params_shape_tree)

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
# setup for generation
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_sided = "left"
model.config.max_length = 128
model.config.num_beams = 1
model.config.do_sample = True
model.config.pad_token_id = tokenizer.pad_token

self.model = model
self.tokenizer = tokenizer

def shard_params(self):
with self.mesh:
self.params = self.shard_params(self.params)

def generate(self, prompts):
inputs = self.tokenizer(prompts, return_tensors="jax", padding="max_length", truncation=True, max_length=32) # BS = 8

with self.mesh:
gen_ids = self.p_generate(freeze(self.params), inputs["input_ids"], inputs["attention_mask"])

generated_text = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
return generated_text
48 changes: 48 additions & 0 deletions bloom_inference/host_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# copied from https://github.com/kingoflolz/mesh-trzansformer-jax

import ray
import time
import numpy as np
from queue import Queue


@ray.remote(resources=***REMOVED***"tpu": 1***REMOVED***)
class TPUHostWorker(object):
def __init__(self, mesh_shape):
self.mesh_shape = mesh_shape

self.input_q = Queue(maxsize=1)
self.output_q = Queue(maxsize=1)

def run(self):
print(f"jax runtime initialization starting")
import jax

from bloom_inference.generator import Generator, head_print

start = time.time()
jax.devices()
head_print(f"jax devices: ***REMOVED***jax.device_count()***REMOVED***")
head_print(f"jax runtime initialized in ***REMOVED***time.time() - start:.06***REMOVED***s")

# load model and params here
head_print("Loading model")
generator = Generator(self.mesh_shape)
generator.load_model_and_params()
head_print("Loading complete")

start = time.time()
generator.shard_params()
head_print(f"Initialized in ***REMOVED***time.time() - start:.06***REMOVED***s")

while True:
operation, prompts = self.input_q.get()
if operation == "generate":
generated_text = generator.generate(prompts)
self.output_q.put(generated_text)
else:
raise Exception("Not implemented")

def generate(self, input):
self.input_q.put(("generate", input))
return self.output_q.get()
File renamed without changes.
30 changes: 15 additions & 15 deletions inference/tpu_cluster.py → bloom_inference/tpu_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,43 @@
import time

import ray

import numpy as np

from .runner import NetworkRunner
# from func_timeout import func_set_timeout


class TPUCluster:
class TPUManager:
# @func_set_timeout(1200)
def __init__(self,
mesh_shape,
node_count):
assert ray.is_initialized() # needs a valid ray cluster to start

from bloom_inference.host_worker import TPUHostWorker

self.nodes = []
self.node_count = node_count
self.dp, self.mp = mesh_shape

start = time.time()

for i in range(node_count):
self.nodes.append(NetworkRunner.options(max_concurrency=2).remote(mesh_shape))
self.nodes.append(TPUHostWorker.options(max_concurrency=2).remote(mesh_shape))

for n in self.nodes:
n.run.remote()
for node in self.nodes:
node.run.remote()

params = []
for n in self.nodes:
params.append(n.get_params.remote())

self.param_count = ray.get(params)[0]
print(f"Ray actors created in ***REMOVED***time.time() - start:.06***REMOVED***s")
print(f"TPU workers created in ***REMOVED***time.time() - start:.06***REMOVED***s")


# @func_set_timeout(600)
def generate(self, context):
context = np.array_split(context, len(self.nodes), axis=0)
# context = np.array_split(context, len(self.nodes), axis=0)
res = []
# for n, ctx in zip(self.nodes, context):
# res.append(n.generate.remote(ctx))

# inputs = tokenizer(prompts, return_tensors="np", padding="max_length", truncation=True, max_length=32)
# inputs["input_ids"] =

for n, ctx in zip(self.nodes, context):
res.append(n.generate.remote(ctx))

Expand Down
58 changes: 0 additions & 58 deletions inference/generation.py

This file was deleted.

103 changes: 0 additions & 103 deletions inference/runner.py

This file was deleted.

Loading

0 comments on commit 2a149f4

Please sign in to comment.