Skip to content

Commit

Permalink
[aoti] Remove example inputs from aoti_compile_and_package (pytorch#1…
Browse files Browse the repository at this point in the history
…40991)

Differential Revision: [D66136724](https://our.internmc.facebook.com/intern/diff/D66136724)
Pull Request resolved: pytorch#140991
Approved by: https://github.com/yushangdi, https://github.com/desertfire
ghstack dependencies: pytorch#140990
  • Loading branch information
angelayi authored and pytorchmergebot committed Nov 20, 2024
1 parent cb6a21b commit 878a849
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 12 deletions.
4 changes: 1 addition & 3 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,9 +1521,7 @@ def load(cls, model, example_inputs):
strict=False,
)
with torch.no_grad():
package_path = torch._inductor.aoti_compile_and_package(
ep, example_args, example_kwargs
) # type: ignore[arg-type]
package_path = torch._inductor.aoti_compile_and_package(ep) # type: ignore[arg-type]

cls.cache[key] = torch._inductor.aoti_load_package(package_path)

Expand Down
1 change: 0 additions & 1 deletion docs/source/torch.compiler_aot_inductor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ For more details on ``torch.export.export``, you can refer to the :ref:`torch.ex
# then load it back using torch.export.load on your inference platform to run AOT compilation.
output_path = torch._inductor.aoti_compile_and_package(
exported,
example_inputs,
# [Optional] Specify the generated shared library path. If not specified,
# the generated artifact is stored in your system temp directory.
package_path=os.path.join(os.getcwd(), "model.pt2"),
Expand Down
2 changes: 1 addition & 1 deletion docs/source/torch.compiler_aot_inductor_minifier.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Here is sample code which will generate an error because we injected an error on
model = Model().to("cuda")
example_inputs = (torch.randn(8, 10).to("cuda"),)
ep = torch.export.export(model, example_inputs)
package_path = torch._inductor.aoti_compile_and_package(ep, example_inputs)
package_path = torch._inductor.aoti_compile_and_package(ep)
compiled_model = torch._inductor.aoti_load_package(package_path)
result = compiled_model(*example_inputs)
Expand Down
3 changes: 1 addition & 2 deletions test/inductor/test_aot_inductor_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def compile(
strict=False,
)
package_path = torch._inductor.aoti_compile_and_package(
ep, args, kwargs, package_path=package_path, inductor_configs=inductor_configs
ep, package_path=package_path, inductor_configs=inductor_configs
) # type: ignore[arg-type]
loaded = load_package(package_path)
return loaded
Expand Down Expand Up @@ -138,7 +138,6 @@ def forward(self, x, y):
# cubin files are removed when exiting this context
package_path = torch._inductor.aoti_compile_and_package(
ep,
example_inputs,
package_path=f.name,
) # type: ignore[arg-type]
loaded = torch._inductor.aoti_load_package(package_path)
Expand Down
2 changes: 1 addition & 1 deletion test/inductor/test_minifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def forward(self, x):
model, example_inputs
)
torch._inductor.aoti_compile_and_package(
ep, example_inputs
ep
)
"""
return self._run_full_test(run_code, None, expected_error, isolate=True)
Expand Down
24 changes: 20 additions & 4 deletions torch/_inductor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import logging
import os
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

Expand All @@ -20,6 +21,9 @@
]


log = logging.getLogger(__name__)


def compile(
gm: torch.fx.GraphModule,
example_inputs: List["InputType"],
Expand All @@ -44,8 +48,8 @@ def compile(

def aoti_compile_and_package(
exported_program,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
_deprecated_unused_args=None,
_deprecated_unused_kwargs=None,
*,
package_path: Optional[str] = None,
inductor_configs: Optional[Dict[str, Any]] = None,
Expand All @@ -72,8 +76,6 @@ def aoti_compile_and_package(
Args:
exported_program: An exported program created through a call from torch.export
args: Example positional inputs
kwargs: Optional example keyword inputs
package_path: Optional specified path to the generated .pt2 artifact.
inductor_configs: Optional dictionary of configs to control inductor.
Expand All @@ -85,6 +87,18 @@ def aoti_compile_and_package(
if not isinstance(exported_program, ExportedProgram):
raise ValueError("Only ExportedProgram is supported")

if exported_program.example_inputs is None:
raise RuntimeError(
"exported_program.example_inputs is required to be set in order "
"for AOTInductor compilation."
)

if _deprecated_unused_args is not None or _deprecated_unused_kwargs is not None:
log.warning(
"You no longer need to specify args/kwargs to aoti_compile_and_package "
"as we can get this information from exported_program.example_inputs."
)

assert package_path is None or package_path.endswith(
".pt2"
), f"Expect package path to end with .pt2, got {package_path}"
Expand All @@ -97,6 +111,8 @@ def aoti_compile_and_package(
"of setting the aot_inductor.output_path config."
)

args, kwargs = exported_program.example_inputs

# a wrapper around aoti_compile_and_package_inner.
return aoti_compile_and_package_debug_wrapper(
exported_program,
Expand Down

0 comments on commit 878a849

Please sign in to comment.