From 99a70a6152b3a94a99b7019fdbdbd78bf4d8af2c Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Wed, 22 Mar 2023 12:10:30 +0900 Subject: [PATCH] update doc for 2.0.0 release (#1486) update tuning guide add known issues add release notes update cppsdk installation guide --- docs/tutorials/getting_started.md | 2 +- docs/tutorials/installation.md | 2 +- .../performance_tuning/known_issues.md | 112 +++++++++++------- .../performance_tuning/tuning_guide.md | 54 +++++---- docs/tutorials/releases.md | 32 +++++ 5 files changed, 137 insertions(+), 65 deletions(-) diff --git a/docs/tutorials/getting_started.md b/docs/tutorials/getting_started.md index eed80291e..09f7de64b 100644 --- a/docs/tutorials/getting_started.md +++ b/docs/tutorials/getting_started.md @@ -51,7 +51,7 @@ with torch.no_grad(), torch.cpu.amp.autocast(): model(data) ########################################## -############ T##orchDynamo ############### +############## TorchDynamo ############### model = ipex.optimize(model) model = torch.compile(model, backend="ipex") diff --git a/docs/tutorials/installation.md b/docs/tutorials/installation.md index 2f261a1c2..b88a7d09d 100644 --- a/docs/tutorials/installation.md +++ b/docs/tutorials/installation.md @@ -152,7 +152,7 @@ docker pull intel/intel-optimized-pytorch:latest |Version|Pre-cxx11 ABI|cxx11 ABI| |--|--|--| -| 2.0.0 | [libintel-ext-pt-2.0.0+cpu.run](http://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/libipex/cpu/libintel-ext-pt-2.0.0%2Bcpu.run) | [libintel-ext-pt-cxx11-abi-2.0.0+cpu.run](http://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/libipex/cpu/libintel-ext-pt-cxx11-abi-2.0.0%2Bcpu.run) | +| 2.0.0 | [libintel-ext-pt-2.0.0+cpu.run](https://intel-extension-for-pytorch.s3.amazonaws.com/libipex/cpu/libintel-ext-pt-2.0.0%2Bcpu.run) | [libintel-ext-pt-cxx11-abi-2.0.0+cpu.run](https://intel-extension-for-pytorch.s3.amazonaws.com/libipex/cpu/libintel-ext-pt-cxx11-abi-2.0.0%2Bcpu.run) | | 1.13.100 | [libintel-ext-pt-1.13.100+cpu.run](http://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/libipex/cpu/libintel-ext-pt-1.13.100%2Bcpu.run) | [libintel-ext-pt-cxx11-abi-1.13.100+cpu.run](http://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/libipex/cpu/libintel-ext-pt-cxx11-abi-1.13.100%2Bcpu.run) | | 1.13.0 | [libintel-ext-pt-1.13.0+cpu.run](http://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/libipex/cpu/libintel-ext-pt-1.13.0%2Bcpu.run) | [libintel-ext-pt-cxx11-abi-1.13.0+cpu.run](http://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/libipex/cpu/libintel-ext-pt-cxx11-abi-1.13.0%2Bcpu.run) | | 1.12.300 | [libintel-ext-pt-1.12.300+cpu.run](http://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/libipex/cpu/libintel-ext-pt-1.12.300%2Bcpu.run) | [libintel-ext-pt-cxx11-abi-1.12.300+cpu.run](http://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/libipex/cpu/libintel-ext-pt-cxx11-abi-1.12.300%2Bcpu.run) | diff --git a/docs/tutorials/performance_tuning/known_issues.md b/docs/tutorials/performance_tuning/known_issues.md index f21b146bb..5095163c8 100644 --- a/docs/tutorials/performance_tuning/known_issues.md +++ b/docs/tutorials/performance_tuning/known_issues.md @@ -1,13 +1,76 @@ Known Issues ============ +## Usage + - There might be Python packages having PyTorch as their hard dependency. If you installed `+cpu` version of PyTorch, installation of these packages might replace the `+cpu` version with the default version released on Pypi.org. If anything goes wrong, please reinstall the `+cpu` version back. - If you found the workload runs with Intel® Extension for PyTorch\* occupies a remarkably large amount of memory, you can try to reduce the occupied memory size by setting the `--weights_prepack` parameter of the `ipex.optimize()` function to `False`. -- Supporting of EmbeddingBag with INT8 when bag size > 1 is working in progress. +- If inference is done with a custom function, `conv+bn` folding feature of the `ipex.optimize()` function doesn't work. + + ``` + import torch + import intel_pytorch_extension as ipex + + class Module(torch.nn.Module): + def __init__(self): + super(Module, self).__init__() + self.conv = torch.nn.Conv2d(1, 10, 5, 1) + self.bn = torch.nn.BatchNorm2d(10) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + def inference(self, x): + return self.forward(x) + + if __name__ == '__main__': + m = Module() + m.eval() + m = ipex.optimize(m, dtype=torch.float32, level="O0") + d = torch.rand(1, 1, 112, 112) + with torch.no_grad(): + m.inference(d) + ``` + + This is a PyTorch FX limitation. You can avoid this error by calling `m = ipex.optimize(m, level="O0")`, which doesn't apply ipex optimization, or disable `conv+bn` folding by calling `m = ipex.optimize(m, level="O1", conv_bn_folding=False)`. + +## TorchDynamo + +- The support of torch.compile() with ipex as the backend is still an experimental feature. If the workload fails to run or demonstrates poor performance, you can use the `torch.jit` APIs and graph optimization APIs of ipex. Currently, the below HuggingFace models fail to run using torch.compile() with ipex backend due to memory issues: + - masked-language-modeling+xlm-roberta-base + - casual-language-modeling+gpt2 + - casual-language-modeling+xlm-roberta-base + - summarization+t5-base + - text-classification+allenai-longformer-base-409 + +## Dynamic Shape + +- When working with an NLP model inference with dynamic input data length appling with TorchScript (either `torch.jit.trace` or `torch.jit.script`), performance with Intel® Extension for PyTorch\* is possible to be less than that without Intel® Extension for PyTorch\*. In this case, adding the workarounds below would help solve this issue. + - Python interface + ```python + torch._C._jit_set_texpr_fuser_enabled(False) + ``` + - C++ interface + ```c++ + #include + torch::jit::setTensorExprFuserEnabled(false); + ``` + +## INT8 -- Compiling with gcc 11 might result in `illegal instruction` error. +- Low performance with INT8 support for dynamic shapes + + The support for dynamic shapes in Intel® Extension for PyTorch\* INT8 integration is still work in progress. When the input shapes are dynamic, for example inputs of variable image sizes in an object detection task or of variable sequence lengths in NLP tasks, the Intel® Extension for PyTorch\* INT8 path may slow down the model inference. In this case, use stock PyTorch INT8 functionality. + + **Note**: Using Runtime Extension feature if batch size cannot be divided by number of streams, because mini batch size on each stream are not equivalent, scripts run into this issues. + +- Supporting of EmbeddingBag with INT8 when bag size > 1 is working in progress. - `RuntimeError: Overflow when unpacking long` when a tensor's min max value exceeds int range while performing int8 calibration. Please customize QConfig to use min-max calibration method. @@ -31,55 +94,24 @@ Known Issues run_benchmark(freezed_model, input) ``` +## BFloat16 + - BF16 AMP(auto-mixed-precision) runs abnormally with the extension on the AVX2-only machine if the topology contains `Conv`, `Matmul`, `Linear`, and `BatchNormalization` +## Runtime Extension + - Runtime extension of MultiStreamModule doesn't support DLRM inference, since the input of DLRM (EmbeddingBag specifically) can't be simplely batch split. - Runtime extension of MultiStreamModule has poor performance of RNNT Inference comparing with native throughput mode. Only part of the RNNT models (joint_net specifically) can be jit traced into graph. However, in one batch inference, `joint_net` is invoked multi times. It increases the overhead of MultiStreamModule as input batch split, thread synchronization and output concat. +## Correctness + - Incorrect Conv and Linear result if the number of OMP threads is changed at runtime The oneDNN memory layout depends on the number of OMP threads, which requires the caller to detect the changes for the # of OMP threads while this release has not implemented it yet. -- Low performance with INT8 support for dynamic shapes - - The support for dynamic shapes in Intel® Extension for PyTorch\* INT8 integration is still work in progress. When the input shapes are dynamic, for example inputs of variable image sizes in an object detection task or of variable sequence lengths in NLP tasks, the Intel® Extension for PyTorch\* INT8 path may slow down the model inference. In this case, use stock PyTorch INT8 functionality. - - **Note**: Using Runtime Extension feature if batch size cannot be divided by number of streams, because mini batch size on each stream are not equivalent, scripts run into this issues. +## Float32 Training - Low throughput with DLRM FP32 Train A 'Sparse Add' [PR](https://github.com/pytorch/pytorch/pull/23057) is pending on review. The issue will be fixed when the PR is merged. - -- If inference is done with a custom function, `conv+bn` folding feature of the `ipex.optimize()` function doesn't work. - - ``` - import torch - import intel_pytorch_extension as ipex - - class Module(torch.nn.Module): - def __init__(self): - super(Module, self).__init__() - self.conv = torch.nn.Conv2d(1, 10, 5, 1) - self.bn = torch.nn.BatchNorm2d(10) - self.relu = torch.nn.ReLU() - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - x = self.relu(x) - return x - - def inference(self, x): - return self.forward(x) - - if __name__ == '__main__': - m = Module() - m.eval() - m = ipex.optimize(m, dtype=torch.float32, level="O0") - d = torch.rand(1, 1, 112, 112) - with torch.no_grad(): - m.inference(d) - ``` - - This is a PyTorch FX limitation. You can avoid this error by calling `m = ipex.optimize(m, level="O0")`, which doesn't apply ipex optimization, or disable `conv+bn` folding by calling `m = ipex.optimize(m, level="O1", conv_bn_folding=False)`. diff --git a/docs/tutorials/performance_tuning/tuning_guide.md b/docs/tutorials/performance_tuning/tuning_guide.md index 842587a53..6dadcfb7c 100644 --- a/docs/tutorials/performance_tuning/tuning_guide.md +++ b/docs/tutorials/performance_tuning/tuning_guide.md @@ -3,9 +3,9 @@ Performance Tuning Guide ## Overview -Intel® Extension for PyTorch\* (IPEX) is a Python package to extend official PyTorch. It makes the out-of-box user experience of PyTorch CPU better while achieving good performance. To fully utilize the power of Intel® architecture and thus yield high performance, PyTorch, as well as IPEX, are powered by [oneAPI Deep Neural Network Library (oneDNN)](https://github.com/oneapi-src/oneDNN), an open-source cross-platform performance library of basic building blocks for deep learning applications. It is developed and optimized for Intel Architecture Processors, Intel Processor Graphics, and Xe architecture-based Graphics. +Intel® Extension for PyTorch\* is a Python package to extend official PyTorch. It makes the out-of-box user experience of PyTorch CPU better while achieving good performance. To fully utilize the power of Intel® architecture and thus yield high performance, PyTorch, as well as Intel® Extension for PyTorch\*, are powered by [oneAPI Deep Neural Network Library (oneDNN)](https://github.com/oneapi-src/oneDNN), an open-source cross-platform performance library of basic building blocks for deep learning applications. It is developed and optimized for Intel Architecture Processors, Intel Processor Graphics, and Xe architecture-based Graphics. -Although default primitives of PyTorch and IPEX are highly optimized, there are things users can do improve performance. Most optimized configurations can be automatically set by the launcher script. This article introduces common methods recommended by Intel developers. +Although default primitives of PyTorch and Intel® Extension for PyTorch\* are highly optimized, there are things users can do improve performance. Most optimized configurations can be automatically set by the launcher script. This article introduces common methods recommended by Intel developers. ## Contents of this Document * [Hardware Configuration](#hardware-configuration) @@ -16,6 +16,7 @@ Although default primitives of PyTorch and IPEX are highly optimized, there are * [Numactl](#numactl) * [OpenMP](#openmp) * [OMP_NUM_THREADS](#omp-num-threads) + * [OMP_THREAD_LIMIT](#omp-thread-limit) * [GNU OpenMP](#gnu-openmp) * [Intel OpenMP](#intel-openmp) * [Memory Allocator](#memory-allocator) @@ -60,7 +61,8 @@ Figure 3: An ASUS Z11PA-D8 Intel® Xeon® server motherboard. It contains two so It is a good thing that more and more CPU cores are provided to users in one socket, because this brings more computation resources. However, this also brings memory access competitions. Program can stall because memory is busy to visit. To address this problem, Non-Uniform Memory Access (NUMA) was introduced. Comparing to Uniform Memory Access (UMA), in which scenario all memories are connected to all cores equally, NUMA tells memories into multiple groups. Certain number of memories are directly attached to one socket's integrated memory controller to become local memory of this socket. As described in the previous section, local memory access is much faster than remote memory access. -Users can get CPU information with ```lscpu``` command on Linux to learn how many cores, sockets there on the machine. Also, NUMA information like how CPU cores are distributed can also be retrieved. The following is an example of ```lscpu``` execution on a machine with two Intel(R) Xeon(R) Platinum 8180M CPUs. 2 sockets were detected. Each socket has 28 physical cores onboard. Since Hyper-Threading is enabled, each core can run 2 threads. I.e. each socket has another 28 logical cores. Thus, there are 112 CPU cores on service. When indexing CPU cores, usually physical cores are indexed before logical core. In this case, the first 28 cores (0-27) are physical cores on the first NUMA socket (node), the second 28 cores (28-55) are physical cores on the second NUMA socket (node). Logical cores are indexed afterward. 56-83 are 28 logical cores on the first NUMA socket (node), 84-111 are the second 28 logical cores on the second NUMA socket (node). Typically, running IPEX should avoid using logical cores to get a good performance. +Users can get CPU information with `lscpu` command on Linux to learn how many cores, sockets there on the machine. Also, NUMA information like how CPU cores are distributed can also be retrieved. The following is an example of `lscpu` execution on a machine with two Intel(R) Xeon(R) Platinum 8180M CPUs. 2 sockets were detected. Each socket has 28 physical cores onboard. Since Hyper-Threading is enabled, each core can run 2 threads. I.e. each socket has another 28 logical cores. Thus, there are 112 CPU cores on service. When indexing CPU cores, usually physical cores are indexed before logical core. In this case, the first 28 cores (0-27) are physical cores on the first NUMA socket (node), the second 28 cores (28-55) are physical cores on the second NUMA socket (node). Logical cores are indexed afterward. 56-83 are 28 logical cores on the first NUMA socket (node), 84-111 are the second 28 logical cores on the second NUMA socket (node). Typically, running Intel® Extension for PyTorch\* should avoid using logical cores to get a good performance. + ``` $ lscpu ... @@ -92,7 +94,7 @@ Since NUMA largely influences memory access performance, this functionality shou During development of Linux kernels, more and more sophisticated implementations/optimizations/strategies had been brought out. Version 2.5 of the Linux kernel already contained basic NUMA support, which was further improved in subsequent kernel releases. Version 3.8 of the Linux kernel brought a new NUMA foundation that allowed development of more efficient NUMA policies in later kernel releases. Version 3.13 of the Linux kernel brought numerous policies that aim at putting a process near its memory, together with the handling of cases such as having memory pages shared between processes, or the use of transparent huge pages. New sysctl settings allow NUMA balancing to be enabled or disabled, as well as the configuration of various NUMA memory balancing parameters.[1] Behavior of Linux kernels are thus different according to kernel version. Newer Linux kernels may contain further optimizations of NUMA strategies, and thus have better performances. For some workloads, NUMA strategy influences performance great. -Linux provides a tool, ```numactl```, that allows user control of NUMA policy for processes or shared memory. It runs processes with a specific NUMA scheduling or memory placement policy. As described in previous section, cores share high-speed cache in one socket, thus it is a good idea to avoid cross socket computations. From a memory access perspective, bounding memory access locally is much faster than accessing remote memories. +Linux provides a tool, `numactl`, that allows user control of NUMA policy for processes or shared memory. It runs processes with a specific NUMA scheduling or memory placement policy. As described in previous section, cores share high-speed cache in one socket, thus it is a good idea to avoid cross socket computations. From a memory access perspective, bounding memory access locally is much faster than accessing remote memories. The following is an example of numactl usage to run a workload on the Nth socket and limit memory access to its local memories on the Nth socket. More detailed description of numactl command can be found [on the numactl man page](https://linux.die.net/man/8/numactl). @@ -118,13 +120,13 @@ Figure 4: A number of parallel block execution threads are forked from primary t Users can control OpenMP behaviors through some environment variables to fit for their workloads. Also, beside GNU OpenMP library ([libgomp](https://gcc.gnu.org/onlinedocs/libgomp/)), Intel provides another OpenMP implementation [libiomp](https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/optimization-and-programming-guide/openmp-support.html) for users to choose from. Environment variables that control behavior of OpenMP threads may differ from libgomp and libiomp. They will be introduced separately in sections below. -GNU OpenMP (libgomp) is the default multi-threading library for both PyTorch and IPEX. +GNU OpenMP (libgomp) is the default multi-threading library for both PyTorch and Intel® Extension for PyTorch\*. [2] [Wikipedia - OpenMP](https://en.wikipedia.org/wiki/OpenMP) #### OMP_NUM_THREADS -Environment variable OMP_NUM_THREADS sets the number of threads used for parallel regions. By default, it is set to be the number of available physical cores. It can be used along with numactl settings, as the following example. If cores 0-3 are on socket 0, this example command runs \ on cores 0-3, with 4 OpenMP threads. +Environment variable `OMP_NUM_THREADS` sets the number of threads used for parallel regions. By default, it is set to be the number of available physical cores. It can be used along with numactl settings, as the following example. If cores 0-3 are on socket 0, this example command runs \ on cores 0-3, with 4 OpenMP threads. This environment variable works on both libgomp and libiomp. @@ -133,13 +135,19 @@ export OMP_NUM_THREADS=4 numactl -C 0-3 --membind 0 python