diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index bd44032145a..a6348b5b907 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -99,7 +99,7 @@ def __init__(self, prefix: str, weights): self.offset = 2 self.weight = nn.Parameter( weights.get_tensor( - f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight" + f"{prefix if prefix else ''}decoder.embed_positions.weight" ) ) @@ -317,7 +317,7 @@ def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights): super().__init__() self.process_group = weights.process_group self.hidden_size = config.hidden_size - prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}" + prefix = f"{prefix if prefix else ''}decoder.layers.{layer_id}" self.self_attn = OPTAttention( config, prefix=f"{prefix}.self_attn", @@ -755,6 +755,8 @@ def forward( class OPTForCausalLM(OPTPreTrainedModel): def __init__(self, prefix, config, weights): super().__init__(config) + if not prefix and any(s.startswith("model") for s in weights.routing.keys()): + prefix = "model" self.model = OPTModel(prefix, config, weights)