Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring for maintainability #4

Conversation

DhruvaBansal00
Copy link

Refactoring Marlin MoE implementation for maintainability and mirroring AWQ codepath

Copy link

github-actions bot commented Aug 7, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config


def fused_moe_gptq(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you call this fused_moe_marlin? We want to separate the naming of the kernel from the algorithm

hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):

def create_weights(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't touch unchanged lines

@@ -386,7 +162,7 @@ def forward_tpu(
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.

This layer contains both MergedColumnParallel weights (gate_up_proj /
This layer contains both MergedColumnParallel weights (gate_up_proj /
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't touch unchanged lines

@@ -377,6 +152,7 @@ def forward_tpu(
topk_group: Optional[int],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't touch unchanged lines

@@ -491,8 +267,8 @@ def weight_loader(self,
else:
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
if (param_data[expert_id] != 1 and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't touch unchanged lines

if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
if (param_data[expert_id] != 1 and
(param_data[expert_id] - loaded_weight).abs() > 1e-5):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't touch unchanged lines

@@ -546,7 +322,8 @@ def forward(self, hidden_states: torch.Tensor,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group)
topk_group=self.topk_group,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't touch unchanged lines

ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, int]]:

cls,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't touch unchanged lines

num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
) -> torch.Tensor:
if layer.marlin_state == GPTQMarlinState.REPACK:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this REPACK, please do the repacking in process_weights_after_loading

you can look at gptq_marlin.py for an example

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed this comment. Also deduplicated code with marlin_utils

@robertgshaw2-neuralmagic
Copy link
Collaborator

Thanks for the PR! this looks really good.

Other than spurious change nits, the key feedback is:

  • the repacking should happen in process_weights_after_loading. You can look at gptq_marlin.py for an example

Other thing - I wonder if there is a better way to do the make_expert_params_mapping --- but this could be a follow up

@DhruvaBansal00
Copy link
Author

I think the change nits happened because of the formatter I am using on save. Will fix it rn.
Other two comments should be addressed.

@DhruvaBansal00
Copy link
Author

@robertgshaw2-neuralmagic resolved all formatting changes. Let me know if this is good to go!

@DhruvaBansal00
Copy link
Author

/ready

@github-actions github-actions bot added the ready label Aug 15, 2024
Comment on lines 13 to 14
from vllm.model_executor.layers.fused_moe import fused_moe, single_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think it would be good to keep single_marlin_moe in the same place as fused_moe_marlin, even if the former is only used for testing

@@ -22,33 +22,49 @@
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple

import re

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you make sure to remove all unused imports?

@DhruvaBansal00
Copy link
Author

@ElizaWszola thank you for the feedback! I have made the requested changes and also ran tests again. Hope things look good to merge now!

Would love to help expedite work on supporting 8-bit quantized models as well (these are returning incorrect outputs on my end). Happy to chat sometime!

@ElizaWszola
Copy link

This looks good in overall!

Just two small remaining things:

  • can you make sure that offline_inference.py is running to completion and producing sane output forllm = LLM(model="TheBloke/Mixtral-8x7B-v0.1-GPTQ", revision="gptq-4bit-128g-actorder_True")?
  • can you double check that the code conforms to the output of format.sh?

@robertgshaw2-neuralmagic
Copy link
Collaborator

@ElizaWszola - I think we can merge this + take it from here

@ElizaWszola
Copy link

I'm merging this now. Thanks @DhruvaBansal00!

@ElizaWszola ElizaWszola merged commit 34bb5b0 into neuralmagic:marlin-moe-integration Aug 22, 2024
LucasWilkinson pushed a commit that referenced this pull request Sep 3, 2024
magic_wand semi_structured_sparse_tensor_linear branch integrates 2:4 semi-structured sparsity into SparseTensor. This PR adds a new sparsity config for 2:4 sparsity to neuralmagic-vllm, using the SparseTensor 2:4 support.

This PR also refactors the sparse linear method into a separate file, vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py, which supports all sparsity formats.
LucasWilkinson pushed a commit that referenced this pull request Sep 3, 2024
magic_wand semi_structured_sparse_tensor_linear branch integrates 2:4 semi-structured sparsity into SparseTensor. This PR adds a new sparsity config for 2:4 sparsity to neuralmagic-vllm, using the SparseTensor 2:4 support.

This PR also refactors the sparse linear method into a separate file, vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py, which supports all sparsity formats.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants