Skip to content

Commit

Permalink
scripts/vsmlrt.py: rename obey_fp16 to force_fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Jan 18, 2023
1 parent 6775d42 commit e45053c
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class TRT:
output_format: int = 0 # 0: fp32, 1: fp16
min_shapes: typing.Tuple[int, int] = (0, 0)
faster_dynamic_shapes: bool = True
obey_fp16: bool = False
force_fp16: bool = False

# internal backend attributes
supports_onnx_serialization: bool = False
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def trtexec(
output_format: int = 0,
min_shapes: typing.Tuple[int, int] = (0, 0),
faster_dynamic_shapes: bool = True,
obey_fp16: bool = False
force_fp16: bool = False
) -> str:

# tensort runtime version, e.g. 8401 => 8.4.1
Expand All @@ -1051,7 +1051,7 @@ def trtexec(
if isinstance(max_shapes, int):
max_shapes = (max_shapes, max_shapes)

if obey_fp16:
if force_fp16:
fp16 = True
tf32 = False

Expand Down Expand Up @@ -1158,15 +1158,15 @@ def trtexec(
if faster_dynamic_shapes and not static_shape and trt_version >= 8500:
args.append("--preview=+fasterDynamicShapes0805")

if obey_fp16:
if force_fp16:
if trt_version >= 8401:
args.extend([
"--layerPrecisions=*:fp16",
"--layerOutputTypes=*:fp16",
"--precisionConstraints=obey"
])
else:
raise ValueError('"obey_fp16" is not available')
raise ValueError('"force_fp16" is not available')

if log:
env_key = "TRTEXEC_LOG_FILE"
Expand Down Expand Up @@ -1376,7 +1376,7 @@ def _inference(
output_format=backend.output_format,
min_shapes=backend.min_shapes,
faster_dynamic_shapes=backend.faster_dynamic_shapes,
obey_fp16=backend.obey_fp16
force_fp16=backend.force_fp16
)
clip = core.trt.Model(
clips, engine_path,
Expand Down Expand Up @@ -1502,7 +1502,6 @@ class BackendV2:
@staticmethod
def TRT(*,
num_streams: int = 1,
obey_fp16: bool = False,
fp16: bool = False,
tf32: bool = True,
output_format: int = 0, # 0: fp32, 1: fp16
Expand All @@ -1512,6 +1511,7 @@ def TRT(*,
min_shapes: typing.Tuple[int, int] = (0, 0),
opt_shapes: typing.Optional[typing.Tuple[int, int]] = None,
max_shapes: typing.Optional[typing.Tuple[int, int]] = None,
force_fp16: bool = False,
use_cublas: bool = False,
use_cudnn: bool = True,
device_id: int = 0,
Expand All @@ -1520,7 +1520,7 @@ def TRT(*,

return Backend.TRT(
num_streams=num_streams,
fp16=fp16, obey_fp16=obey_fp16, tf32=tf32, output_format=output_format,
fp16=fp16, force_fp16=force_fp16, tf32=tf32, output_format=output_format,
workspace=workspace, use_cuda_graph=use_cuda_graph,
static_shape=static_shape,
min_shapes=min_shapes, opt_shapes=opt_shapes, max_shapes=max_shapes,
Expand Down

0 comments on commit e45053c

Please sign in to comment.