diff --git a/onnx_export.py b/onnx_export.py index 3baab369ca..554d314b0b 100644 --- a/onnx_export.py +++ b/onnx_export.py @@ -90,6 +90,8 @@ def main(): check_forward=args.check_forward, training=args.training, verbose=args.verbose, + input_size=(3, args.img_size, args.img_size), + batch_size=args.batch_size, )