Skip to content

Commit

Permalink
Generate code for elementwise operations (#962)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcrc2 authored Jul 21, 2021
1 parent 0ab1a29 commit 88e1491
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 16 deletions.
56 changes: 52 additions & 4 deletions src/python/ksc/cgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def entry_point_cpp_type(t, use_torch):
raise ValueError(f'Unable to generate C++ type for "{t}"')


def generate_cpp_entry_points(bindings_to_generate, decls, use_torch=False):
def generate_cpp_entry_points(
bindings_to_generate, decls, elementwise=False, use_torch=False
):
decls_by_name = {decl.name: decl for decl in decls}

def lookup_decl(structured_name):
Expand All @@ -59,7 +61,10 @@ def lookup_decl(structured_name):

cpp_entry_points = "".join(
generate_cpp_entry_point(
binding_name, lookup_decl(structured_name), use_torch=use_torch
binding_name,
lookup_decl(structured_name),
elementwise=elementwise,
use_torch=use_torch,
)
for binding_name, structured_name in bindings_to_generate
)
Expand Down Expand Up @@ -89,7 +94,12 @@ def arg_types_of_decl(decl):
return arg_types


def generate_cpp_entry_point(cpp_function_name, decl, use_torch):
def generate_cpp_entry_point(cpp_function_name, decl, elementwise, use_torch):
if elementwise:
if not use_torch:
raise ValueError("Elementwise operations only available when using torch")
return generate_cpp_elementwise_entry_point(cpp_function_name, decl)

arg_types = arg_types_of_decl(decl)
num_args = len(arg_types)

Expand All @@ -110,7 +120,7 @@ def join_args(sep, callable):
for i in range(num_args):
cpp += f" auto ks_arg{i} = convert_argument<{ks_cpp_type(arg_types[i])}>(arg{i});\n"

# auto ks_ret = ks::my_kernel(&g_alloc, ks_arg0, ..., ks_arg7)
# auto ks_ret = ks::my_kernel(&g_alloc, ks_arg0, ..., ks_arg7);
cpp += f"""
auto ks_ret = ks::{ks_function_name}(&g_alloc {join_args("", lambda i: f", ks_arg{i}")});
"""
Expand All @@ -121,3 +131,41 @@ def join_args(sep, callable):
}}
"""
return cpp


def generate_cpp_elementwise_entry_point(cpp_function_name, decl):
arg_types = arg_types_of_decl(decl)
if not all(a == Type.Float for a in arg_types):
raise ValueError(
"Elementwise operations only available for floating-point element type"
)
num_args = len(arg_types)

def join_args(sep, callable):
return sep.join(callable(i) for i in range(num_args))

ks_function_name = utils.encode_name(decl.name.mangled())

# torch::Tensor entry_my_kernel(torch::Tensor arg0, ..., torch::Tensor arg7)
cpp = f"torch::Tensor {cpp_function_name}({join_args(', ', lambda i: f'torch::Tensor arg{i}')}) {{\n"

# auto* arg_data0 = arg0.data_ptr<float>();
# ...
# auto* arg_data7 = arg7.data_ptr<float>();
for i in range(num_args):
cpp += f"""
KS_ASSERT(arg{i}.is_contiguous());
KS_ASSERT(arg{i}.scalar_type() == scalar_type_of_Float);
auto* arg_data{i} = arg{i}.data_ptr<float>();
"""
# ret_data[i] = ks::my_op(&g_alloc, arg_data0[i], arg_data1[i]);
cpp += f"""
auto ret = torch::empty_like(arg0);
auto* ret_data = ret.data_ptr<float>();
for (int i = 0, ne = arg0.numel(); i != ne; ++i) {{
ret_data[i] = ks::{ks_function_name}(&g_alloc {join_args("", lambda i: f", arg_data{i}[i]")});
}}
return ret;
}}
"""
return cpp
20 changes: 16 additions & 4 deletions src/python/ksc/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,12 @@ def _():


def generate_cpp_for_py_module_from_ks(
ks_str, bindings_to_generate, python_module_name, use_aten=True, use_torch=False
ks_str,
bindings_to_generate,
python_module_name,
elementwise=False,
use_aten=True,
use_torch=False,
):
def mangled_with_type(structured_name):
if not structured_name.has_type():
Expand All @@ -201,7 +206,7 @@ def mangled_with_type(structured_name):

cpp_ks_functions, decls = generate_cpp_from_ks(ks_str, use_aten=use_aten)
cpp_entry_points = cgen.generate_cpp_entry_points(
bindings_to_generate, decls, use_torch=use_torch
bindings_to_generate, decls, elementwise=elementwise, use_torch=use_torch
)
cpp_pybind_module_declaration = generate_cpp_pybind_module_declaration(
bindings, python_module_name
Expand Down Expand Up @@ -238,13 +243,14 @@ def m_def(python_name, cpp_name):


def build_py_module_from_ks(
ks_str, bindings_to_generate, use_aten=False, use_torch=False
ks_str, bindings_to_generate, elementwise=False, use_aten=False, use_torch=False
):

cpp_str = generate_cpp_for_py_module_from_ks(
ks_str,
bindings_to_generate,
"PYTHON_MODULE_NAME",
elementwise=elementwise,
use_aten=use_aten,
use_torch=use_torch,
)
Expand All @@ -261,7 +267,12 @@ def build_py_module_from_ks(


def build_module_using_pytorch_from_ks(
ks_str, bindings_to_generate, torch_extension_name, use_aten=False, extra_cflags=[]
ks_str,
bindings_to_generate,
torch_extension_name,
elementwise=False,
use_aten=False,
extra_cflags=[],
):
"""Uses PyTorch C++ extension mechanism to build and load a module
Expand All @@ -279,6 +290,7 @@ def build_module_using_pytorch_from_ks(
ks_str,
bindings_to_generate,
torch_extension_name,
elementwise=elementwise,
use_aten=use_aten,
use_torch=True,
)
Expand Down
101 changes: 93 additions & 8 deletions src/python/ksc/torch_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,21 @@
)

from ksc.type import Type
from ksc.expr import Expr, Def, EDef, GDef, Rule, Const, Var, Lam, Call, Let, If, Assert
from ksc.expr import (
Expr,
Def,
EDef,
GDef,
Rule,
Const,
Var,
Lam,
Call,
Let,
If,
Assert,
make_structured_name,
)
from ksc.expr import StructuredName

from ksc.type_propagate import type_propagate_decls
Expand Down Expand Up @@ -467,7 +481,9 @@ def make_KscAutogradFunction(py_mod):
)


def ksc_defs_to_module(ksc_defs, entry_def, torch_extension_name, generate_lm):
def ksc_defs_to_module(
ksc_defs, entry_def, torch_extension_name, elementwise, generate_lm
):
symtab = dict()
ksc_dir = utils.get_ksc_dir()
decls_prelude = list(parse_ks_filename(ksc_dir + "/src/runtime/prelude.ks"))
Expand Down Expand Up @@ -502,13 +518,14 @@ def ksc_defs_to_module(ksc_defs, entry_def, torch_extension_name, generate_lm):
ks_str,
entry_def.name,
torch_extension_name,
elementwise,
generate_lm,
extra_cflags=default_cflags,
)


def ksc_string_to_module(
ks_str, entry_sn, torch_extension_name, generate_lm, extra_cflags
ks_str, entry_sn, torch_extension_name, elementwise, generate_lm, extra_cflags
):
der = "rev" if generate_lm else "sufrev"
bindings_to_generate = [
Expand All @@ -519,6 +536,7 @@ def ksc_string_to_module(
ks_str,
bindings_to_generate,
torch_extension_name,
elementwise=elementwise,
use_aten=True,
extra_cflags=extra_cflags,
)
Expand All @@ -541,9 +559,15 @@ def cpp_string_to_module(


def ksc_defs_to_autograd_function(
ksc_defs, entry_def, torch_extension_name, generate_lm=True
ksc_defs, entry_def, torch_extension_name, elementwise=False, generate_lm=True
):
mod = ksc_defs_to_module(ksc_defs, entry_def, torch_extension_name, generate_lm)
mod = ksc_defs_to_module(
ksc_defs,
entry_def,
torch_extension_name,
elementwise=elementwise,
generate_lm=generate_lm,
)
return make_KscAutogradFunction(mod)


Expand All @@ -555,7 +579,12 @@ def ksc_string_to_autograd_function(
extra_cflags=default_cflags,
):
mod = ksc_string_to_module(
ks_str, entry_sn, torch_extension_name, generate_lm, extra_cflags
ks_str,
entry_sn,
torch_extension_name,
elementwise=False,
generate_lm=generate_lm,
extra_cflags=extra_cflags,
)
return make_KscAutogradFunction(mod)

Expand Down Expand Up @@ -594,15 +623,71 @@ def tsmod2ksmod(
ksc_def = ts2ks_fromgraph(False, fn_name, ts_graph, example_inputs)
ksc_defs.insert(0, ksc_def)

elementwise = is_elementwise_operation(ksc_defs[-1])
if elementwise:
ksc_defs.pop()

entry_def = ksc_defs[-1]
return ksc_defs_to_autograd_function(
ksc_defs, entry_def, torch_extension_name, generate_lm
ksc_defs,
entry_def,
torch_extension_name,
elementwise=elementwise,
generate_lm=generate_lm,
)


def ts2mod(function, example_inputs, torch_extension_name, generate_lm=True):
fn = torch.jit.script(function)
ksc_def = ts2ks_fromgraph(False, fn.name, fn.graph, example_inputs)
return ksc_defs_to_autograd_function(
[ksc_def], ksc_def, torch_extension_name, generate_lm
[ksc_def],
ksc_def,
torch_extension_name,
elementwise=False,
generate_lm=generate_lm,
)


def is_elementwise_operation(ksc_def):
"""
Inspect the body of a def to determine whether it is a
simple elementwise operation, e.g.
(def vrelu3 None ((_x$o1 : (Tensor 1 Float)))
(let (_1 "relu3")
(let (_3 (map (lam (ts2ks$0 : Float)
(relu3 ts2ks$0)) _x$o1))
_3)))
"""

def is_map(expr, arg_name):
if isinstance(expr, Call):
if expr.name != make_structured_name("map"):
return False
assert len(expr.args) == 2
if not (isinstance(expr.args[1], Var) and expr.args[1].name == arg_name):
return False
lam = expr.args[0]
assert isinstance(lam, Lam)
return (
isinstance(lam.body, Call)
and len(lam.body.args) == 1
and isinstance(lam.body.args[0], Var)
and lam.body.args[0].name == lam.arg.name
)
elif isinstance(expr, Let):
if isinstance(expr.vars, Var):
if expr.body == expr.vars:
return is_map(expr.rhs, arg_name)
else:
return is_map(expr.body, arg_name)
else:
return False
else:
return False

if len(ksc_def.args) != 1:
print(f"Num args {len(ksc_def.args)}")
return False
return is_map(ksc_def.body, ksc_def.args[0].name)

0 comments on commit 88e1491

Please sign in to comment.