Skip to content

Commit

Permalink
fix for #9233
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth committed Oct 11, 2024
1 parent e80b82a commit 609e9fb
Showing 1 changed file with 10 additions and 26 deletions.
36 changes: 10 additions & 26 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ class _ModelInfo:
is_embedding_model: bool
supports_multimodal: bool
supports_pp: bool
has_inner_state: bool
is_attention_free: bool

@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
Expand All @@ -167,6 +169,8 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
is_embedding_model=is_embedding_model(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
)


Expand Down Expand Up @@ -382,6 +386,12 @@ def is_pp_supported_model(
) -> bool:
return self.inspect_model_cls(architectures).supports_pp

def model_has_inner_state(self, architectures: Union[str, List[str]]) -> bool:
return self.inspect_model_cls(architectures).has_inner_state

def is_attention_free_model(self, architectures: Union[str, List[str]]) -> bool:
return self.inspect_model_cls(architectures).is_attention_free


ModelRegistry = _ModelRegistry({
model_arch: _LazyRegisteredModel(
Expand Down Expand Up @@ -430,32 +440,6 @@ def _run() -> None:
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))

@staticmethod
def model_has_inner_state(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")

has_instate = partial(ModelRegistry._check_stateless,
has_inner_state,
default=False)

return any(has_instate(arch) for arch in architectures)

@staticmethod
def is_attention_free_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")

is_attn_free = partial(ModelRegistry._check_stateless,
is_attention_free,
default=False)

return any(is_attn_free(arch) for arch in architectures)


if __name__ == "__main__":
_run()

0 comments on commit 609e9fb

Please sign in to comment.