From d6ab93a3b4aa22b4d7f7ff6094a3d04ffba7174a Mon Sep 17 00:00:00 2001 From: Fanhai Lu <154379058+FanhaiLu1@users.noreply.github.com> Date: Wed, 24 Apr 2024 21:27:21 -0700 Subject: [PATCH] Clean up engine (#42) --- jetstream_pt/engine.py | 48 ++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 875585c6..44898b48 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -18,22 +18,23 @@ import threading import functools +from etils import epath from flax import struct import jax from jax import numpy as jnp +from safetensors import safe_open import torch import numpy as np from jetstream.engine import engine_api, tokenizer_pb2, token_utils import torch_xla2 +from torch.utils import _pytree as pytree from jetstream_pt.third_party.llama2 import model_exportable, model_args from jetstream_pt import cache_manager from jetstream_pt import quantize from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData -from torch.utils import _pytree as pytree - Mesh = jax.sharding.Mesh P = jax.sharding.PartitionSpec @@ -43,6 +44,7 @@ @struct.dataclass +# pylint: disable-next=all class Prefix: token: jax.Array # [1, seqlen] caches: List[Tuple[jax.Array, jax.Array]] @@ -50,6 +52,7 @@ class Prefix: @struct.dataclass +# pylint: disable-next=all class DecodeState: tokens: jax.Array # [batch_size, seqlen] caches: List[Tuple[jax.Array, jax.Array]] @@ -65,6 +68,7 @@ class DecodeState: # NOTE model specific +# pylint: disable-next=all class PyTorchEngine(engine_api.Engine): """Wraps functions to the Jet Engine API format.""" @@ -107,6 +111,7 @@ def __init__( # out_shardings=self.get_decode_state_sharding()) self._lock = threading.RLock() + # pylint: disable-next=all def sharding_by_name(self, name): # This allows easier way to edit shardings @@ -123,17 +128,16 @@ def sharding_by_name(self, name): if "attention." in name: if "wo" in name: return self.y_sharding - else: - return self.x_sharding + return self.x_sharding if "feed_forward." in name: if "w2" in name: return self.y_sharding - else: - return self.x_sharding + return self.x_sharding if "output" in name: return self.x_sharding return self.replicated + # pylint: disable-next=all def init_decode_state( self, ) -> DecodeState: @@ -156,6 +160,7 @@ def init_decode_state( ), ) + # pylint: disable-next=all def _call_model_generate( self, weights, @@ -266,10 +271,12 @@ def prefill( def shrink_prefix( self, prefix: Prefix, - new_length: int, + new_length: int, # pylint: disable=unused-argument ) -> Prefix: + """shrink prefix""" return prefix + # pylint: disable-next=all def _insert_no_wrap( self, prefix: Prefix, @@ -345,6 +352,7 @@ def insert(cache, scaler, new_entry): mask, ) + # pylint: disable-next=all def _insert_wrap( self, prefix: Prefix, @@ -353,7 +361,6 @@ def _insert_wrap( ): # returns Decode State start_insert = decode_state.current_position - prefix.seq_len - end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen tokens = decode_state.tokens.at[slot].set(prefix.token) start_insert = start_insert % self.env.cache_sequence_length @@ -517,7 +524,9 @@ def generate( return new_decode_state, result_tokens + # pylint: disable-next=all def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters: + # pylint: disable-next=all return tokenizer_pb2.TokenizerParameters(path=self.env.tokenizer_path) def join_prefixes( @@ -527,6 +536,7 @@ def join_prefixes( prefix2: engine_api.Prefix, length2: int, ) -> tuple[engine_api.Prefix, int]: + """join prefixes""" raise NotImplementedError("join_prefixes not supported") def _make_state_dict_jax(self, model_args_meta): @@ -540,7 +550,6 @@ def make_array(t): return pytree.tree_map_only(torch.Tensor, make_array, model_args_meta) def _load_from_safetensors(self, path): - from safetensors import safe_open weights = {} with safe_open(path, framework="flax", device="cpu") as f: @@ -563,8 +572,9 @@ def _load_from_safetensors(self, path): return weights + # pylint: disable-next=all def load_params(self) -> Params: - # TODO load from files + # We want to fix this: load from files with jax.default_device(self.colocated_cpus()): if self.env.checkpoint_path: if self.env.checkpoint_format == "safetensors": @@ -581,6 +591,7 @@ def load_params(self) -> Params: print(f"Name: {k}, shape: {v.shape} x {v.dtype}") return jax_weights + @property def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]: return jax.devices("cpu")[0] @@ -624,18 +635,21 @@ def max_prefill_length(self) -> int: @property def max_decode_length(self) -> int: """Maximum decode length.""" + # pylint: disable-next=all return self.env._data.max_decode_length @property def mesh(self): - return self._mesh + return self.mesh +# pylint: disable-next=all def create_pytorch_engine( + # pylint: disable-next=all devices: list[Any], tokenizer_path: str, ckpt_path: Optional[str] = None, - samples_per_slot: int = 1, + samples_per_slot: int = 1, # pylint: disable=unused-argument bf16_enable: bool = False, param_size: str = "7b", context_length: int = 1024, @@ -659,7 +673,7 @@ def create_pytorch_engine( checkpoint_format = "" checkpoint_path = "" - if not ckpt_path or ckpt_path == None: + if not ckpt_path or ckpt_path is None: print("WARNING: Using random weights instead of checkpoints.") elif ".safetensors" in ckpt_path: checkpoint_format = "safetensors" @@ -669,9 +683,7 @@ def create_pytorch_engine( "Loading from Pytorch raw checkpoint is not supported!" ) else: - from etils import epath - - path = epath.Path(ckpt_path) if ckpt_path and ckpt_path != None else "" + path = epath.Path(ckpt_path) if ckpt_path and ckpt_path is not None else "" if not path.exists(): raise ValueError(f"Checkpoint path {ckpt_path} not exists!") paths = list(path.glob("*.safetensors")) @@ -680,6 +692,7 @@ def create_pytorch_engine( ), f"Expects 1 *.safetensors in the checkpoint dir, see {len(paths)}" checkpoint_format = "safetensors" checkpoint_path = paths[0] + env_data = JetEngineEnvironmentData( tokenizer_path=tokenizer_path, checkpoint_path=checkpoint_path, @@ -697,7 +710,6 @@ def create_pytorch_engine( tokenizer = token_utils.load_vocab(tokenizer_path) pt_model = None - shard_weights_fn = None if model_name == "llama": args = model_args.get_model_args( param_size, @@ -712,7 +724,7 @@ def create_pytorch_engine( num_params_size = 0 num_params = 0 - for k, v in pt_model.state_dict().items(): + for _, v in pt_model.state_dict().items(): num_params += 1 num_params_size += np.prod(v.shape) * (1 if v.dtype == jnp.int8 else 2) print("Number of param Gbytes:", num_params_size / (1 << 30))