diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py index ea3f6cac067..1f6ad974f20 100644 --- a/test/stablehlo/test_pt2e_qdq.py +++ b/test/stablehlo/test_pt2e_qdq.py @@ -6,7 +6,7 @@ import torch import torch_xla.core.xla_model as xm import torchvision -from torch._export import capture_pre_autograd_graph +from torch.export import export_for_training from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.xnnpack_quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config) @@ -88,7 +88,7 @@ def test_resnet18(self): # Step 1: export resnet18 args = (torch.randn(1, 3, 224, 224),) m = torchvision.models.resnet18().eval() - m = capture_pre_autograd_graph(m, args) + m = export_for_training(m, args).module() # Step 2: Insert observers or fake quantize modules quantizer = XNNPACKQuantizer().set_global( @@ -120,7 +120,7 @@ def test_resnet18_per_channel(self): # Step 1: export resnet18 args = (torch.randn(1, 3, 224, 224),) m = torchvision.models.resnet18().eval() - m = capture_pre_autograd_graph(m, args) + m = export_for_training(m, args).module() # Step 2: Insert observers or fake quantize modules quantizer = XNNPACKQuantizer().set_global(