diff --git a/scripts/vsmlrt.py b/scripts/vsmlrt.py index debf834..1d2ea13 100644 --- a/scripts/vsmlrt.py +++ b/scripts/vsmlrt.py @@ -132,7 +132,7 @@ class TRT: output_format: int = 0 # 0: fp32, 1: fp16 min_shapes: typing.Tuple[int, int] = (0, 0) faster_dynamic_shapes: bool = True - obey_fp16: bool = False + force_fp16: bool = False # internal backend attributes supports_onnx_serialization: bool = False @@ -1039,7 +1039,7 @@ def trtexec( output_format: int = 0, min_shapes: typing.Tuple[int, int] = (0, 0), faster_dynamic_shapes: bool = True, - obey_fp16: bool = False + force_fp16: bool = False ) -> str: # tensort runtime version, e.g. 8401 => 8.4.1 @@ -1051,7 +1051,7 @@ def trtexec( if isinstance(max_shapes, int): max_shapes = (max_shapes, max_shapes) - if obey_fp16: + if force_fp16: fp16 = True tf32 = False @@ -1158,7 +1158,7 @@ def trtexec( if faster_dynamic_shapes and not static_shape and trt_version >= 8500: args.append("--preview=+fasterDynamicShapes0805") - if obey_fp16: + if force_fp16: if trt_version >= 8401: args.extend([ "--layerPrecisions=*:fp16", @@ -1166,7 +1166,7 @@ def trtexec( "--precisionConstraints=obey" ]) else: - raise ValueError('"obey_fp16" is not available') + raise ValueError('"force_fp16" is not available') if log: env_key = "TRTEXEC_LOG_FILE" @@ -1376,7 +1376,7 @@ def _inference( output_format=backend.output_format, min_shapes=backend.min_shapes, faster_dynamic_shapes=backend.faster_dynamic_shapes, - obey_fp16=backend.obey_fp16 + force_fp16=backend.force_fp16 ) clip = core.trt.Model( clips, engine_path, @@ -1502,7 +1502,6 @@ class BackendV2: @staticmethod def TRT(*, num_streams: int = 1, - obey_fp16: bool = False, fp16: bool = False, tf32: bool = True, output_format: int = 0, # 0: fp32, 1: fp16 @@ -1512,6 +1511,7 @@ def TRT(*, min_shapes: typing.Tuple[int, int] = (0, 0), opt_shapes: typing.Optional[typing.Tuple[int, int]] = None, max_shapes: typing.Optional[typing.Tuple[int, int]] = None, + force_fp16: bool = False, use_cublas: bool = False, use_cudnn: bool = True, device_id: int = 0, @@ -1520,7 +1520,7 @@ def TRT(*, return Backend.TRT( num_streams=num_streams, - fp16=fp16, obey_fp16=obey_fp16, tf32=tf32, output_format=output_format, + fp16=fp16, force_fp16=force_fp16, tf32=tf32, output_format=output_format, workspace=workspace, use_cuda_graph=use_cuda_graph, static_shape=static_shape, min_shapes=min_shapes, opt_shapes=opt_shapes, max_shapes=max_shapes,