Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Converters for Basic Casting Operations #607

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@
from .transpose import *
from .unary import *
from .view import *
from .cast import *
72 changes: 72 additions & 0 deletions torch2trt/converters/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


def convert_cast(ctx):
"""
A simple converter for supporting casting operations.

IMPORTANT: Note that because TensorRT does not support
64 bit data types, .long() will not be supported
"""
input_tensor = ctx.method_args[0]
layer = ctx.network.add_identity(input_tensor._trt)
output = ctx.method_return
JWLee89 marked this conversation as resolved.
Show resolved Hide resolved
output._trt = layer.get_output(0)


@tensorrt_converter("torch.Tensor.float")
def convert_float(ctx):
convert_cast(ctx)


@tensorrt_converter("torch.Tensor.int")
def convert_int(ctx):
convert_cast(ctx)


@tensorrt_converter("torch.Tensor.bool")
def convert_bool(ctx):
convert_cast(ctx)


class DotFloat(torch.nn.Module):
def __init__(self):
super(DotFloat, self).__init__()

def forward(self, x):
return x.float()


class DotInt(torch.nn.Module):
def __init__(self):
super(DotInt, self).__init__()

def forward(self, x):
return x.int()


class DotBool(torch.nn.Module):
def __init__(self):
super(DotBool, self).__init__()

def forward(self, x):
return x.bool()


@add_module_test(torch.bool, torch.device('cuda'), [(1, 3, 3)])
@add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)])
def test_torch_float_cast():
return DotFloat()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)])
@add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)])
def test_torch_int_cast():
return DotInt()


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)])
@add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)])
def test_torch_bool_cast():
return DotBool()