From 754b797470841fb9754612cc1f67b1f5752424ac Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Wed, 7 Feb 2024 13:17:17 -0800 Subject: [PATCH] [export] Remove torch._export.export uses (#6486) --- test/stablehlo/test_export_llama.py | 7 +++---- test/stablehlo/test_stablehlo_inference.py | 2 +- test/stablehlo/test_stablehlo_save_load.py | 7 +++---- test/test_core_aten_ops.py | 1 - 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/test/stablehlo/test_export_llama.py b/test/stablehlo/test_export_llama.py index 211f6f3c4d8..9387d11b2bf 100644 --- a/test/stablehlo/test_export_llama.py +++ b/test/stablehlo/test_export_llama.py @@ -2,7 +2,6 @@ import torch_xla.core.xla_model as xm from torch_xla.stablehlo import save_as_stablehlo, StableHLOExportOptions import torch -import torch._export import torchvision import tempfile @@ -18,7 +17,7 @@ def test_llama_export(self): model = llama_model.Transformer(options) arg = (torch.randint(0, 1000, (1, 100)), 0) """ - exported = torch._export.export(model, arg) + exported = torch.export.export(model, arg) with tempfile.TemporaryDirectory() as tempdir: save_as_stablehlo(exported, arg, tempdir) @@ -30,7 +29,7 @@ def test_llama_export(self): options = StableHLOExportOptions() options.override_tracing_arguments = arg with torch.no_grad(): - exported2 = torch._export.export(gen, arg) + exported2 = torch.export.export(gen, arg) with tempfile.TemporaryDirectory() as tempdir: save_as_stablehlo(exported2, tempdir, options) @@ -40,7 +39,7 @@ def test_llama_export(self): arg = (torch.randint(0, 1000, (8, 100)), torch.arange(0, 100), None) options = StableHLOExportOptions() options.override_tracing_arguments = arg - exported = torch._export.export(model, arg) + exported = torch.export.export(model, arg) with tempfile.TemporaryDirectory() as tempdir: save_as_stablehlo(exported, tempdir, options) diff --git a/test/stablehlo/test_stablehlo_inference.py b/test/stablehlo/test_stablehlo_inference.py index d8f09343b89..311aed0c4b5 100644 --- a/test/stablehlo/test_stablehlo_inference.py +++ b/test/stablehlo/test_stablehlo_inference.py @@ -10,7 +10,7 @@ def export_torch_model(model, args): - exported = torch._export.export(model, args) + exported = torch.export.export(model, args) options = StableHLOExportOptions() options.override_tracing_arguments = args return exported_program_to_stablehlo(exported, options) diff --git a/test/stablehlo/test_stablehlo_save_load.py b/test/stablehlo/test_stablehlo_save_load.py index 7753b4f631a..1e1e41c3513 100644 --- a/test/stablehlo/test_stablehlo_save_load.py +++ b/test/stablehlo/test_stablehlo_save_load.py @@ -4,7 +4,6 @@ from torch_xla import save_torch_model_as_stablehlo, save_as_stablehlo from torch_xla.stablehlo import StableHLOExportOptions, StableHLOGraphModule import torch -import torch._export import torchvision import unittest import os @@ -92,7 +91,7 @@ def test_cat(self): def test_save_load(self): model = ElementwiseAdd() inputs = model.get_random_inputs() - exported = torch._export.export(model, inputs) + exported = torch.export.export(model, inputs) options = StableHLOExportOptions() options.override_tracing_arguments = inputs with tempfile.TemporaryDirectory() as tempdir: @@ -104,7 +103,7 @@ def test_save_load(self): def test_save_load_without_saving_weights(self): model = ElementwiseAdd() inputs = model.get_random_inputs() - exported = torch._export.export(model, inputs) + exported = torch.export.export(model, inputs) options = StableHLOExportOptions() options.override_tracing_arguments = inputs options.save_weights = False @@ -138,7 +137,7 @@ def test_save_load2_without_saving_weights(self): def test_save_load3(self): model = ElementwiseAdd() inputs = model.get_random_inputs() - exported = torch._export.export(model, inputs) + exported = torch.export.export(model, inputs) with tempfile.TemporaryDirectory() as tempdir: # Shouldnt need specify options because exported has example_input inside save_as_stablehlo(exported, tempdir) diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 5bf70dc0edc..ac36f88ebe6 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -3,7 +3,6 @@ from torch_xla.stablehlo import exported_program_to_stablehlo from torch.utils import _pytree as pytree import torch -import torch._export import os import tempfile