Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] cannot convert x.to(torch.uint8) #3247

Open
braindevices opened this issue Oct 18, 2024 · 1 comment
Open

🐛 [Bug] cannot convert x.to(torch.uint8) #3247

braindevices opened this issue Oct 18, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@braindevices
Copy link

Bug Description

it report _to_copy is not supported, when we use things like x.to(torch.uint8)

torch_tensorrt.dynamo.conversion._TRTInterpreter.UnsupportedOperatorException: Conversion of function torch._ops.aten.aten::_to_copy not currently supported!

While executing %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%mul,), kwargs = {dtype: torch.uint8, _itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7f3a1e94bc30>: ((1, 3, 5, 7), torch.float32, False, (105, 35, 7, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3a20b3d3b0>: ((1, 3, 5, 7), torch.float32, False, (105, 35, 7, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7f3a20aa39b0>: ((1, 3, 5, 7), torch.float32, False, (105, 35, 7, 1), torch.contiguous_format, False, {})}})

But actually tensorrt support this operation, if we convert to onnx then load onnx in tensorrt

To Reproduce

Steps to reproduce the behavior:

  1. define a dummy model contain to()
import torch
from torch import nn
class dummy_t(nn.Module):
    def __init__(self) -> None:
        super().__init__()
    def forward(self, x: torch.Tensor):
        return x.clamp_(0, 1).mul_(255).to(dtype=torch.uint8)
xs = [torch.randn((1,3,5,7)).cuda()]
exported = torch.export.export(
    dummy_t().cuda(),
    args=tuple(xs)
)
  1. run trt export, it will fail:
torch_tensorrt.dynamo.convert_module_to_trt_engine(
    exported,
    assume_dynamic_shape_support=False,
    inputs=tuple(xs),
    use_python_runtime=False,
    enabled_precisions={torch.float32},
    use_fast_partitioner=False,
    debug=True,
    min_block_size=1,
    require_full_compilation=True
)
  1. run onnx->trt it works fine
from tempfile import NamedTemporaryFile
import onnx
import tensorrt as trt
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
import io
output_names = ['output0']
input_names = ["x"]
with NamedTemporaryFile() as f:
    onnx_program = torch.onnx.export(
        dummy_t().cuda(),
        tuple(xs),
        f.name,
        verbose=False,
        opset_version=20,
        do_constant_folding=True,  # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
        # https://github.com/pytorch/pytorch/issues/73843
        input_names=input_names,
        output_names=output_names,
        dynamo=False,
        training=torch.onnx.TrainingMode.EVAL # we can export trainable model!
    )
    model_onnx: onnx.ModelProto
    model_onnx = onnx.load(f.name)
workspace = 10*1024**2
trt_logger = trt.Logger(trt.Logger.INFO)
trt_logger.min_severity = trt.Logger.Severity.VERBOSE

builder = trt.Builder(trt_logger)
config = builder.create_builder_config()
config.set_memory_pool_limit(pool=trt.MemoryPoolType.WORKSPACE, pool_size=workspace)
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, trt_logger)
if not parser.parse(model_onnx.SerializeToString()):
    raise RuntimeError(f'failed to load ONNX model')

inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]

with builder.build_serialized_network(network, config) as engine, io.BytesIO() as engine_bytes: # type: ignore
    engine_bytes.write(engine)
    engine_bytes.seek(0)
    serialized_trt_engine = engine_bytes.read()

pt_trt = PythonTorchTensorRTModule(
    serialized_trt_engine,
    input_names=input_names,
    output_names=["output0"]
)
print(pt_trt(*xs).dtype)

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT 2.4.0
  • PyTorch Version (e.g. 1.0): 2.4.1
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Almalinux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Python version: 3.11
  • CUDA version: 12.3
  • GPU models and configuration: RTX4k

Additional context

@braindevices braindevices added the bug Something isn't working label Oct 18, 2024
@braindevices braindevices changed the title 🐛 [Bug] cannot do y.to(torch.uint8) 🐛 [Bug] cannot convert x.to(torch.uint8) Oct 18, 2024
@apbose
Copy link
Collaborator

apbose commented Nov 14, 2024

Hi, the above complains because the copy validator in the converter does not support uint8 as a valid input data type. That is because TRT does not support uint8 in its operations. For example if you run the below code (note that there are some changes in the API names) for onnx conversion and then loading in TRT

import torch
import torch_tensorrt
import tensorrt as trt 
from torch import nn
class dummy_t(nn.Module):
    def __init__(self) -> None:
        super().__init__()
    def forward(self, x: torch.Tensor):
        y = x.clamp_(0, 1).mul_(255).to(dtype=torch.uint8)
        return torch.mul(y,1)
xs = [torch.randn((1,3,5,7)).cuda()]
exported = torch.export.export(
    dummy_t().cuda(),
    args=tuple(xs)
)
from tempfile import NamedTemporaryFile
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
import io
import onnx
output_names = ['output0']
input_names = ["x"]
with NamedTemporaryFile() as f:
    onnx_program = torch.onnx.export(
        dummy_t().cuda(),
        tuple(xs),
        f.name,
        verbose=False,
        opset_version=20,
        do_constant_folding=True,  # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
        # https://github.com/pytorch/pytorch/issues/73843
        input_names=input_names,
        output_names=output_names,
        dynamo=False,
        training=torch.onnx.TrainingMode.EVAL # we can export trainable model!
    )
    model_onnx: onnx.ModelProto
    model_onnx = onnx.load(f.name)
workspace = 10*1024**2
trt_logger = trt.Logger(trt.Logger.INFO)
trt_logger.min_severity = trt.Logger.Severity.VERBOSE

builder = trt.Builder(trt_logger)
config = builder.create_builder_config()
config.set_memory_pool_limit(pool=trt.MemoryPoolType.WORKSPACE, pool_size=workspace)
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, trt_logger)
if not parser.parse(model_onnx.SerializeToString()):
    raise RuntimeError(f'failed to load ONNX model')

inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]

with builder.build_serialized_network(network, config) as engine, io.BytesIO() as engine_bytes: # type: ignore
    engine_bytes.write(engine)
    engine_bytes.seek(0)
    serialized_trt_engine = engine_bytes.read()

pt_trt = PythonTorchTensorRTModule(
    serialized_trt_engine,
    input_binding_names=input_names,
    output_binding_names=["output0"]
)
print(pt_trt(*xs).dtype)

The above will fail with

[11/14/2024-20:39:08] [TRT] [V] Static check for parsing node: /Mul_1 [Mul]
Traceback (most recent call last):
  File "/code/torchTRT/TensorRT/issue_3247.py", line 62, in <module>
    raise RuntimeError(f'failed to load ONNX model')
RuntimeError: failed to load ONNX model

since it won't be able to process the uint8 input

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants