Skip to content

Commit

Permalink
Make kernels extern, use stdlib for fns
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 12, 2024
1 parent cb5f4fa commit f31db99
Showing 1 changed file with 5 additions and 16 deletions.
21 changes: 5 additions & 16 deletions candle-kernels/src/layernorm.cu
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
#define C10_WARP_SIZE 32
#include <stdint.h>

float max(float x, float y) {
return ::fmaxf(x, y);
}
double max(double x, double y) {
return ::fmax(x, y);
}
#define C10_WARP_SIZE 32

float rsqrt(float x) {
return ::rsqrtf(x);
}
double rsqrt(double x) {
return ::rsqrt(x);
}

struct Block1D {
static __forceinline__ __device__ int Tid() { return threadIdx.x; }
Expand Down Expand Up @@ -67,7 +56,7 @@ __inline__ __device__ T BlockReduceSum(T val, T* shared) {
}

#define ROWWISEMOMENTS(TYPENAME, FN_NAME) \
__global__ void FN_NAME( \
extern "C" __global__ void FN_NAME( \
int64_t N, \
double eps, \
const TYPENAME* X, \
Expand Down Expand Up @@ -95,7 +84,7 @@ __global__ void FN_NAME( \
}

#define LAYERNORM_BIAS(TYPENAME, FN_NAME) \
__global__ void FN_NAME( \
extern "C" __global__ void FN_NAME( \
int64_t N, \
const TYPENAME* X, \
const TYPENAME* mean, \
Expand All @@ -116,7 +105,7 @@ __global__ void FN_NAME( \
}

#define LAYERNORM(TYPENAME, FN_NAME) \
__global__ void FN_NAME( \
extern "C" __global__ void FN_NAME( \
int64_t N, \
const TYPENAME* X, \
const TYPENAME* mean, \
Expand Down

0 comments on commit f31db99

Please sign in to comment.