Skip to content

Commit

Permalink
Pr 2290 ci run (#2329)
Browse files Browse the repository at this point in the history
* MODEL_ID propagation fix

* fix: remove global model id

---------

Co-authored-by: root <[email protected]>
  • Loading branch information
drbh and root authored Jul 31, 2024
1 parent 34f7dcf commit f7f6187
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 13 deletions.
3 changes: 1 addition & 2 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
BLOCK_SIZE,
CUDA_GRAPHS,
get_adapter_to_index,
MODEL_ID,
)
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
Expand Down Expand Up @@ -1156,7 +1155,7 @@ def warmup(self, batch: FlashCausalLMBatch):

tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE,
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
)

log_master(
Expand Down
9 changes: 0 additions & 9 deletions server/text_generation_server/models/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,6 @@

CUDA_GRAPHS = cuda_graphs

# This is overridden at model loading.
MODEL_ID = None


def set_model_id(model_id: str):
global MODEL_ID
MODEL_ID = model_id


# NOTE: eventually we should move this into the router and pass back the
# index in all cases.
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
Expand Down
3 changes: 1 addition & 2 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
from text_generation_server.models.globals import set_adapter_to_index


class SignalHandler:
Expand Down Expand Up @@ -271,7 +271,6 @@ async def serve_inner(
while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5)

set_model_id(model_id)
asyncio.run(
serve_inner(
model_id,
Expand Down

0 comments on commit f7f6187

Please sign in to comment.