diff --git a/csrc/gpu/aten/core/DeviceInfo.h b/csrc/gpu/aten/core/DeviceInfo.h index aa13b5f50..2bd0a4b01 100644 --- a/csrc/gpu/aten/core/DeviceInfo.h +++ b/csrc/gpu/aten/core/DeviceInfo.h @@ -35,6 +35,9 @@ struct DeviceInfo { std::vector sub_group_sizes; bool support_fp64; bool support_cl_bf16_conversion; + bool support_cl_sg_matmul_acc; + bool support_cl_sg_matmul_acc_tf32; + bool support_cl_sg_2d_block_io; }; } // namespace dpcpp diff --git a/csrc/gpu/runtime/Device.cpp b/csrc/gpu/runtime/Device.cpp index ce89de451..fcf679453 100644 --- a/csrc/gpu/runtime/Device.cpp +++ b/csrc/gpu/runtime/Device.cpp @@ -322,8 +322,11 @@ static void initDeviceProperty(DeviceId device_id) { : 8; device_prop.support_atomic64 = device.has(dpcpp_dev_aspect_atomic64); device_prop.support_fp64 = device.has(dpcpp_dev_aspect_fp64); - sycl::ext::oneapi::experimental::cl_version version{20, 20, 20}; + sycl::ext::oneapi::experimental::cl_version version; device_prop.support_cl_bf16_conversion = device.ext_oneapi_supports_cl_extension("cl_intel_bfloat16_conversions", &version); + device_prop.support_cl_sg_matmul_acc = device.ext_oneapi_supports_cl_extension("cl_intel_subgroup_matrix_multiply_accumulate", &version); + device_prop.support_cl_sg_matmul_acc_tf32 = device.ext_oneapi_supports_cl_extension("cl_intel_subgroup_matrix_multiply_accumulate_tensor_float32", &version); + device_prop.support_cl_sg_2d_block_io = device.ext_oneapi_supports_cl_extension("cl_intel_subgroup_2d_block_io", &version); device_properties[device_id] = device_prop; @@ -358,9 +361,12 @@ static void initDeviceProperty(DeviceId device_id) { dev_info.max_num_sub_groups = device_prop.max_num_subgroup; dev_info.sub_group_sizes = device_prop.subgroup_sizes; dev_info.support_fp64 = device_prop.support_fp64; + dev_info.support_cl_bf16_conversion = device_prop.support_cl_bf16_conversion; + dev_info.support_cl_sg_matmul_acc = device_prop.support_cl_sg_matmul_acc; + dev_info.support_cl_sg_matmul_acc_tf32 = device_prop.support_cl_sg_matmul_acc_tf32; + dev_info.support_cl_sg_2d_block_io = device_prop.support_cl_sg_2d_block_io; #if (defined(__INTEL_LLVM_COMPILER) && __INTEL_LLVM_COMPILER >= 20240100) dev_info.device_arch = static_cast(device_prop.device_arch); - dev_info.support_cl_bf16_conversion = device_prop.support_cl_bf16_conversion; #else dev_info.device_arch = (uint64_t)0; #endif diff --git a/csrc/gpu/runtime/DeviceProp.h b/csrc/gpu/runtime/DeviceProp.h index f2af1843d..b4bd5ce70 100644 --- a/csrc/gpu/runtime/DeviceProp.h +++ b/csrc/gpu/runtime/DeviceProp.h @@ -144,6 +144,9 @@ struct DeviceProp { bool support_fp64; bool support_atomic64; bool support_cl_bf16_conversion; + bool support_cl_sg_matmul_acc; + bool support_cl_sg_matmul_acc_tf32; + bool support_cl_sg_2d_block_io; }; } // namespace dpcpp diff --git a/intel_extension_for_pytorch/csrc/xpu/Module.cpp b/intel_extension_for_pytorch/csrc/xpu/Module.cpp index 34bc3117e..fb0a96a3f 100644 --- a/intel_extension_for_pytorch/csrc/xpu/Module.cpp +++ b/intel_extension_for_pytorch/csrc/xpu/Module.cpp @@ -578,6 +578,9 @@ static void register_xpu_device_info(PyObject* module) { .def_readonly("sub_group_sizes", &DeviceInfo::sub_group_sizes) .def_readonly("has_fp64", &DeviceInfo::support_fp64) .def_readonly("support_cl_bf16_conversion", &DeviceInfo::support_cl_bf16_conversion) + .def_readonly("support_cl_sg_matmul_acc", &DeviceInfo::support_cl_sg_matmul_acc) + .def_readonly("support_cl_sg_matmul_acc_tf32", &DeviceInfo::support_cl_sg_matmul_acc_tf32) + .def_readonly("support_cl_sg_2d_block_io", &DeviceInfo::support_cl_sg_2d_block_io) .def_readonly("device_arch", &DeviceInfo::device_arch) .def_property_readonly( "dev_type", [](const DeviceInfo& info) { return get_dev_type(info); }) @@ -591,7 +594,11 @@ static void register_xpu_device_info(PyObject* module) { << "MB, max_compute_units=" << info.max_compute_units << ", gpu_eu_count=" << info.gpu_eu_count << ", device_arch=" << info.device_arch - << ", support_cl_bf16_conversion=" << info.support_cl_bf16_conversion << ")"; + << ", support_cl_bf16_conversion=" << info.support_cl_bf16_conversion + << ", support_cl_sg_matmul_acc=" << info.support_cl_sg_matmul_acc + << ", support_cl_sg_matmul_acc_tf32=" << info.support_cl_sg_matmul_acc_tf32 + << ", support_cl_sg_2d_block_io=" << info.support_cl_sg_2d_block_io + << ")"; return stream.str(); }); }