Intel® Extension for PyTorch* v1.12.0-cpu Release Notes
We are excited to bring you the release of Intel® Extension for PyTorch* 1.12.0-cpu, by tightly following PyTorch 1.12 release. In this release, we matured the automatic int8 quantization and made it a stable feature. We stabilized runtime extension and brought about a MultiStreamModule feature to further boost throughput in offline inference scenario. We also brought about various enhancements in operation and graph which are positive for the performance of broad set of workloads.
- Automatic INT8 quantization became a stable feature baking into a well-tuned default quantization recipe, supporting both static and dynamic quantization and a wide range of calibration algorithms.
- Runtime Extension, featured MultiStreamModule, became a stable feature, could further enhance throughput in offline inference scenario.
- More optimizations in graph and operations to improve performance of broad set of models, examples include but not limited to wave2vec, T5, Albert etc.
- Pre-built experimental binary with oneDNN Graph Compiler tuned on would deliver additional performance gain for Bert, Albert, Roberta in INT8 inference.
Highlights
- Matured automatic INT8 quantization feature baking into a well-tuned default quantization recipe. We facilitated the user experience and provided a wide range of calibration algorithms like Histogram, MinMax, MovingAverageMinMax, etc. Meanwhile, We polished the static quantization with better flexibility and enabled dynamic quantization as well. Compared to the previous version, the brief changes are as follows. Refer to tutorial page for more details.
v1.11.0-cpu | v1.12.0-cpu |
import intel_extension_for_pytorch as ipex
# Calibrate the model
qconfig = ipex.quantization.QuantConf(qscheme=torch.per_tensor_affine)
for data in calibration_data_set:
with ipex.quantization.calibrate(qconfig):
model_to_be_calibrated(x)
qconfig.save('qconfig.json')
# Convert the model to jit model
conf = ipex.quantization.QuantConf('qconfig.json')
with torch.no_grad():
traced_model = ipex.quantization.convert(model, conf, example_input)
# Do inference
y = traced_model(x) |
import intel_extension_for_pytorch as ipex
# Calibrate the model
qconfig = ipex.quantization.default_static_qconfig # Histogram calibration algorithm and
calibrated_model = ipex.quantization.prepare(model_to_be_calibrated, qconfig, example_inputs=example_inputs)
for data in calibration_data_set:
calibrated_model(data)
# Convert the model to jit model
quantized_model = ipex.quantization.convert(calibrated_model)
with torch.no_grad():
traced_model = torch.jit.trace(quantized_model, example_input)
traced_model = torch.jit.freeze(traced_model)
# Do inference
y = traced_model(x) |
- Runtime Extension, featured MultiStreamModule, became a stable feature. In this release, we enhanced the heuristic rule to further enhance throughput in offline inference scenario. Meanwhile, we also provide the
ipex.cpu.runtime.MultiStreamModuleHint
to custom how to split the input into streams and concat the output for each steam.
v1.11.0-cpu | v1.12.0-cpu |
import intel_extension_for_pytorch as ipex
# Create CPU pool
cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
# Create multi-stream model
multi_Stream_model = ipex.cpu.runtime.MultiStreamModule(model, num_streams=2, cpu_pool=cpu_pool) |
import intel_extension_for_pytorch as ipex
# Create CPU pool
cpu_pool = ipex.cpu.runtime.CPUPool(node_id=0)
# Optional
multi_stream_input_hint = ipex.cpu.runtime.MultiStreamModuleHint(0)
multi_stream_output_hint = ipex.cpu.runtime.MultiStreamModuleHint(0)
# Create multi-stream model
multi_Stream_model = ipex.cpu.runtime.MultiStreamModule(model, num_streams=2, cpu_pool=cpu_pool,
multi_stream_input_hint, # optional
multi_stream_output_hint ) # optional |
- Polished the
ipex.optimize
to accept the input shape information which would conclude the optimal memory layout for better kernel efficiency.
v1.11.0-cpu | v1.12.0-cpu |
import intel_extension_for_pytorch as ipex
model = ...
model.load_state_dict(torch.load(PATH))
model.eval()
optimized_model = ipex.optimize(model, dtype=torch.bfloat16) |
import intel_extension_for_pytorch as ipex
model = ...
model.load_state_dict(torch.load(PATH))
model.eval()
optimized_model = ipex.optimize(model, dtype=torch.bfloat16, sample_input=input) |
-
Provided a pre-built experimental binary with oneDNN Graph Compiler turned on, which would deliver additional performance gain for Bert, Albert, and Roberta in INT8 inference.
-
Provided more optimizations in graph and operations
- Fuse Adam to improve training performance #822
- Enable Normalization operators to support channels-last 3D #642
- Support Deconv3D to serve most models and implement most fusions like Conv
- Enable LSTM to support static and dynamic quantization #692
- Enable Linear to support dynamic quantization #787
- Fusions.
- Fuse
Add
+Swish
to accelerate FSI Riskful model #551 - Fuse
Conv
+LeakyReLU
#589 - Fuse
BMM
+Add
#407 - Fuse
Concat
+BN
+ReLU
#647 - Optimize
Convolution1D
to support channels last memory layout and fuseGeLU
as its post operation. #657 - Fuse
Einsum
+Add
to boost Alphafold2 #674 - Fuse
Linear
+Tanh
#711
- Fuse
Known Issues
-
RuntimeError: Overflow when unpacking long
when a tensor's min max value exceeds int range while performing int8 calibration. Please customize QConfig to use min-max calibration method. -
Calibrating with quantize_per_tensor, when benchmarking with 1 OpenMP* thread, results might be incorrect with large tensors (find more detailed info here. Editing your code following the pseudocode below can workaround this issue, if you do need to explicitly set OMP_NUM_THREAEDS=1 for benchmarking. However, there could be a performance regression if oneDNN graph compiler prototype feature is utilized.
Workaround pseudocode:
# perform convert/trace/freeze with omp_num_threads > 1(N) torch.set_num_threads(N) prepared_model = prepare(model, input) converted_model = convert(prepared_model) traced_model = torch.jit.trace(converted_model, input) freezed_model = torch.jit.freeze(traced_model) # run freezed model to apply optimization pass freezed_model(input) # benchmarking with omp_num_threads = 1 torch.set_num_threads(1) run_benchmark(freezed_model, input)
-
Low performance with INT8 support for dynamic shapes
The support for dynamic shapes in Intel® Extension for PyTorch* INT8 integration is still work in progress. When the input shapes are dynamic, for example inputs of variable image sizes in an object detection task or of variable sequence lengths in NLP tasks, the Intel® Extension for PyTorch* INT8 path may slow down the model inference. In this case, use stock PyTorch INT8 functionality.
Note: Using Runtime Extension feature if batch size cannot be divided by number of streams, because mini batch size on each stream are not equivalent, scripts run into this issues. -
BF16 AMP(auto-mixed-precision) runs abnormally with the extension on the AVX2-only machine if the topology contains
Conv
,Matmul
,Linear
, andBatchNormalization
-
Runtime extension of MultiStreamModule doesn't support DLRM inference, since the input of DLRM (EmbeddingBag specifically) can't be simplely batch split.
-
Runtime extension of MultiStreamModule has poor performance of RNNT Inference comparing with native throughput mode. Only part of the RNNT models (joint_net specifically) can be jit traced into graph. However, in one batch inference,
joint_net
is invoked multi times. It increases the overhead of MultiStreamModule as input batch split, thread synchronization and output concat. -
Incorrect Conv and Linear result if the number of OMP threads is changed at runtime
The oneDNN memory layout depends on the number of OMP threads, which requires the caller to detect the changes for the # of OMP threads while this release has not implemented it yet. -
Low throughput with DLRM FP32 Train
A 'Sparse Add' PR is pending on review. The issue will be fixed when the PR is merged. -
If inference is done with a custom function,
conv+bn
folding feature of theipex.optimize()
function doesn't work.import torch import intel_pytorch_extension as ipex class Module(torch.nn.Module): def __init__(self): super(Module, self).__init__() self.conv = torch.nn.Conv2d(1, 10, 5, 1) self.bn = torch.nn.BatchNorm2d(10) self.relu = torch.nn.ReLU() def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x def inference(self, x): return self.forward(x) if __name__ == '__main__': m = Module() m.eval() m = ipex.optimize(m, dtype=torch.float32, level="O0") d = torch.rand(1, 1, 112, 112) with torch.no_grad(): m.inference(d)
This is a PyTorch FX limitation. You can avoid this error by calling
m = ipex.optimize(m, level="O0")
, which doesn't apply ipex optimization, or disableconv+bn
folding by callingm = ipex.optimize(m, level="O1", conv_bn_folding=False)
.