Skip to content

Commit

Permalink
rebase fix?
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Dec 7, 2024
1 parent e5c6d10 commit 96e4a7a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
4 changes: 2 additions & 2 deletions csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ struct ScaledEpilogueBase {
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;

// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;

// This utility function constructs the arguments for the load descriptors
Expand Down
12 changes: 12 additions & 0 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
#include <stddef.h>

#include <cudaTypedefs.h>

#if defined CUDA_VERSION && CUDA_VERSION >= 12000

#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>

#include <iostream>
#include <sstream>
#include <vector>

#include "cutlass/cutlass.h"
#include "scaled_mm_c3x.cuh"

Expand Down

0 comments on commit 96e4a7a

Please sign in to comment.