Skip to content

Commit

Permalink
Migrate from capture_pre_autograd_graph to export_for_training (#8398)
Browse files Browse the repository at this point in the history
  • Loading branch information
yushangdi authored Dec 17, 2024
1 parent 3f915ab commit b2b890e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/stablehlo/test_pt2e_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch_xla.core.xla_model as xm
import torchvision
from torch._export import capture_pre_autograd_graph
from torch.export import export_for_training
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer, get_symmetric_quantization_config)
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_resnet18(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m = capture_pre_autograd_graph(m, args)
m = export_for_training(m, args).module()

# Step 2: Insert observers or fake quantize modules
quantizer = XNNPACKQuantizer().set_global(
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_resnet18_per_channel(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m = capture_pre_autograd_graph(m, args)
m = export_for_training(m, args).module()

# Step 2: Insert observers or fake quantize modules
quantizer = XNNPACKQuantizer().set_global(
Expand Down

0 comments on commit b2b890e

Please sign in to comment.