diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 152516e7f51..36bb26621f8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -43,7 +43,6 @@ BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, - MODEL_ID, ) from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser @@ -1156,7 +1155,7 @@ def warmup(self, batch: FlashCausalLMBatch): tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, - f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", + f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) log_master( diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index ac42df30c04..8d2431dbc9d 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -29,15 +29,6 @@ CUDA_GRAPHS = cuda_graphs -# This is overridden at model loading. -MODEL_ID = None - - -def set_model_id(model_id: str): - global MODEL_ID - MODEL_ID = model_id - - # NOTE: eventually we should move this into the router and pass back the # index in all cases. ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 22bd759f541..b92ab572a85 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -30,7 +30,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id, set_adapter_to_index +from text_generation_server.models.globals import set_adapter_to_index class SignalHandler: @@ -271,7 +271,6 @@ async def serve_inner( while signal_handler.KEEP_PROCESSING: await asyncio.sleep(0.5) - set_model_id(model_id) asyncio.run( serve_inner( model_id,