Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/setup-vulkan-linux-deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ install_vulkan_sdk() {
export PATH="${PATH}:${_vulkan_sdk_dir}/${VULKAN_SDK_VERSION}/x86_64/bin/"
}

VULKAN_SDK_VERSION="1.3.296.0"
VULKAN_SDK_VERSION="1.4.321.1"

install_swiftshader
install_vulkan_sdk "${VULKAN_SDK_VERSION}"
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) {
shader.kernel_name, vkapi::VulkanExtension::INT8_STORAGE);
}
}
if (shader.requires_integer_dot_product) {
if (!adapter_p_->supports_int8_dot_product()) {
throw vkapi::ShaderNotSupportedError(
shader.kernel_name, vkapi::VulkanExtension::INTEGER_DOT_PRODUCT);
}
}
}

vkapi::DescriptorSet Context::get_descriptor_set(
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,7 @@ class ShaderInfo:
requires_shader_int16_ext: bool = False
requires_16bit_storage_ext: bool = False
requires_8bit_storage_ext: bool = False
requires_integer_dot_product_ext: bool = False


def getName(filePath: str) -> str:
Expand Down Expand Up @@ -1213,6 +1214,8 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
shader_info.requires_16bit_storage_ext = True
if "GL_EXT_shader_8bit_storage" in line:
shader_info.requires_8bit_storage_ext = True
if "GL_EXT_integer_dot_product" in line:
shader_info.requires_integer_dot_product_ext = True

return shader_info

Expand Down Expand Up @@ -1288,6 +1291,7 @@ def to_cpp_str(val: bool):
to_cpp_str(shader_info.requires_shader_int16_ext),
to_cpp_str(shader_info.requires_16bit_storage_ext),
to_cpp_str(shader_info.requires_8bit_storage_ext),
to_cpp_str(shader_info.requires_integer_dot_product_ext),
]

shader_info_str = textwrap.indent(
Expand Down
5 changes: 5 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ ComputeGraph::ComputeGraph(GraphConfig config)
config_.execute_threshold_node_count = 128;
config_.execute_initial_threshold_node_count = 64;
}

// Check if the underlying GPU can access accelerated integer dot product
// instructions
can_use_int8_dot_product_ =
context_->adapter_ptr()->supports_int8_dot_product();
}

ComputeGraph::~ComputeGraph() {
Expand Down
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ class ComputeGraph final {
// config.execute_threshold_node_count.
size_t execute_threshold_node_count_ = 0;

// Whether the underlying GPU support accelerated integer dot product
// extensions
bool can_use_int8_dot_product_ = false;

public:
//
// Accessors
Expand Down Expand Up @@ -1013,6 +1017,10 @@ class ComputeGraph final {
return execute_count_;
}

inline bool can_use_int8_dot_product() const {
return can_use_int8_dot_product_;
}

/*
* Check whether the GPU supports 8 bit buffers.
*/
Expand Down
112 changes: 112 additions & 0 deletions backends/vulkan/runtime/vk_api/Adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ VkDevice create_logical_device(
#ifdef VK_KHR_shader_float16_int8
VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME,
#endif /* VK_KHR_shader_float16_int8 */
#ifdef VK_KHR_shader_integer_dot_product
VK_KHR_SHADER_INTEGER_DOT_PRODUCT_EXTENSION_NAME,
#endif /* VK_KHR_shader_integer_dot_product */
#if defined(VK_KHR_pipeline_executable_properties) && defined(VULKAN_DEBUG)
VK_KHR_PIPELINE_EXECUTABLE_PROPERTIES_EXTENSION_NAME,
#endif /* VK_KHR_pipeline_executable_properties */
Expand Down Expand Up @@ -160,6 +163,14 @@ VkDevice create_logical_device(
extension_list_top = &shader_float16_int8_types;
#endif /* VK_KHR_shader_float16_int8 */

#ifdef VK_KHR_shader_integer_dot_product
VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR
shader_int_dot_product_features{
physical_device.shader_int_dot_product_features};
shader_int_dot_product_features.pNext = extension_list_top;
extension_list_top = &shader_int_dot_product_features;
#endif /* VK_KHR_shader_integer_dot_product */

device_create_info.pNext = extension_list_top;

VkDevice handle = nullptr;
Expand Down Expand Up @@ -401,6 +412,107 @@ std::string Adapter::stringize() const {
#endif /* VK_KHR_shader_float16_int8 */
ss << " }" << std::endl;

#ifdef VK_KHR_shader_integer_dot_product
ss << " Shader Integer Dot Product Features {" << std::endl;
PRINT_PROP(
physical_device_.shader_int_dot_product_features,
shaderIntegerDotProduct);
ss << " }" << std::endl;

ss << " Shader Integer Dot Product Properties {" << std::endl;
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct8BitUnsignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct8BitSignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct8BitMixedSignednessAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct4x8BitPackedUnsignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct4x8BitPackedSignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct4x8BitPackedMixedSignednessAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct16BitUnsignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct16BitSignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct16BitMixedSignednessAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct32BitUnsignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct32BitSignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct32BitMixedSignednessAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct64BitUnsignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct64BitSignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProduct64BitMixedSignednessAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating8BitUnsignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating8BitSignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating8BitMixedSignednessAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating4x8BitPackedUnsignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating4x8BitPackedSignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating4x8BitPackedMixedSignednessAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating16BitUnsignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating16BitSignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating16BitMixedSignednessAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating32BitUnsignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating32BitSignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating32BitMixedSignednessAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating64BitUnsignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating64BitSignedAccelerated);
PRINT_PROP(
physical_device_.shader_int_dot_product_properties,
integerDotProductAccumulatingSaturating64BitMixedSignednessAccelerated);
ss << " }" << std::endl;
#endif /* VK_KHR_shader_integer_dot_product */

const VkPhysicalDeviceMemoryProperties& mem_props =
physical_device_.memory_properties;

Expand Down
9 changes: 9 additions & 0 deletions backends/vulkan/runtime/vk_api/Adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,15 @@ class Adapter final {
#endif /* VK_KHR_shader_float16_int8 */
}

inline bool supports_int8_dot_product() {
#ifdef VK_KHR_shader_integer_dot_product
return physical_device_.shader_int_dot_product_features
.shaderIntegerDotProduct == VK_TRUE;
#else
return false;
#endif /* VK_KHR_shader_integer_dot_product */
}

inline bool supports_int16_shader_types() {
return physical_device_.supports_int16_shader_types;
}
Expand Down
13 changes: 13 additions & 0 deletions backends/vulkan/runtime/vk_api/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
shader_float16_int8_types{
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR},
#endif /* VK_KHR_shader_float16_int8 */
#ifdef VK_KHR_shader_integer_dot_product
shader_int_dot_product_features{
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR},
shader_int_dot_product_properties{
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_PROPERTIES_KHR},
#endif
queue_families{},
num_compute_queues(0),
supports_int16_shader_types(false),
Expand Down Expand Up @@ -77,6 +83,13 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
extension_list_top = &shader_float16_int8_types;
#endif /* VK_KHR_shader_float16_int8 */

#ifdef VK_KHR_shader_integer_dot_product
shader_int_dot_product_features.pNext = extension_list_top;
extension_list_top = &shader_int_dot_product_features;
shader_int_dot_product_properties.pNext = extension_list_top;
extension_list_top = &shader_int_dot_product_properties;
#endif /* VK_KHR_shader_integer_dot_product */

features2.pNext = extension_list_top;

vkGetPhysicalDeviceFeatures2(handle, &features2);
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/vk_api/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ struct PhysicalDevice final {
#ifdef VK_KHR_shader_float16_int8
VkPhysicalDeviceShaderFloat16Int8Features shader_float16_int8_types;
#endif /* VK_KHR_shader_float16_int8 */
#ifdef VK_KHR_shader_integer_dot_product
VkPhysicalDeviceShaderIntegerDotProductFeatures
shader_int_dot_product_features;
VkPhysicalDeviceShaderIntegerDotProductProperties
shader_int_dot_product_properties;
#endif /* VK_KHR_shader_integer_dot_product */

// Available GPU queues
std::vector<VkQueueFamilyProperties> queue_families;
Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/vk_api/Exception.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ std::ostream& operator<<(std::ostream& out, const VulkanExtension result) {
case VulkanExtension::INT8_STORAGE:
out << "VK_KHR_8bit_storage";
break;
case VulkanExtension::INTEGER_DOT_PRODUCT:
out << "VK_KHR_shader_integer_dot_product";
break;
}
return out;
}
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/vk_api/Exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ enum class VulkanExtension : uint8_t {
SHADER_INT16,
INT16_STORAGE,
INT8_STORAGE,
INTEGER_DOT_PRODUCT,
};

class ShaderNotSupportedError : public std::exception {
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/vk_api/QueryPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ std::string QueryPool::generate_string_report() {

std::stringstream ss;

int kernel_name_w = 40;
int kernel_name_w = 120;
int global_size_w = 25;
int local_size_w = 25;
int duration_w = 25;
Expand Down
6 changes: 4 additions & 2 deletions backends/vulkan/runtime/vk_api/Shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ ShaderInfo::ShaderInfo(
const utils::uvec3 tile_size,
const bool requires_shader_int16_ext,
const bool requires_16bit_storage_ext,
const bool requires_8bit_storage_ext)
const bool requires_8bit_storage_ext,
const bool requires_integer_dot_product_ext)
: src_code{
spirv_bin,
size,
Expand All @@ -41,7 +42,8 @@ ShaderInfo::ShaderInfo(
out_tile_size(tile_size),
requires_shader_int16(requires_shader_int16_ext),
requires_16bit_storage(requires_16bit_storage_ext),
requires_8bit_storage(requires_8bit_storage_ext) {
requires_8bit_storage(requires_8bit_storage_ext),
requires_integer_dot_product(requires_integer_dot_product_ext) {
}

bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {
Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/vk_api/Shader.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ struct ShaderInfo final {
bool requires_shader_int16 = false;
bool requires_16bit_storage = false;
bool requires_8bit_storage = false;
bool requires_integer_dot_product = false;

explicit ShaderInfo();

Expand All @@ -76,7 +77,8 @@ struct ShaderInfo final {
const utils::uvec3 tile_size,
const bool requires_shader_int16_ext,
const bool requires_16bit_storage_ext,
const bool requires_8bit_storage_ext);
const bool requires_8bit_storage_ext,
const bool requires_integer_dot_product_ext);

operator bool() const {
return src_code.bin != nullptr;
Expand Down
Loading