Skip to content

Commit

Permalink
Use generic API for creating MLIR operations (#1477)
Browse files Browse the repository at this point in the history
Co-authored-by: Paulo Valente <[email protected]>
  • Loading branch information
jonatanklosko and polvalente authored May 6, 2024
1 parent 9d8a1ff commit 03a0c1c
Show file tree
Hide file tree
Showing 38 changed files with 2,174 additions and 5,164 deletions.
4 changes: 2 additions & 2 deletions exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \
fi

SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/mlir/ops.cc $(EXLA_DIR)/mlir/builder.cc $(EXLA_DIR)/mlir/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc
HEADERS = $(EXLA_DIR)/mlir/ops.h $(EXLA_DIR)/mlir/builder.h $(EXLA_DIR)/mlir/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o


Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#include "custom_calls.h"
#include "exla_nif_util.h"

#include <Eigen/Dense>
#include <Eigen/QR>
#include "xla/service/custom_call_target_registry.h"

#include "builder.h"
#include "Eigen/Dense"
#include "Eigen/QR"

template <typename DataType>
void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType *in, int64_t m, int64_t k, int64_t n, bool complete) {
Expand Down Expand Up @@ -102,4 +103,9 @@ void qr_cpu_custom_call_f32(void *out[], const void *in[]) {

void qr_cpu_custom_call_f64(void *out[], const void *in[]) {
qr_cpu_custom_call<double>(out, in);
}
}

XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_custom_call_f32);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f64", qr_cpu_custom_call_f64);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16);
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once

template <typename DataType>
void qr_cpu_custom_call(void *out[], const void *in[]);
#ifndef EXLA_MLIR_CUSTOM_CALLS_H_
#define EXLA_MLIR_CUSTOM_CALLS_H_

void qr_cpu_custom_call_bf16(void *out[], const void *in[]);
void qr_cpu_custom_call_f16(void *out[], const void *in[]);
void qr_cpu_custom_call_f32(void *out[], const void *in[]);
void qr_cpu_custom_call_f64(void *out[], const void *in[]);
void qr_cpu_custom_call_f64(void *out[], const void *in[]);

#endif
Loading

0 comments on commit 03a0c1c

Please sign in to comment.