Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hotfix/fix_nvtx_pop #7

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion grouped_gemm/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def forward(ctx,
nvtx.range_push("permute_topK forward")
# Empty input check
if not input_act.numel():
if ENABLE_NVTX:
nvtx.range_pop()
return input_act, None

# For top1 case, view the indices as 2D tensor to unify the shape for topk>=2 cases.
Expand Down Expand Up @@ -133,6 +135,8 @@ def backward(ctx, permuted_act_grad, _):
nvtx.range_push("permute_topK backward")
# Empty input check
if not permuted_act_grad.numel():
if ENABLE_NVTX:
nvtx.range_pop()
return permuted_act_grad, None, None, None

if not permuted_act_grad.is_contiguous():
Expand Down Expand Up @@ -170,6 +174,8 @@ def forward(ctx,
# Empty input check
if not input_act.numel():
ctx.probs = probs
if ENABLE_NVTX:
nvtx.range_pop()
return input_act

# Device check
Expand Down Expand Up @@ -229,6 +235,8 @@ def backward(ctx, unpermuted_act_grad):
nvtx.range_push("unpermute_topK backward")
# Empty input check
if not unpermuted_act_grad.numel():
if ENABLE_NVTX:
nvtx.range_pop()
return unpermuted_act_grad, None, ctx.probs

if not unpermuted_act_grad.is_contiguous():
Expand All @@ -255,4 +263,4 @@ def permute(input_act, indices, num_out_tokens=None, max_token_num=0):
return PermuteMoE_topK.apply(input_act, indices, num_out_tokens, max_token_num)

def unpermute(input_act, row_id_map, probs=None):
return UnpermuteMoE_topK.apply(input_act, row_id_map, probs)
return UnpermuteMoE_topK.apply(input_act, row_id_map, probs)