v1.10.0
Intel® Extension for PyTorch* v1.10.0-cpu Release Notes
The Intel® Extension for PyTorch* 1.10 is on top of PyTorch 1.10. In this release, we polished the front-end APIs. The APIs are more simple, stable, and straightforward now. According to the PyTorch community recommendation, we changed the underhood device from XPU
to CPU
. With this change, the model and tensor do not need to be converted to the extension device to get a performance improvement. It simplifies the model changes.
Besides that, we continuously optimize the Transformer* and CNN models by fusing more operators and applying NHWC. We measured the 1.10 performance on Torchvison and HugginFace. As expected, 1.10 can speed up the two model zones. In addition, 1.10 releases the C++ SDK to facilitate PyTorch deployment with the extension.
Highlights
- Change the package name to
intel_extension_for_pytorch
while the original package name isintel_pytorch_extension
. This change targets to avoid any potential legal issues.
v1.9.0-cpu | v1.10.0-cpu |
import intel_pytorch_extension as ipex |
import intel_extension_for_pytorch as ipex |
- The underhood device is changed from the extension-specific device(
XPU
) to the standard CPU device which aligns with PyTorch CPU device design regardless of the dispatch mechanism and operator register mechanism. The model does not need to be converted to the extension device explicitly.
v1.9.0-cpu | v1.10.0-cpu |
import torch
import torchvision.models as models
# Import the extension
import intel_extension_for_pytorch as ipex
resnet18 = models.resnet18(pretrained = True)
# Explicitly convert the model to the extension device
resnet18_xpu = resnet18.to(ipex.DEVICE) |
import torch
import torchvision.models as models
# Import the extension
import intel_extension_for_pytorch as ipex
resnet18 = models.resnet18(pretrained = True) |
- Compared to 1.9.0, 1.10.0 follows PyTorch AMP API(
torch.cpu.amp
) to support auto-mixed-precision.torch.cpu.amp
provides convenience for auto data type conversion at runtime.torch.cpu.amp
supportstorch.bfloat16
now to boost the performance on Intel CPU what has BFloat16 instructions.
import torch
class SimpleNet(torch.nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv = torch.nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=(1, 1), bias=False)
def forward(self, x):
return self.conv(x)
v1.9.0-cpu | v1.10.0-cpu |
# Import the extension
import intel_pytorch_extension as ipex
# Automatically mix precision
ipex.enable_auto_mixed_precision(mixed_dtype = torch.bfloat16)
model = SimpleNet().eval()
x = torch.rand(64, 64, 224, 224)
with torch.no_grad():
model = torch.jit.trace(model, x)
model = torch.jit.freeze(model)
y = model(x) |
# Import the extension
import intel_extension_for_pytorch as ipex
model = SimpleNet().eval()
x = torch.rand(64, 64, 224, 224)
with torch.cpu.amp.autocast(), torch.no_grad():
model = torch.jit.trace(model, x)
model = torch.jit.freeze(model)
y = model(x) |
- The 1.10 release provides the INT8 calibration as an experimental feature while it only supports post-training static quantization now. Compared to 1.9.0, the fronted APIs for quantization is more straightforward and ease-of-use.
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv = nn.Conv2d(10, 10, 3)
def forward(self, x):
x = self.conv(x)
return x
model = MyModel().eval()
# user dataset for calibration.
xx_c = [torch.randn(1, 10, 28, 28) for i in range(2))
# user dataset for validation.
xx_v = [torch.randn(1, 10, 28, 28) for i in range(20))
- Clibration
v1.9.0-cpu | v1.10.0-cpu |
# Import the extension
import intel_pytorch_extension as ipex
# Convert the model to the Extension device
model = Model().to(ipex.DEVICE)
# Create a configuration file to save quantization parameters.
conf = ipex.AmpConf(torch.int8)
with torch.no_grad():
for x in xx_c:
# Run the model under calibration mode to collect quantization parameters
with ipex.AutoMixPrecision(conf, running_mode='calibration'):
y = model(x.to(ipex.DEVICE))
# Save the configuration file
conf.save('configure.json') |
# Import the extension
import intel_extension_for_pytorch as ipex
conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_affine)
with torch.no_grad():
for x in xx_c:
with ipex.quantization.calibrate(conf):
y = model(x)
conf.save('configure.json') |
- Inference
v1.9.0-cpu | v1.10.0-cpu |
# Import the extension
import intel_pytorch_extension as ipex
# Convert the model to the Extension device
model = Model().to(ipex.DEVICE)
conf = ipex.AmpConf(torch.int8, 'configure.json')
with torch.no_grad():
for x in cali_dataset:
with ipex.AutoMixPrecision(conf, running_mode='inference'):
y = model(x.to(ipex.DEVICE)) |
# Import the extension
import intel_extension_for_pytorch as ipex
conf = ipex.quantization.QuantConf('configure.json')
with torch.no_grad():
trace_model = ipex.quantization.convert(model, conf, example_input)
for x in xx_v:
y = trace_model(x) |
-
This release introduces the
optimize
API at the python front end to optimize the model. The new API supports FP32 and BF16, inference, and training. -
Runtime Extension (Experimental) provides a runtime CPU pool API to bind threads to cores. It also features async tasks. Please Note: Intel® Extension for PyTorch* Runtime extension is still in the POC stage. The API is subject to change. More detailed descriptions are available in the extension documentation.
Known Issues
-
omp_set_num_threads
function failed to change OpenMP threads number of oneDNN operators if it was set before.omp_set_num_threads
function is provided in Intel® Extension for PyTorch* to change the number of threads used with OpenMP. However, it failed to change the number of OpenMP threads if it was set before.pseudo-code:
omp_set_num_threads(6) model_execution() omp_set_num_threads(4) same_model_execution_again()
Reason: oneDNN primitive descriptor stores the OMP number of threads. Current oneDNN integration caches the primitive descriptor in the extension. So if we use runtime extension with oneDNN based on top of PyTorch or the extension, the runtime extension fails to change the used OMP number of threads.
-
Low performance with INT8 support for dynamic shapes
The support for dynamic shapes in Intel® Extension for PyTorch* INT8 integration is still working in progress. For the use cases where the input shapes are dynamic, for example, inputs of variable image sizes in an object detection task or of variable sequence lengths in NLP tasks, the Intel® Extension for PyTorch* INT8 path may slow down the model inference. In this case, please utilize stock PyTorch INT8 functionality.
-
Low throughput with DLRM FP32 Train
A 'Sparse Add' PR is pending review. The issue will be fixed when the PR is merged.
What's Changed
Full Changelog: v1.9.0...v1.10.0