Skip to content

Commit

Permalink
[export] Remove torch._export.export uses (pytorch#6486)
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi authored and amithrm committed Mar 1, 2024
1 parent 78a6706 commit e46f925
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 10 deletions.
7 changes: 3 additions & 4 deletions test/stablehlo/test_export_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch_xla.core.xla_model as xm
from torch_xla.stablehlo import save_as_stablehlo, StableHLOExportOptions
import torch
import torch._export
import torchvision

import tempfile
Expand All @@ -18,7 +17,7 @@ def test_llama_export(self):
model = llama_model.Transformer(options)
arg = (torch.randint(0, 1000, (1, 100)), 0)
"""
exported = torch._export.export(model, arg)
exported = torch.export.export(model, arg)
with tempfile.TemporaryDirectory() as tempdir:
save_as_stablehlo(exported, arg, tempdir)
Expand All @@ -30,7 +29,7 @@ def test_llama_export(self):
options = StableHLOExportOptions()
options.override_tracing_arguments = arg
with torch.no_grad():
exported2 = torch._export.export(gen, arg)
exported2 = torch.export.export(gen, arg)
with tempfile.TemporaryDirectory() as tempdir:
save_as_stablehlo(exported2, tempdir, options)

Expand All @@ -40,7 +39,7 @@ def test_llama_export(self):
arg = (torch.randint(0, 1000, (8, 100)), torch.arange(0, 100), None)
options = StableHLOExportOptions()
options.override_tracing_arguments = arg
exported = torch._export.export(model, arg)
exported = torch.export.export(model, arg)
with tempfile.TemporaryDirectory() as tempdir:
save_as_stablehlo(exported, tempdir, options)

Expand Down
2 changes: 1 addition & 1 deletion test/stablehlo/test_stablehlo_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def export_torch_model(model, args):
exported = torch._export.export(model, args)
exported = torch.export.export(model, args)
options = StableHLOExportOptions()
options.override_tracing_arguments = args
return exported_program_to_stablehlo(exported, options)
Expand Down
7 changes: 3 additions & 4 deletions test/stablehlo/test_stablehlo_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch_xla import save_torch_model_as_stablehlo, save_as_stablehlo
from torch_xla.stablehlo import StableHLOExportOptions, StableHLOGraphModule
import torch
import torch._export
import torchvision
import unittest
import os
Expand Down Expand Up @@ -92,7 +91,7 @@ def test_cat(self):
def test_save_load(self):
model = ElementwiseAdd()
inputs = model.get_random_inputs()
exported = torch._export.export(model, inputs)
exported = torch.export.export(model, inputs)
options = StableHLOExportOptions()
options.override_tracing_arguments = inputs
with tempfile.TemporaryDirectory() as tempdir:
Expand All @@ -104,7 +103,7 @@ def test_save_load(self):
def test_save_load_without_saving_weights(self):
model = ElementwiseAdd()
inputs = model.get_random_inputs()
exported = torch._export.export(model, inputs)
exported = torch.export.export(model, inputs)
options = StableHLOExportOptions()
options.override_tracing_arguments = inputs
options.save_weights = False
Expand Down Expand Up @@ -138,7 +137,7 @@ def test_save_load2_without_saving_weights(self):
def test_save_load3(self):
model = ElementwiseAdd()
inputs = model.get_random_inputs()
exported = torch._export.export(model, inputs)
exported = torch.export.export(model, inputs)
with tempfile.TemporaryDirectory() as tempdir:
# Shouldnt need specify options because exported has example_input inside
save_as_stablehlo(exported, tempdir)
Expand Down
1 change: 0 additions & 1 deletion test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from torch_xla.stablehlo import exported_program_to_stablehlo
from torch.utils import _pytree as pytree
import torch
import torch._export

import os
import tempfile
Expand Down

0 comments on commit e46f925

Please sign in to comment.