Skip to content

Commit

Permalink
Clean up code comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Dec 11, 2024
1 parent 6d574af commit 72f4577
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 42 deletions.
5 changes: 0 additions & 5 deletions benchmarks/cutlass_benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5

# # Initialize a to all ones
# a = torch.ones((m, k), device='cuda')
# # Initialize b to all ones
# b = torch.ones((n, k), device='cuda')

b = prune_to_2_4(b.t()).t()

if dtype == torch.int8:
Expand Down
52 changes: 26 additions & 26 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
}
}

void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& e,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand All @@ -306,36 +307,35 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == c.dtype(),
"currently bias dtype must match output dtype ", c.dtype());
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBias>(
c, a, e, b, a_scales, b_scales, *bias);
out, a, e, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogue>(
c, a, e, b, a_scales, b_scales);
out, a, e, b, a_scales, b_scales);
}
}

// void cutlass_scaled_sparse_mm_azp_sm90(torch::Tensor& out, torch::Tensor
// const& a,
// torch::Tensor const& e,
// torch::Tensor const& b,
// torch::Tensor const& a_scales,
// torch::Tensor const& b_scales,
// torch::Tensor const& azp_adj,
// c10::optional<torch::Tensor> const& azp,
// c10::optional<torch::Tensor> const& bias) {
// TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
// TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

// if (azp) {
// return
// cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
// out, a, e, b, a_scales, b_scales, azp_adj, *azp, bias);
// } else {
// return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
// out, a, e, b, a_scales, b_scales, azp_adj, bias);
// }
// }
void cutlass_scaled_sparse_mm_azp_sm90(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& e,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

if (azp) {
return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
out, a, e, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_sparse_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
out, a, e, b, a_scales, b_scales, azp_adj, bias);
}
}

#endif
5 changes: 0 additions & 5 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,6 @@ struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc = AccType;
// typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
// float>::type;

using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
Expand Down Expand Up @@ -432,9 +430,6 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(1);

// using StrideB = Stride<int64_t, Int<1>, int64_t>;
// using StrideC = typename Gemm::StrideC;

using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
using StrideB = typename Gemm::GemmKernel::StrideB;
Expand Down
4 changes: 2 additions & 2 deletions csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,

// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1); // Row-major
// TORCH_CHECK(b.stride(0) == 1 && c.stride(0) == 1); // Column-major
// TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(b.stride(0) == 1 && c.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,11 @@ def apply_weights(self,
input_scale = layer.input_scale
q_input = x

out = ops.cutlass_scaled_sparse_mm(a=layer.weight,
out = ops.cutlass_scaled_sparse_mm(a=q_input,
b=layer.weight,
e=layer.meta,
b=q_input.t(),
scale_a=layer.weight_scale,
scale_b=input_scale,
scale_a=input_scale,
scale_b=layer.weight_scale,
out_dtype=self.output_dtype,
bias=bias)
assert out.is_contiguous()
Expand Down

0 comments on commit 72f4577

Please sign in to comment.