diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 4a35d704..6aca686e 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -14,17 +14,12 @@ from absl import app from absl import flags -from absl import logging -import sys import jax import jax.numpy as jnp -import numpy as np from jetstream.engine import token_utils -from absl.testing import absltest import os -import sys from jetstream_pt import engine as je import time diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 620e4529..8c7881da 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -14,18 +14,14 @@ """Implement Jet Engine API.""" -import copy from typing import Any, List, Optional, Tuple, Union import threading import functools from flax import struct -from absl import logging import jax from jax import numpy as jnp -from jax.experimental import mesh_utils import torch -import jax.sharding as jsharding import numpy as np from jetstream.engine import engine_api, tokenizer_pb2, token_utils @@ -37,7 +33,6 @@ from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData from torch.utils import _pytree as pytree -import torch diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index fa13b732..c2d848b5 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -16,7 +16,7 @@ """This version contains modification to make it easier to trace and support batch.""" import math -from typing import Any, List, Optional, Tuple +from typing import Optional, Tuple import torch from torch import nn diff --git a/jetstream_pt/third_party/llama2/model_exportable.py b/jetstream_pt/third_party/llama2/model_exportable.py index eac15ddd..c56501de 100644 --- a/jetstream_pt/third_party/llama2/model_exportable.py +++ b/jetstream_pt/third_party/llama2/model_exportable.py @@ -1,20 +1,14 @@ # pylint: disable-all """This version contains modification to make it easier to trace and support batch.""" -import math -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional import torch from torch import nn import torch.nn.functional as F from . import model_args -import jax.sharding as jsharding -from jax.experimental import mesh_utils import jax -import jax.numpy as jnp -import torch_xla2 -import torch_xla2.extra from jetstream_pt.layers import Attention, RMSNorm, Int8Embedding, WeightOnlyInt8Linear diff --git a/jetstream_pt/third_party/llama2/tokenizer.py b/jetstream_pt/third_party/llama2/tokenizer.py index 7a8feb72..d39c9f93 100644 --- a/jetstream_pt/third_party/llama2/tokenizer.py +++ b/jetstream_pt/third_party/llama2/tokenizer.py @@ -2,7 +2,6 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os -from logging import getLogger from typing import List from sentencepiece import SentencePieceProcessor diff --git a/run_interactive.py b/run_interactive.py index 7ac733c6..6ec9bd42 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -17,22 +17,16 @@ from absl import logging import random from typing import List -import sys import jax -import jax.numpy as jnp -import numpy as np from jetstream.engine import token_utils -from absl.testing import absltest -from colorama import Fore, Back, Style +from colorama import Fore, Style import os -import sys from jetstream_pt import engine as je import time -import logging logging.getLogger().setLevel(logging.ERROR) diff --git a/run_server.py b/run_server.py index 0f7a0321..ff285e3a 100644 --- a/run_server.py +++ b/run_server.py @@ -21,7 +21,6 @@ from jetstream.core import server_lib import jetstream_pt -from jetstream_pt import config from jetstream.core.config_lib import ServerConfig