diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 0691da9b3e0..fc002082c71 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -382,8 +382,13 @@ def __init__( prefix = f"{prefix}.h.{layer_id}" + # NOTE: Falcon 180B uses the ln_attn prefix + ln_prefix = "input_layernorm" + if config.num_hidden_layers == 80: + ln_prefix = "ln_attn" + self.input_layernorm = FastLayerNorm.load( - prefix=f"{prefix}.input_layernorm", + prefix=f"{prefix}.{ln_prefix}", weights=weights, eps=config.layer_norm_epsilon, ) @@ -477,6 +482,10 @@ def __init__(self, config, prefix: str, weights): # in the case no number of layer norms is provided, we default to 1 self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1) + # Falcon 180B uses the ln_attn prefix and has 2 layer norms + if config.num_hidden_layers == 80: + self.num_ln = 2 + if self.num_ln == 1: self.input_ln = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm",