diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b33eb2f..827fbb5f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,19 +19,20 @@ set(CMAKE_CXX_EXTENSIONS NO) # Determine if SimSIMD is built as a subproject (using `add_subdirectory`) or if it is the main project set(SIMSIMD_IS_MAIN_PROJECT OFF) -if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) +if(CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) set(SIMSIMD_IS_MAIN_PROJECT ON) -endif () +endif() option(SIMSIMD_BUILD_SHARED "Compile a dynamic library" ${SIMSIMD_IS_MAIN_PROJECT}) option(SIMSIMD_BUILD_TESTS "Small compilation tests compile-time and run-time dispatch" OFF) option(SIMSIMD_BUILD_BENCHMARKS "Compile micro-benchmarks for current ISA" OFF) option(SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS "Include BLAS in micro-kernel benchmarks" OFF) +option(SIMSIMD_BUILD_WITH_OPENMP "Enable OpenMP support" OFF) # Default to Release build type if not set -if (NOT CMAKE_BUILD_TYPE) +if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) -endif () +endif() # Global compiler flags for debug and release set(CMAKE_CXX_FLAGS_DEBUG "-g -fsanitize=address") @@ -40,17 +41,18 @@ set(CMAKE_C_FLAGS_DEBUG "-g -fsanitize=address") set(CMAKE_C_FLAGS_RELEASE "-O3") # Compiler-specific flags -if (CMAKE_CXX_COMPILER_ID MATCHES "^(Apple)?Clang$") - if (NOT APPLE) +if(CMAKE_CXX_COMPILER_ID MATCHES "^(Apple)?Clang$") + if(NOT APPLE) add_compile_options(-march=native) - endif () + endif() add_compile_options(-pedantic -ferror-limit=1) -elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") +elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") add_compile_options(-march=native -pedantic -fmax-errors=1 -Wno-tautological-constant-compare) -elseif (CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM") + add_compile_options(-mavx512fp16 -mavx512bf16 -mamx-int8 -mamx-bf16) +elseif(CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM") add_compile_options(-w -ferror-limit=1) -endif () +endif() # Define the header-only library file(GLOB SIMSIMD_SOURCES include/simsimd/*.h) @@ -59,11 +61,9 @@ target_sources(simsimd INTERFACE ${SIMSIMD_SOURCES}) target_include_directories(simsimd INTERFACE "${PROJECT_SOURCE_DIR}/include") # Build benchmarks if required -if (SIMSIMD_BUILD_BENCHMARKS) - # Fetch external dependencies +if(SIMSIMD_BUILD_BENCHMARKS) include(FetchContent) - # Suppress building tests of Google Benchmark set(BENCHMARK_ENABLE_TESTING OFF) set(BENCHMARK_ENABLE_INSTALL OFF) set(BENCHMARK_ENABLE_DOXYGEN OFF) @@ -79,15 +79,16 @@ if (SIMSIMD_BUILD_BENCHMARKS) ) FetchContent_MakeAvailable(benchmark) - # Remove the Google Benchmark's "built in debug warning" - if (CMAKE_BUILD_TYPE STREQUAL "Release") + # Remove the "google benchmark built in debug" warning + if(CMAKE_BUILD_TYPE STREQUAL "Release") target_compile_definitions(benchmark PRIVATE NDEBUG) - endif () + endif() find_package(Threads REQUIRED) add_executable(simsimd_bench scripts/bench.cxx) target_link_libraries(simsimd_bench simsimd Threads::Threads benchmark) + # BLAS support if (SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS) find_package(BLAS REQUIRED) if (BLAS_FOUND) @@ -101,7 +102,17 @@ if (SIMSIMD_BUILD_BENCHMARKS) endif () endif () -endif () + # OpenMP support + if(SIMSIMD_BUILD_WITH_OPENMP) + find_package(OpenMP REQUIRED) + + if(OpenMP_CXX_FOUND) + target_compile_definitions(simsimd INTERFACE SIMSIMD_BUILD_WITH_OPENMP) + target_compile_options(simsimd INTERFACE ${OpenMP_CXX_FLAGS}) + target_link_libraries(simsimd INTERFACE OpenMP::OpenMP_CXX) + endif() + endif() +endif() if (SIMSIMD_BUILD_TESTS) add_executable(simsimd_test_compile_time scripts/test.c) @@ -110,11 +121,11 @@ if (SIMSIMD_BUILD_TESTS) add_executable(simsimd_test_run_time scripts/test.c c/lib.c) target_compile_definitions(simsimd_test_run_time PRIVATE SIMSIMD_DYNAMIC_DISPATCH=1) target_link_libraries(simsimd_test_run_time simsimd m) -endif () +endif() -if (SIMSIMD_BUILD_SHARED) +if(SIMSIMD_BUILD_SHARED) set(SIMSIMD_SOURCES ${SIMSIMD_SOURCES} c/lib.c) add_library(simsimd_shared SHARED ${SIMSIMD_SOURCES}) target_include_directories(simsimd_shared PUBLIC "${PROJECT_SOURCE_DIR}/include") set_target_properties(simsimd_shared PROPERTIES OUTPUT_NAME simsimd) -endif () +endif() diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f7b9b91c..d27fe8c6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,7 +12,10 @@ To rerun experiments utilize the following command: ```sh sudo apt install libopenblas-dev # BLAS installation is optional, but recommended for benchmarks -cmake -D CMAKE_BUILD_TYPE=Release -D SIMSIMD_BUILD_TESTS=1 -D SIMSIMD_BUILD_BENCHMARKS=1 -D SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS=1 -B build_release +cmake -D CMAKE_BUILD_TYPE=Release \ + -D SIMSIMD_BUILD_TESTS=1 -D SIMSIMD_BUILD_BENCHMARKS=1 \ + -D SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS=1 -D SIMSIMD_BUILD_WITH_OPENMP=1 \ + -B build_release cmake --build build_release --config Release build_release/simsimd_bench build_release/simsimd_bench --benchmark_filter=js @@ -63,6 +66,24 @@ cmake -D CMAKE_BUILD_TYPE=Release \ cmake --build build_release --config Release ``` +Similarly, using Clang on Linux: + +```sh +sudo apt install clang +cmake -D CMAKE_BUILD_TYPE=Release \ + -D SIMSIMD_BUILD_TESTS=1 -D SIMSIMD_BUILD_BENCHMARKS=1 \ + -D SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS=1 -D SIMSIMD_BUILD_WITH_OPENMP=0 \ + -D CMAKE_C_COMPILER=clang -D CMAKE_CXX_COMPILER=clang++ \ + -B build_release +cmake --build build_release --config Release +``` + +I'd recommend putting the following breakpoints: + +- `__asan::ReportGenericError` - to detect illegal memory accesses. +- `__GI_exit` - to stop at exit points - the end of running any executable. +- `__builtin_unreachable` - to catch all the places where the code is expected to be unreachable. + ## Python Testing: diff --git a/README.md b/README.md index 35b159fe..4ae82315 100644 --- a/README.md +++ b/README.md @@ -821,7 +821,9 @@ Possibly, in the future: Last, but not the least - don't build unless there is a demand for it. So if you have a specific use-case, please open an issue or a pull request, and ideally, bring in more users with similar needs. -### Cosine Similarity, Reciprocal Square Root, and Newton-Raphson Iteration +### Spatial Affinity + +> On cosine similarity, reciprocal square root approximations, and the Newton-Raphson iteration. The cosine similarity is the most common and straightforward metric used in machine learning and information retrieval. Interestingly, there are multiple ways to shoot yourself in the foot when computing it. @@ -876,7 +878,10 @@ On 1536-dimensional inputs on Intel Sapphire Rapids CPU a single such iteration | `float32` | 2.21e-08 ± 1.65e-08 | 3.47e-07 ± 3.49e-07 | 3.77e-09 ± 2.84e-09 | | `float64` | 0.00e+00 ± 0.00e+00 | 3.80e-07 ± 4.50e-07 | 1.35e-11 ± 1.85e-11 | -### Curved Spaces, Mahalanobis Distance, and Bilinear Quadratic Forms + +### Curved Spaces + +> On non-Euclidean geometry, Mahalanobis distance, and Bilinear Quadratic Forms. The Mahalanobis distance is a generalization of the Euclidean distance, which takes into account the covariance of the data. It's very similar in its form to the bilinear form, which is a generalization of the dot product. @@ -906,7 +911,9 @@ A $vector * matrix * vector$ product is a scalar, whereas its constituent parts SimSIMD doesn't produce intermediate vector results, like `a @ M @ b`, but computes the bilinear form directly. -### Set Intersection, Galloping, and Binary Search +### Sparse Vectors + +> On sorted set intersections, "Galloping", and Binary Search. The set intersection operation is generally defined as the number of elements that are common between two sets, represented as sorted arrays of integers. The most common way to compute it is a linear scan: @@ -934,7 +941,9 @@ Third approach is to use the SIMD instructions to compare multiple elements at o After benchmarking, the last approach was chosen, as it's the most flexible and often the fastest. -### Complex Dot Products, Conjugate Dot Products, and Complex Numbers +### Dot Products + +> On complex numbers, complex conjugates. Complex dot products are a generalization of the dot product to complex numbers. They are supported by most BLAS packages, but almost never in mixed precision. @@ -971,7 +980,12 @@ def vdot(a: List[number], b: List[number]) -> number: return ab_real, ab_imaginary ``` -### Logarithms in Kullback-Leibler & Jensen–Shannon Divergences +### Binary Represntations + + +### Probability Distributions + +> On fast logarithm approximations in Kullback-Leibler & Jensen–Shannon divergences. The Kullback-Leibler divergence is a measure of how one probability distribution diverges from a second, expected probability distribution. Jensen-Shannon divergence is a symmetrized and smoothed version of the Kullback-Leibler divergence, which can be used as a distance metric between probability distributions. @@ -986,7 +1000,60 @@ Jensen-Shannon divergence is a symmetrized and smoothed version of the Kullback- Both functions are defined for non-negative numbers, and the logarithm is a key part of their computation. -### Mixed Precision in Fused-Multiply-Add and Weighted Sums +### Dense Matrix Multiplications + +> On dense matrix multiplication, and Normalized Cross-Correlation. + +Unlike most dense matrix multiplication libraries, SimSIMD: + +- Focuses on mixed-precision computation. +- Focuses on row-by-row multiplication, where the second matrix is transposed. +- Outputs into the same precision as the input matrices, normalizing the output, like in cosine similarity calculations. +- Doesn't implement parallelism, making it compatible with arbitrary concurrency models, third-party thread pools and task schedulers, like OpenMP, TBB, or Rayon. + +These decisions makes the algorithm much easier to tile, and makes it more robust to noise and precision loss. +It also makes the multiplication more suitable to integration into multi-step Machine Learning and AI inference pipelines, where the output of one step is the input of the next, and uniform representation is required. + +```math +C_{ij} = \frac{A_i \cdot B_j}{\|A_i\| \|B_j\|} = \frac{\sum_k A_{ik} B_{jk}}{\sqrt{\sum_k A_{ik}^2} \sqrt{\sum_k B_{jk}^2}} +``` + +The conventional matrmix multiplication lacks the denominator and indexes the second matrix differently: + +```math +C_{ij} = A_i \cdot B_j = \frac{\sum_k A_{ik} B_{kj}}{\sqrt{\sum_k A_{ik}^2} \sqrt{\sum_k B_{kj}^2}} +``` + +Important to note, if you are going to use the Normalized Cross Correlation in training pipelines, the difference in the product definition will affect the gradient flow, and is defined differently: + +Gradient with Respect to $A_i$: + +```math +\frac{\partial C_{ij}}{\partial A_i} = \frac{1}{\|A_i\| \|B_j\|} \left( B_j - \frac{A_i \cdot B_j}{\|A_i\|^2} A_i \right) +``` + +Similarly, the gradient with respect to $B_j$ is: + +```math +\frac{\partial C_{ij}}{\partial B_j} = \frac{1}{\|A_i\| \|B_j\|} \left( A_i - \frac{A_i \cdot B_j}{\|B_j\|^2} B_j \right) +``` + +#### Optimizing for Intel AMX + +Intel Advanced Matrix Extensions (AMX) is a new instruction available in Sapphire Rapids CPUs and newer. +It provides 8x special registers, each __1 kilobyte__ in size, enough to store 16 by 16 matrix tile of 4-byte words. + +### Sparse-Sparse Matrix Multiplications + +> On sparse matrix multiplication, and the Compressed Sparse Row format, and the Compressed Sparse Column format. + +Sparse-Sparse matrix multiplication is just "matrix multiplication" when both matrices have sparse. +Hence their values are stored in non-regularly addressable arrays, and their indices are often stored in separate arrays. +Those representations are very effective when $>99%$ of the values are close to zero, and the matrices are large. +With a good algorithm, a 100x improvement in performance can be achieved. +### Elementwise Operations + +> On mixed precision in Fused-Multiply-Add and Weighted Sums. The Fused-Multiply-Add (FMA) operation is a single operation that combines element-wise multiplication and addition with different scaling factors. The Weighted Sum is it's simplified variant without element-wise multiplication. diff --git a/include/simsimd/binary.h b/include/simsimd/binary.h index be452882..57346690 100644 --- a/include/simsimd/binary.h +++ b/include/simsimd/binary.h @@ -3,6 +3,7 @@ * @brief SIMD-accelerated Binary Similarity Measures. * @author Ash Vardanian * @date July 1, 2023 + * @see https://github.com/ashvardanian/simsimd?tab=readme-ov-file#binary-representations * * Contains: * - Hamming distance @@ -49,14 +50,16 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t c SIMSIMD_PUBLIC unsigned char simsimd_popcount_b8(simsimd_b8_t x) { static unsigned char lookup_table[] = { + // 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, // - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8}; + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, // + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, // + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, // + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, // + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, // + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, // + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8, // + }; return lookup_table[x]; } diff --git a/include/simsimd/curved.h b/include/simsimd/curved.h index 59a99fe6..9090c316 100644 --- a/include/simsimd/curved.h +++ b/include/simsimd/curved.h @@ -3,6 +3,7 @@ * @brief SIMD-accelerated Similarity Measures for curved spaces. * @author Ash Vardanian * @date August 27, 2024 + * @see https://github.com/ashvardanian/simsimd?tab=readme-ov-file#curved-spaces * * Contains: * - Bilinear form multiplication @@ -90,6 +91,10 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_genoa(simsimd_bf16_t const* a, simsimd SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_bf16_t const* c, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_bilinear_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_f16_t const* c, simsimd_size_t n, simsimd_distance_t* result); SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_f16_t const* c, simsimd_size_t n, simsimd_distance_t* result); + + +SIMSIMD_PUBLIC void simsimd_bilinear_bf16_sapphire(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_bf16_t const* c, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_sapphire(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_bf16_t const* c, simsimd_size_t n, simsimd_distance_t* result); // clang-format on #define SIMSIMD_MAKE_BILINEAR(name, input_type, accumulator_type, load_and_convert) \ @@ -791,6 +796,100 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const *a, sim *result = _simsimd_sqrt_f32_haswell(_mm512_reduce_add_ph(sum_vec)); } +SIMSIMD_PUBLIC void simsimd_bilinear_bf16_sapphire(simsimd_bf16_t const *a, simsimd_bf16_t const *b, + simsimd_bf16_t const *c, simsimd_size_t n, + simsimd_distance_t *result) { + + __m512 sum_vec = _mm512_setzero_ps(); + // Using Intel AMX instructions we can perform 16x32x16 brain-float multiplication using specialed + // matrix-multiplication hardware, accumulating the results in a 16x16 single-precsion matrix. + SIMSIMD_ALIGN64 simsimd_u8_t tiles_config[64]; + + // Memset in one cycle, like a boss :) + // std::memset(tiles_config, 0, sizeof(tiles_config)); + _mm512_storeu_si512((__m512i *)tiles_config, _mm512_setzero_si512()); + + // Only one pallete is currently supported: + simsimd_u8_t *palette_id_ptr = &tiles_config[0]; + *palette_id_ptr = 1; + + // The geniuses behind AMX decided to use different precisions for the rows and columns. + // Wasted 2 hours of my life not noticing this! + simsimd_u16_t *tiles_colsb_ptr = (simsimd_u16_t *)(&tiles_config[16]); + simsimd_u8_t *tiles_rows_ptr = &tiles_config[48]; + + // Important to note, AMX doesn't care about the real shape of our matrix, + // it only cares about it's own tile shape. Keep it simple, otherwise + // the next person reading this will be painting the walls with their brains :) + tiles_rows_ptr[0] = 16; + tiles_rows_ptr[1] = 16; + tiles_rows_ptr[2] = 16; + tiles_colsb_ptr[0] = 64; + tiles_colsb_ptr[1] = 64; + tiles_colsb_ptr[2] = 64; + + // Intialize the tiles configuration and zero-out the accumulators: + _tile_loadconfig(&tiles_config); + _tile_zero(2); + + // In every iteration we can load 16x32 = 512 values from the first vector. + _tile_loadd(0, a, 64); + + // When multiplying matrix chains, like $X * Y * Z$, the order doesn't matter. + // We can perform $(X * Y) * Z$ or $X * (Y * Z)$, the result will be the same. + // The last variant is better in our case, as X and Z are much smaller than Y. + // + // The output of $YZ = Y * Z$ will be a single row in our case. + simsimd_size_t const vertical_tiles = n / 16; + simsimd_size_t const horizontal_tiles = n / 32; + for (simsimd_size_t matrix_start_row = 0; matrix_start_row + 16 <= n; matrix_start_row += 16) { + + // We need one more loop over here?! + + // The column vector `b` contains twice as many rows as the curvature matrix. + simsimd_size_t vector_progress = matrix_start_row * 2; + simsimd_bf16_t b_reshaped[16][16][2]; + for (simsimd_size_t i = 0; i < 32; ++i) + for (simsimd_size_t j = 0; j < 16; ++j) + // We are shrinking the number of rows in the second matrix by 2x for `bf16` and by 4x for `i8` and + // `u8`. We are also practically making the rows longer by 2x for `bf16` and by 4x for `i8` and `u8`. + b_reshaped[i / 2][j][i % 2] = b[vector_progress + i * 16 + j]; + _tile_loadd(1, &b_reshaped, 64); + + // Within a horizontal band of 16 rows in a curvature matrix, + // we can compute 16 items of the YZ row in parallel, progressing + // aggregating 32 pairs into each of the 16 items. + for (simsimd_size_t matrix_start_col = 0; matrix_start_col + 32 <= n; matrix_start_col += 32) { + // Load a 16x32 block from the curvature matrix. + _tile_loadd(0, c + matrix_start_row * n + matrix_start_col, 64); + // Perform the matrix multiplication. + _tile_dpbf16ps(2, 0, 1); + } + + simsimd_f32_t cb_expanded[16][16]; + _tile_stored(2, &cb_expanded, 64); + + // Now we need to accumulate the results in every column into the top row. + // Each row is exactly 64 bytes or 512 bits, making it perfect for AVX-512. + __m512 cb_collapsed_vec = _mm512_setzero_ps(); + for (simsimd_size_t i = 0; i < 16; ++i) { + __m512 cb_row_vec = _mm512_loadu_ps(cb_expanded[i]); + cb_collapsed_vec = _mm512_add_ps(cb_collapsed_vec, cb_row_vec); + } + + // At this point the `cb_collapsed_vec` contains the product of a band of Y multiplied by the column Z. + // To compute the overall bilinear form, we need to combine it with the respective slice of row X. + __m512 a_vec = _simsimd_bf16x16_to_f32x16_skylake(_mm256_loadu_epi16(a + vector_progress)); + sum_vec = _mm512_fmadd_ps(a_vec, cb_collapsed_vec, sum_vec); + } + + *result = _mm512_reduce_add_ps(sum_vec); +} + +SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_sapphire(simsimd_bf16_t const *a, simsimd_bf16_t const *b, + simsimd_bf16_t const *c, simsimd_size_t n, + simsimd_distance_t *result) {} + #pragma clang attribute pop #pragma GCC pop_options #endif // SIMSIMD_TARGET_SAPPHIRE diff --git a/include/simsimd/dot.h b/include/simsimd/dot.h index 556940b1..c8db5812 100644 --- a/include/simsimd/dot.h +++ b/include/simsimd/dot.h @@ -3,6 +3,7 @@ * @brief SIMD-accelerated Dot Products for Real and Complex numbers. * @author Ash Vardanian * @date February 24, 2024 + * @see https://github.com/ashvardanian/simsimd?tab=readme-ov-file#dot-products * * Contains: * - Dot Product for Real and Complex vectors @@ -1256,6 +1257,15 @@ SIMSIMD_PUBLIC void simsimd_dot_bf16_haswell(simsimd_bf16_t const *a, simsimd_bf #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "bmi2") #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,bmi2"))), apply_to = function) +SIMSIMD_INTERNAL __m512 _simsimd_bf16x16_to_f32x16_skylake(__m256i a) { + // AVX-512 contains `_mm512_cvtpbh_ps`, but that's a sequential instruction + return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)); +} + +SIMSIMD_INTERNAL __m256i _simsimd_f32x16_to_bf16x16_skylake(__m512 a) { + return _mm512_cvtepi32_epi16(_mm512_srli_epi32(_mm512_castps_si512(a), 16)); +} + SIMSIMD_INTERNAL simsimd_f64_t _simsimd_reduce_f32x16_skylake(__m512 a) { __m512 x = _mm512_add_ps(a, _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(0, 0, 3, 2))); __m128 r = _mm512_castps512_ps128(_mm512_add_ps(x, _mm512_shuffle_f32x4(x, x, _MM_SHUFFLE(0, 0, 0, 1)))); diff --git a/include/simsimd/matmul.h b/include/simsimd/matmul.h new file mode 100644 index 00000000..3d2d124f --- /dev/null +++ b/include/simsimd/matmul.h @@ -0,0 +1,621 @@ +/** + * @file matmul.h + * @brief SIMD-accelerated mixed-precision Matrix Multiplication kernels. + * @author Ash Vardanian + * @date September 14, 2024 + * @see https://github.com/ashvardanian/simsimd?tab=readme-ov-file#dense-matrix-multiplications + * + * Implements matrix-multiplication kernels, focusing on mixed precision and row-major layouts. + * Assuming we are multiplying rows-by-rows and normalizing the products with magnitudes, + * as opposed to conventional rows-by-columns dot-products, a more suitable name + * is @b "normalized-cross-correlation" or @b "nxcor" for short. + * + * For datatypes: + * - 64-bit IEEE floating point numbers + * - 32-bit IEEE floating point numbers + * - 16-bit IEEE floating point numbers + * - 16-bit brain floating point numbers + * - 8-bit signed integers + * + * For hardware architectures: + * - x86 (AVX2, AVX512, AMX) + * - Arm (NEON, SVE, SME) + * + * x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/ + * Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/ + * + * Matrix Multiplication in 40 lines of C by Sergey Slotin: https://en.algorithmica.org/hpc/algorithms/matmul/ + * LLaMA Now Goes Faster on CPUs by Justine Tunney: https://justine.lol/matmul/ + * LLM.int8 quantization for PyTorch: https://github.com/bitsandbytes-foundation/bitsandbytes + */ +#ifndef SIMSIMD_MATMUL_H +#define SIMSIMD_MATMUL_H + +#include "types.h" + +#include "dot.h" // `_simsimd_bf16x16_to_f32x16_skylake` + +#ifdef __cplusplus +extern "C" { +#endif + +/* Serial backends for all numeric types. + * By default they use 32-bit arithmetic, unless the arguments themselves contain 64-bit floats. + * For double-precision computation check out the "*_accurate" variants of those "*_serial" functions. + */ +SIMSIMD_PUBLIC void simsimd_nxcor_f64_serial( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f64_t const *a, simsimd_size_t a_stride, // + simsimd_f64_t const *b, simsimd_size_t b_stride, // + simsimd_f64_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_f32_serial( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f32_t const *a, simsimd_size_t a_stride, // + simsimd_f32_t const *b, simsimd_size_t b_stride, // + simsimd_f32_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_f16_serial( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f16_t const *a, simsimd_size_t a_stride, // + simsimd_f16_t const *b, simsimd_size_t b_stride, // + simsimd_f16_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_bf16_serial( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_bf16_t const *a, simsimd_size_t a_stride, // + simsimd_bf16_t const *b, simsimd_size_t b_stride, // + simsimd_bf16_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_i8_serial( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_i8_t const *a, simsimd_size_t a_stride, // + simsimd_i8_t const *b, simsimd_size_t b_stride, // + simsimd_i8_t *c, simsimd_size_t c_stride); + +/* Double-precision serial backends for all numeric types. + * For single-precision computation check out the "*_serial" counterparts of those "*_accurate" functions. + */ +SIMSIMD_PUBLIC void simsimd_nxcor_f32_accurate( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f32_t const *a, simsimd_size_t a_stride, // + simsimd_f32_t const *b, simsimd_size_t b_stride, // + simsimd_f32_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_f16_accurate( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f16_t const *a, simsimd_size_t a_stride, // + simsimd_f16_t const *b, simsimd_size_t b_stride, // + simsimd_f16_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_bf16_accurate( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_bf16_t const *a, simsimd_size_t a_stride, // + simsimd_bf16_t const *b, simsimd_size_t b_stride, // + simsimd_bf16_t *c, simsimd_size_t c_stride); + +/* SIMD-powered backends for Arm NEON, mostly using 32-bit arithmetic over 128-bit words. + * By far the most portable backend, covering most Arm v8 devices, over a billion phones, and almost all + * server CPUs produced before 2023. + */ +SIMSIMD_PUBLIC void simsimd_nxcor_f32_neon( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f32_t const *a, simsimd_size_t a_stride, // + simsimd_f32_t const *b, simsimd_size_t b_stride, // + simsimd_f32_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_f16_neon( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f16_t const *a, simsimd_size_t a_stride, // + simsimd_f16_t const *b, simsimd_size_t b_stride, // + simsimd_f16_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_bf16_neon( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_bf16_t const *a, simsimd_size_t a_stride, // + simsimd_bf16_t const *b, simsimd_size_t b_stride, // + simsimd_bf16_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_i8_neon( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_i8_t const *a, simsimd_size_t a_stride, // + simsimd_i8_t const *b, simsimd_size_t b_stride, // + simsimd_i8_t *c, simsimd_size_t c_stride); + +/* SIMD-powered backends for Arm SVE, mostly using 32-bit arithmetic over variable-length platform-defined word sizes. + * Designed for Arm Graviton 3, Microsoft Cobalt, as well as Nvidia Grace and newer Ampere Altra CPUs. + */ +SIMSIMD_PUBLIC void simsimd_nxcor_f32_sve( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f32_t const *a, simsimd_size_t a_stride, // + simsimd_f32_t const *b, simsimd_size_t b_stride, // + simsimd_f32_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_f16_sve( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f16_t const *a, simsimd_size_t a_stride, // + simsimd_f16_t const *b, simsimd_size_t b_stride, // + simsimd_f16_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_bf16_sve( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f16_t const *a, simsimd_size_t a_stride, // + simsimd_f16_t const *b, simsimd_size_t b_stride, // + simsimd_f16_t *c, simsimd_size_t c_stride); + +/* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer, using 32-bit arithmetic over 256-bit words. + * First demonstrated in 2011, at least one Haswell-based processor was still being sold in 2022 — the Pentium G3420. + * Practically all modern x86 CPUs support AVX2, FMA, and F16C, making it a perfect baseline for SIMD algorithms. + * On other hand, there is no need to implement AVX2 versions of `f32` and `f64` functions, as those are + * properly vectorized by recent compilers. + */ +SIMSIMD_PUBLIC void simsimd_nxcor_f32_haswell( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f32_t const *a, simsimd_size_t a_stride, // + simsimd_f32_t const *b, simsimd_size_t b_stride, // + simsimd_f32_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_f16_haswell( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f16_t const *a, simsimd_size_t a_stride, // + simsimd_f16_t const *b, simsimd_size_t b_stride, // + simsimd_f16_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_bf16_haswell( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_bf16_t const *a, simsimd_size_t a_stride, // + simsimd_bf16_t const *b, simsimd_size_t b_stride, // + simsimd_bf16_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_i8_haswell( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_i8_t const *a, simsimd_size_t a_stride, // + simsimd_i8_t const *b, simsimd_size_t b_stride, // + simsimd_i8_t *c, simsimd_size_t c_stride); + +/* SIMD-powered backends for various generations of AVX512 CPUs. + * Skylake is handy, as it supports masked loads and other operations, avoiding the need for the tail loop. + * Ice Lake added VNNI, VPOPCNTDQ, IFMA, VBMI, VAES, GFNI, VBMI2, BITALG, VPCLMULQDQ, and other extensions for integral + * operations. Genoa added only BF16. Sapphire Rapids added tiled matrix operations in AMX, that we can use for `i8` and + * `bf16` types. + */ +SIMSIMD_PUBLIC void simsimd_nxcor_f32_skylake( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f32_t const *a, simsimd_size_t a_stride, // + simsimd_f32_t const *b, simsimd_size_t b_stride, // + simsimd_f32_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_i8_ice( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_i8_t const *a, simsimd_size_t a_stride, // + simsimd_i8_t const *b, simsimd_size_t b_stride, // + simsimd_i8_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_bf16_genoa( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_bf16_t const *a, simsimd_size_t a_stride, // + simsimd_bf16_t const *b, simsimd_size_t b_stride, // + simsimd_bf16_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_f16_sapphire( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_f16_t const *a, simsimd_size_t a_stride, // + simsimd_f16_t const *b, simsimd_size_t b_stride, // + simsimd_f16_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_bf16_sapphire( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_bf16_t const *a, simsimd_size_t a_stride, // + simsimd_bf16_t const *b, simsimd_size_t b_stride, // + simsimd_bf16_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_i8_sapphire( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_i8_t const *a, simsimd_size_t a_stride, // + simsimd_i8_t const *b, simsimd_size_t b_stride, // + simsimd_i8_t *c, simsimd_size_t c_stride); +SIMSIMD_PUBLIC void simsimd_nxcor_i4x2_sapphire( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_i4x2_t const *a, simsimd_size_t a_stride, // + simsimd_i4x2_t const *b, simsimd_size_t b_stride, // + simsimd_i4x2_t *c, simsimd_size_t c_stride); + +#define SIMSIMD_MAKE_MATMUL(name, input_type, accumulator_type, output_type, load_and_convert, convert_and_store) \ + void simsimd_nxcor_##input_type##_##name(simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, \ + simsimd_##input_type##_t const *a, simsimd_size_t a_stride, \ + simsimd_##input_type##_t const *b, simsimd_size_t b_stride, \ + simsimd_##output_type##_t *c, simsimd_size_t c_stride) { \ + for (simsimd_size_t i = 0; i < a_rows; ++i) { \ + simsimd_##input_type##_t const *a_row = \ + (simsimd_##input_type##_t const *)_simsimd_advance_by_bytes((void *)a, i * a_stride); \ + simsimd_##output_type##_t *c_row = \ + (simsimd_##output_type##_t *)_simsimd_advance_by_bytes((void *)c, i * c_stride); \ + for (simsimd_size_t j = 0; j < b_rows; ++j) { \ + simsimd_##input_type##_t const *b_row = \ + (simsimd_##input_type##_t const *)_simsimd_advance_by_bytes((void *)b, j * b_stride); \ + simsimd_##accumulator_type##_t sum = 0; \ + for (simsimd_size_t k = 0; k < cols; ++k) { \ + simsimd_##accumulator_type##_t aik = load_and_convert(a_row + k); \ + simsimd_##accumulator_type##_t bjk = load_and_convert(b_row + k); \ + sum += aik * bjk; \ + } \ + convert_and_store(sum, c_row + j); \ + } \ + } \ + } + +#define SIMSIMD_MAKE_TILED(name, input_type, accumulator_type, output_type, load_and_convert, convert_and_store, \ + tile_size) \ + void simsimd_nxcor_##input_type##_##name(simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, \ + simsimd_##input_type##_t const *a, simsimd_size_t a_stride, \ + simsimd_##input_type##_t const *b, simsimd_size_t b_stride, \ + simsimd_##output_type##_t *c, simsimd_size_t c_stride) { \ + for (simsimd_size_t ii = 0; ii < a_rows; ii += tile_size) { \ + for (simsimd_size_t jj = 0; jj < b_rows; jj += tile_size) { \ + for (simsimd_size_t kk = 0; kk < cols; kk += tile_size) { \ + simsimd_size_t i_max = (ii + tile_size < a_rows) ? (ii + tile_size) : a_rows; \ + simsimd_size_t j_max = (jj + tile_size < b_rows) ? (jj + tile_size) : b_rows; \ + simsimd_size_t k_max = (kk + tile_size < cols) ? (kk + tile_size) : cols; \ + for (simsimd_size_t i = ii; i < i_max; ++i) { \ + simsimd_##input_type##_t const *a_row = \ + (simsimd_##input_type##_t const *)_simsimd_advance_by_bytes((void *)a, i * a_stride); \ + simsimd_##output_type##_t *c_row = \ + (simsimd_##output_type##_t *)_simsimd_advance_by_bytes((void *)c, i * c_stride); \ + for (simsimd_size_t j = jj; j < j_max; ++j) { \ + simsimd_##input_type##_t const *b_row = \ + (simsimd_##input_type##_t const *)_simsimd_advance_by_bytes((void *)b, j * b_stride); \ + simsimd_##accumulator_type##_t sum = 0; \ + for (simsimd_size_t k = kk; k < k_max; ++k) { \ + simsimd_##accumulator_type##_t aik = load_and_convert(a_row + k); \ + simsimd_##accumulator_type##_t bjk = load_and_convert(b_row + k); \ + sum += aik * bjk; \ + } \ + convert_and_store(sum, c_row + j); \ + } \ + } \ + } \ + } \ + } \ + } + +SIMSIMD_MAKE_TILED(serial, f64, f64, f64, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT, 16) // simsimd_nxcor_f64_serial +SIMSIMD_MAKE_TILED(serial, f32, f32, f32, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT, 16) // simsimd_nxcor_f32_serial +SIMSIMD_MAKE_TILED(serial, f16, f32, f16, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16, 16) // simsimd_nxcor_f16_serial +SIMSIMD_MAKE_TILED(serial, bf16, f32, bf16, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16, 16) // simsimd_nxcor_bf16_serial +SIMSIMD_MAKE_TILED(serial, i8, i64, i8, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT, 16) // simsimd_nxcor_i8_serial +SIMSIMD_MAKE_TILED(accurate, f32, f64, f32, SIMSIMD_DEREFERENCE, SIMSIMD_EXPORT, 16) // simsimd_nxcor_f32_accurate +SIMSIMD_MAKE_TILED(accurate, f16, f64, f16, SIMSIMD_F16_TO_F32, SIMSIMD_F32_TO_F16, 16) // simsimd_nxcor_f16_accurate +SIMSIMD_MAKE_TILED(accurate, bf16, f64, bf16, SIMSIMD_BF16_TO_F32, SIMSIMD_F32_TO_BF16, + 16) // simsimd_nxcor_bf16_accurate + +#if SIMSIMD_TARGET_ARM +#if SIMSIMD_TARGET_NEON +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+simd") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) + +#pragma clang attribute pop +#pragma GCC pop_options +#endif + +#if SIMSIMD_TARGET_NEON_F16 +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+simd+fp16") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function) + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_NEON_F16 + +#if SIMSIMD_TARGET_NEON_BF16 +#pragma GCC push_options +#pragma GCC target("arch=armv8.6-a+simd+bf16") +#pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function) + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_NEON_BF16 + +#if SIMSIMD_TARGET_SVE + +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+sve") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) + +#pragma clang attribute pop +#pragma GCC pop_options + +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+sve+fp16") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+fp16"))), apply_to = function) + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SVE +#endif // SIMSIMD_TARGET_ARM + +#if SIMSIMD_TARGET_X86 +#if SIMSIMD_TARGET_HASWELL +#pragma GCC push_options +#pragma GCC target("avx2", "f16c", "fma") +#pragma clang attribute push(__attribute__((target("avx2,f16c,fma"))), apply_to = function) + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_HASWELL + +#if SIMSIMD_TARGET_SKYLAKE +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "avx512bw", "bmi2") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,avx512bw,bmi2"))), apply_to = function) + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SKYLAKE + +#if SIMSIMD_TARGET_GENOA +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2", "avx512bw", "avx512bf16") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2,avx512bw,avx512bf16"))), apply_to = function) + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_GENOA + +#if SIMSIMD_TARGET_SAPPHIRE +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2", "avx512bw", "avx512fp16") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2,avx512bw,avx512fp16"))), apply_to = function) + +// We are going to implement multiple levels of tiling here. +// One is defined by the AMX tile size, which is (32 x 16) for BF16, 1 KB per each of 8x registers. +// The others can be defined by the CPU cache size, which is: +// - 384 KB of L1 data cache per hyper-threaded core. +// - 16 MB of L2 per hyper-threaded core. +// The L1 cache is enough to store 3x (256 x 256) BF16 matrices, but the last one needs to be larger +// to accommodate F32 values. Moreover, we need to renormalize the values to avoid overflows and +// significant loss of precision. +// - (128 x 128) BF16 tile is 32 KB, so A & B are 64 KB. +// - (128 x 128) F32 tile is 64 KB, so C, A2, & B2 are 192 KB. +// This totals to 256 KB and fits within +#ifndef SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE +#define SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE 128 +#endif + +SIMSIMD_PUBLIC void simsimd_nxcor_bf16_sapphire( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_bf16_t const *a, simsimd_size_t a_row_stride_bytes, // + simsimd_bf16_t const *b, simsimd_size_t b_row_stride_bytes, // + simsimd_bf16_t *c, simsimd_size_t c_row_stride_bytes) { + + SIMSIMD_ALIGN64 simsimd_bf16_t a_l1_tile[SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE][SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE]; + SIMSIMD_ALIGN64 simsimd_bf16_t b_l1_tile[SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE][SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE]; + SIMSIMD_ALIGN64 simsimd_f32_t c_l1_tile[SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE][SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE]; + SIMSIMD_ALIGN64 simsimd_f32_t a2_l1_tile[SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE][SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE]; + SIMSIMD_ALIGN64 simsimd_f32_t b2_l1_tile[SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE][SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE]; + + SIMSIMD_STATIC_ASSERT(SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE % 32 == 0, + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE_NOT_MULTIPLE_OF_32); + + // Set up the AMX tile configuration structure. + // There are 8 tile registers from TMM0 to TMM7. Each is 16 rows by 64 bytes, fitting up to 1 KB of data. + // The actual dimensions can be different and are controlled by `rows` and `colsb` - the width in bytes. + SIMSIMD_ALIGN64 simsimd_u8_t amx_tiles_config[64]; + _mm512_store_si512((__m512i *)amx_tiles_config, _mm512_setzero_si512()); // Will fail, if the buffer is not aligned. + simsimd_u8_t *amx_palette_id_ptr = &amx_tiles_config[0]; + simsimd_u16_t *amx_tiles_colsb_ptr = (simsimd_u16_t *)(&amx_tiles_config[16]); //! 16-bit integers! + simsimd_u8_t *amx_tiles_rows_ptr = &amx_tiles_config[48]; //! 8-bit integers! + *amx_palette_id_ptr = 1; // The only palette currently supported + + // When using AMX tiles, we generally want to minimize the number of loads and stores, + // especially for the second argument, as it will involve reordering. + // So ideally, we want to load any AMX tile of B just once, and multiply it by many AMX tiles of A. + // That way we are computing a vertical band of C. Let's assign: + // - A tiles: TMM0, TMM1 - for (16 x 32) `bf16` values. + // - B tiles: TMM2, TMM3 - for (16 x 32) `bf16` values. + // - C tiles: TMM4, TMM5, TMM6, TMM7 - for (16 x 16) `f32` values. + amx_tiles_rows_ptr[0] = amx_tiles_rows_ptr[1] = amx_tiles_rows_ptr[2] = amx_tiles_rows_ptr[3] = + amx_tiles_rows_ptr[4] = amx_tiles_rows_ptr[5] = amx_tiles_rows_ptr[6] = amx_tiles_rows_ptr[7] = 16; + amx_tiles_colsb_ptr[0] = amx_tiles_colsb_ptr[1] = amx_tiles_colsb_ptr[2] = amx_tiles_colsb_ptr[3] = + amx_tiles_colsb_ptr[4] = amx_tiles_colsb_ptr[5] = amx_tiles_colsb_ptr[6] = amx_tiles_colsb_ptr[7] = 64; + _tile_loadconfig(&amx_tiles_config); + _tile_zero(4); // C top left tile. + _tile_zero(5); // C top right tile. + _tile_zero(6); // C bottom left tile. + _tile_zero(7); // C bottom right tile. + + for (simsimd_size_t a_l1_start_row = 0; a_l1_start_row < a_rows; + a_l1_start_row += SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE) { + /// Below comes code for rows: A[a_l1_start_row : a_l1_start_row + a_l1_count_rows]. + simsimd_size_t const a_l1_count_rows = (a_l1_start_row + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE < a_rows) + ? SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE + : a_rows - a_l1_start_row; + + for (simsimd_size_t b_l1_start_row = 0; b_l1_start_row != b_rows; + b_l1_start_row += SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE) { + /// Below comes code for rows: B[b_l1_start_row : b_l1_start_row + b_l1_count_rows] + simsimd_size_t const b_l1_count_rows = (b_l1_start_row + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE < b_rows) + ? SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE + : b_rows - b_l1_start_row; + + // Load the existing values in C tile: + // C[a_l1_start_row:a_l1_start_row+a_l1_count_rows][b_l1_start_row:b_l1_start_row+b_l1_count_rows]. + // Piece of cake with AVX-512, knowing that the data is already aligned. + // Each ZMM register can hold 64 bytes, so we can load 16x BF16 upcasting to 16x F32 elements at once. + simsimd_size_t const c_tail_size = b_l1_count_rows % 16; + // If the `b_l1_count_rows` is not divisible by 16, handle the tail with masked loads in AVX-512. + __mmask16 const c_tail_mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, c_tail_size); + for (simsimd_size_t row_in_l1 = 0; row_in_l1 < a_l1_count_rows; ++row_in_l1) { + simsimd_size_t col_in_l1 = 0; + simsimd_bf16_t const *c_global = (simsimd_bf16_t const *)_simsimd_advance_by_bytes( + (void *)(c + b_l1_start_row), // shift within a row + c_row_stride_bytes * (a_l1_start_row + row_in_l1)); // shift to the right row + + for (; col_in_l1 + 16 <= b_l1_count_rows; col_in_l1 += 16, c_global += 16) + _mm512_store_ps(&c_l1_tile[row_in_l1][col_in_l1], + _simsimd_bf16x16_to_f32x16_skylake(_mm256_lddqu_si256((__m256i *)c_global))); + if (c_tail_size) + _mm512_store_ps( + &c_l1_tile[row_in_l1][col_in_l1], + _simsimd_bf16x16_to_f32x16_skylake(_mm256_maskz_loadu_epi16(c_tail_mask, c_global))); + } + + // At this point we are multiplying a horizontal band of A by a horizontal band of B. + // Both will have up to `SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE` rows and `cols` columns. + for (simsimd_size_t l1_start_col = 0; l1_start_col < cols; + l1_start_col += SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE) { + /// Below comes code for tiles: + /// A[a_l1_start_row : a_l1_start_row + a_l1_count_rows, l1_start_col : l1_start_col + l1_count_cols] + /// B[b_l1_start_row : b_l1_start_row + b_l1_count_rows, l1_start_col : l1_start_col + l1_count_cols] + simsimd_size_t const l1_count_cols = (l1_start_col + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE < cols) + ? SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE + : cols - l1_start_col; + + // Now we need to load the tiles of A and B. + // Piece of cake with AVX-512, knowing that the data is already aligned. + // Each ZMM register can hold 64 bytes, so we can load 32x BF16 elements at once. + int is_boundary_tile = (l1_start_col + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE > cols) || + (a_l1_start_row + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE > a_rows) || + (b_l1_start_row + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE > b_rows); + if (!is_boundary_tile) { + // Load both A and B tiles in one loop. + for (simsimd_size_t row_in_l1 = 0; row_in_l1 != SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE; ++row_in_l1) { + simsimd_bf16_t const *a_global = (simsimd_bf16_t const *)_simsimd_advance_by_bytes( + (void *)(a + l1_start_col), // shift within a row + a_row_stride_bytes * (a_l1_start_row + row_in_l1)); // shift to the right row + + simsimd_bf16_t const *b_global = (simsimd_bf16_t const *)_simsimd_advance_by_bytes( // + (void *)(b + l1_start_col), // shift within a row + b_row_stride_bytes * (b_l1_start_row + row_in_l1)); // shift to the right row + + for (simsimd_size_t col_in_l1 = 0; col_in_l1 != SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE; + col_in_l1 += 32, a_global += 32, b_global += 32) { + _mm512_store_si512((__m512i *)&a_l1_tile[row_in_l1][col_in_l1], + _mm512_loadu_si512((__m512i *)a_global)); + _mm512_store_si512((__m512i *)&b_l1_tile[row_in_l1][col_in_l1], + _mm512_loadu_si512((__m512i *)b_global)); + } + } + } + // When dealing with boundary tiles, we need separate logic for the A and B tiles, + // cause those matrices can have a different number of rows. We also need to take care + // of the row tails, in case the number of columns is not divisible by the tile size. + else { + simsimd_size_t const tail_size = l1_count_cols % 32; + __mmask32 const tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_size); + // Load the A tile. + for (simsimd_size_t row_in_l1 = 0; row_in_l1 != a_l1_count_rows; ++row_in_l1) { + simsimd_bf16_t const *a_global = (simsimd_bf16_t const *)_simsimd_advance_by_bytes( + (void *)(a + l1_start_col), // shift within a row + a_row_stride_bytes * (a_l1_start_row + row_in_l1)); // shift to the right row + + simsimd_size_t col_in_l1 = 0; + for (; col_in_l1 + 32 < l1_count_cols; col_in_l1 += 32, a_global += 32) + _mm512_store_si512((__m512i *)&a_l1_tile[row_in_l1][col_in_l1], + _mm512_loadu_si512((__m512i *)a_global)); + if (tail_size) + _mm512_store_si512(&a_l1_tile[row_in_l1][col_in_l1], + _mm512_maskz_loadu_epi16(tail_mask, a_global)); + } + // Load the B tile. + for (simsimd_size_t row_in_l1 = 0; row_in_l1 != b_l1_count_rows; ++row_in_l1) { + simsimd_bf16_t const *b_global = (simsimd_bf16_t const *)_simsimd_advance_by_bytes( + (void *)(b + l1_start_col), // shift within a row + b_row_stride_bytes * (b_l1_start_row + row_in_l1)); // shift to the right row + simsimd_size_t col_in_l1 = 0; + for (; col_in_l1 + 32 < l1_count_cols; col_in_l1 += 32, b_global += 32) + _mm512_store_si512((__m512i *)&b_l1_tile[row_in_l1][col_in_l1], + _mm512_loadu_si512((__m512i *)b_global)); + if (tail_size) + _mm512_store_si512(&b_l1_tile[row_in_l1][col_in_l1], + _mm512_maskz_loadu_epi16(tail_mask, b_global)); + } + } + + // Now we need to view our `SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE`-sided matrices + // as composed of (16 x 16) tiles of 4-byte values, or (16 x 32) tiles of 2-byte values. + // Vertically stacking TMM4 and TMM5, we can see that as a (32 x 32) "pivot" slice of B. + simsimd_bf16_t tmm2_reordered[16][16][2]; + simsimd_bf16_t tmm3_reordered[16][16][2]; + for (simsimd_size_t b_row_in_l1 = 0; b_row_in_l1 != SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE; + b_row_in_l1 += 32) { + for (simsimd_size_t b_col_in_l1 = 0; b_col_in_l1 != SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE; + b_col_in_l1 += 32) { + + // Load and permute the data from B for AMX, simultaneously transposing! + // TODO: Optimize with AVX-512. + for (simsimd_size_t col_in_amx_tile = 0; col_in_amx_tile != 32; ++col_in_amx_tile) { + for (simsimd_size_t row_in_amx_tile = 0; row_in_amx_tile != 16; ++row_in_amx_tile) { + tmm2_reordered[col_in_amx_tile / 2][row_in_amx_tile][col_in_amx_tile % 2] = + b_l1_tile[b_row_in_l1 + row_in_amx_tile][b_col_in_l1 + col_in_amx_tile]; + tmm3_reordered[col_in_amx_tile / 2][row_in_amx_tile][col_in_amx_tile % 2] = + b_l1_tile[b_row_in_l1 + row_in_amx_tile + 16][b_col_in_l1 + col_in_amx_tile]; + } + } + _tile_loadd(2, &tmm2_reordered[0][0][0], 64); + _tile_loadd(3, &tmm3_reordered[0][0][0], 64); + + // Now we will walk through all the entries in the first 32 columns of the L1 tile of A. + // We will multiply them by TMM4 and TMM5, accumulating into TMM6 and TMM7 respectively. + for (simsimd_size_t a_row_in_l1 = 0; a_row_in_l1 != SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE; + a_row_in_l1 += 32) { + _tile_loadd(0, &a_l1_tile[a_row_in_l1][b_col_in_l1], + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_bf16_t)); + _tile_loadd(1, &a_l1_tile[a_row_in_l1 + 16][b_col_in_l1], + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_bf16_t)); + + _tile_loadd(4, &c_l1_tile[a_row_in_l1][b_row_in_l1], + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_f32_t)); + _tile_loadd(5, &c_l1_tile[a_row_in_l1][b_row_in_l1 + 16], + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_f32_t)); + _tile_loadd(6, &c_l1_tile[a_row_in_l1 + 16][b_row_in_l1], + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_f32_t)); + _tile_loadd(7, &c_l1_tile[a_row_in_l1 + 16][b_row_in_l1 + 16], + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_f32_t)); + + // Perform all possible multiplications. + _tile_dpbf16ps(4, 0, 2); + _tile_dpbf16ps(5, 0, 3); + _tile_dpbf16ps(6, 1, 2); + _tile_dpbf16ps(7, 1, 3); + + // Save back the updated C values. + _tile_stored(4, &c_l1_tile[a_row_in_l1][b_row_in_l1], + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_f32_t)); + _tile_stored(5, &c_l1_tile[a_row_in_l1][b_row_in_l1 + 16], + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_f32_t)); + _tile_stored(6, &c_l1_tile[a_row_in_l1 + 16][b_row_in_l1], + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_f32_t)); + _tile_stored(7, &c_l1_tile[a_row_in_l1 + 16][b_row_in_l1 + 16], + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_f32_t)); + } + } + } + } + + // Export the C values to global memory: + // C[a_l1_start_row:a_l1_start_row+a_l1_count_rows][b_l1_start_row:b_l1_start_row+b_l1_count_rows]. + for (simsimd_size_t row_in_l1 = 0; row_in_l1 < a_l1_count_rows; ++row_in_l1) { + simsimd_size_t col_in_l1 = 0; + simsimd_bf16_t const *c_global = (simsimd_bf16_t const *)_simsimd_advance_by_bytes( + (void *)(c + b_l1_start_row), // shift within a row + c_row_stride_bytes * (a_l1_start_row + row_in_l1) // shift to the right row + ); + for (; col_in_l1 + 16 <= b_l1_count_rows; col_in_l1 += 16, c_global += 16) + _mm256_storeu_si256((__m256i *)c_global, _simsimd_f32x16_to_bf16x16_skylake( + _mm512_load_ps(&c_l1_tile[row_in_l1][col_in_l1]))); + if (c_tail_size) + _mm256_mask_storeu_epi16( + (__m256i *)c_global, c_tail_mask, + _simsimd_f32x16_to_bf16x16_skylake(_mm512_load_ps(&c_l1_tile[row_in_l1][col_in_l1]))); + } + } + } +} + +SIMSIMD_PUBLIC void simsimd_nxcor_i8_sapphire( // + simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, // + simsimd_i8_t const *a, simsimd_size_t a_stride, // + simsimd_i8_t const *b, simsimd_size_t b_stride, // + simsimd_i8_t *c, simsimd_size_t c_stride) {} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SAPPHIRE + +#if SIMSIMD_TARGET_ICE +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2", "avx512bw", "avx512vnni") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2,avx512bw,avx512vnni"))), apply_to = function) + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_ICE +#endif // SIMSIMD_TARGET_X86 + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/include/simsimd/probability.h b/include/simsimd/probability.h index 2865aa32..a71745fa 100644 --- a/include/simsimd/probability.h +++ b/include/simsimd/probability.h @@ -3,6 +3,7 @@ * @brief SIMD-accelerated Similarity Measures for Probability Distributions. * @author Ash Vardanian * @date October 20, 2023 + * @see https://github.com/ashvardanian/simsimd?tab=readme-ov-file#probability-distributions * * Contains: * - Kullback-Leibler divergence diff --git a/include/simsimd/simsimd.h b/include/simsimd/simsimd.h index c3b59f20..0896d072 100644 --- a/include/simsimd/simsimd.h +++ b/include/simsimd/simsimd.h @@ -107,6 +107,7 @@ #include "dot.h" // Inner (dot) product, and its conjugate #include "elementwise.h" // Weighted Sum, Fused-Multiply-Add #include "geospatial.h" // Haversine and Vincenty +#include "matmul.h" // Normalized Cross Correlation or Matrix Multiplication #include "probability.h" // Kullback-Leibler, Jensen–Shannon #include "sparse.h" // Intersect #include "spatial.h" // L2, Cosine @@ -332,6 +333,40 @@ SIMSIMD_PUBLIC void simsimd_find_kernel_punned( // #if _SIMSIMD_TARGET_X86 +/** + * @brief Helper function that performs the system call on Linux to enable AMX instructions. + * ! This function must be called before invoking any AMX kernels on Linux. + */ +SIMSIMD_INTERNAL int _simsimd_capabilities_x86_enable_amx(void) { +#if defined(SIMSIMD_DEFINED_LINUX) + // Thanks to the good people of the Rust community: + // https://github.com/rust-lang/rust/issues/107795 + int XFEATURE_MASK_XTILECFG = (1 << 17); + int XFEATURE_MASK_XTILEDATA = (1 << 18); + int ARCH_GET_XCOMP_PERM = 0x1022; + int ARCH_REQ_XCOMP_PERM = 0x1023; + int SYS_arch_prctl = 158; + + unsigned long bitmask = 0; + long status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + if (0 != status) return 0; + if (bitmask & XFEATURE_MASK_XTILEDATA) return 1; + + status = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, 18); + if (0 != status) return 0; // XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed + status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + + // XFEATURE_XTILEDATA setup is failed, can't use TMUL + if (0 != status || !(bitmask & XFEATURE_MASK_XTILEDATA)) return 0; + + // XFEATURE_XTILEDATA set successfully, TMUL usage is allowed + (void)XFEATURE_MASK_XTILECFG; + return 1; +#else + return 0; +#endif +} + /** * @brief Function to determine the SIMD capabilities of the current 64-bit x86 machine at @b runtime. * @return A bitmask of the SIMD capabilities represented as a `simsimd_capability_t` enum value. @@ -373,9 +408,6 @@ SIMSIMD_PUBLIC simsimd_capability_t _simsimd_capabilities_x86(void) { // Check for AVX512F (Function ID 7, EBX register) // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L155 unsigned supports_avx512f = (info7.named.ebx & 0x00010000) != 0; - // Check for AVX512FP16 (Function ID 7, EDX register) - // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L198C9-L198C23 - unsigned supports_avx512fp16 = (info7.named.edx & 0x00800000) != 0; // Check for AVX512VNNI (Function ID 7, ECX register) unsigned supports_avx512vnni = (info7.named.ecx & 0x00000800) != 0; // Check for AVX512IFMA (Function ID 7, EBX register) @@ -389,6 +421,11 @@ SIMSIMD_PUBLIC simsimd_capability_t _simsimd_capabilities_x86(void) { // Check for AVX512BF16 (Function ID 7, Sub-leaf 1, EAX register) // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L205 unsigned supports_avx512bf16 = (info7sub1.named.eax & 0x00000020) != 0; + // Check for AVX512FP16 (Function ID 7, EDX register) + // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L198C9-L198C23 + unsigned supports_avx512fp16 = (info7.named.edx & 0x00800000) != 0; + unsigned supports_amxbf16 = (info7.named.edx & 0x00400000) != 0; + unsigned supports_amxint8 = (info7.named.edx & 0x02000000) != 0; // Clang doesn't show the VP2INTERSECT flag, but we can get it from QEMU // https://stackoverflow.com/a/68289220/2766161 unsigned supports_avx512vp2intersect = (info7.named.edx & 0x00000100) != 0; @@ -399,7 +436,7 @@ SIMSIMD_PUBLIC simsimd_capability_t _simsimd_capabilities_x86(void) { unsigned supports_ice = supports_avx512vnni && supports_avx512ifma && supports_avx512bitalg && supports_avx512vbmi2 && supports_avx512vpopcntdq; unsigned supports_genoa = supports_avx512bf16; - unsigned supports_sapphire = supports_avx512fp16; + unsigned supports_sapphire = supports_avx512fp16 && supports_amxbf16 && supports_amxint8; // We don't want to accidently enable AVX512VP2INTERSECT on Intel Tiger Lake CPUs unsigned supports_turin = supports_avx512vp2intersect && supports_avx512bf16; unsigned supports_sierra = 0; diff --git a/include/simsimd/sparse.h b/include/simsimd/sparse.h index 414cfb22..cfb7c34a 100644 --- a/include/simsimd/sparse.h +++ b/include/simsimd/sparse.h @@ -3,6 +3,7 @@ * @brief SIMD-accelerated functions for Sparse Vectors. * @author Ash Vardanian * @date March 21, 2024 + * @see https://github.com/ashvardanian/simsimd?tab=readme-ov-file#sparse-vectors * * Contains: * - Set Intersection ~ Jaccard Distance diff --git a/include/simsimd/spatial.h b/include/simsimd/spatial.h index e03e5a78..9575c5d5 100644 --- a/include/simsimd/spatial.h +++ b/include/simsimd/spatial.h @@ -3,6 +3,7 @@ * @brief SIMD-accelerated Spatial Similarity Measures. * @author Ash Vardanian * @date March 14, 2023 + * @see https://github.com/ashvardanian/simsimd?tab=readme-ov-file#spatial-affinity * * Contains: * - L2 (Euclidean) regular and squared distance diff --git a/include/simsimd/types.h b/include/simsimd/types.h index 8c2fc5cb..edeabc08 100644 --- a/include/simsimd/types.h +++ b/include/simsimd/types.h @@ -265,6 +265,17 @@ #define SIMSIMD_F16_DIVISION_EPSILON (1e-3) #endif +#ifdef _MSC_VER +#define SIMSIMD_ALIGN64 __declspec(align(64)) +#elif defined(__GNUC__) || defined(__clang__) +#define SIMSIMD_ALIGN64 __attribute__((aligned(64))) +#endif + +/** + * @brief Similat to `static_assert`, but compatible with C 99. + */ +#define SIMSIMD_STATIC_ASSERT(expr, msg) typedef char static_assert_##msg[(expr) ? 1 : -1] + #ifdef __cplusplus extern "C" { #endif @@ -561,6 +572,13 @@ SIMSIMD_PUBLIC void simsimd_f32_to_bf16(simsimd_f32_t x, simsimd_bf16_t *result_ *(unsigned short *)result_ptr = (unsigned short)conv.i; } +/** + * @brief Helper structure for implementing strided matrix row lookups, with @b single-byte-level pointer math. + */ +SIMSIMD_INTERNAL void *_simsimd_advance_by_bytes(void *ptr, simsimd_size_t bytes) { + return (void *)((simsimd_u8_t *)ptr + bytes); +} + SIMSIMD_PUBLIC simsimd_u32_t simsimd_u32_rol(simsimd_u32_t x, int n) { return (x << n) | (x >> (32 - n)); } SIMSIMD_PUBLIC simsimd_u16_t simsimd_u16_rol(simsimd_u16_t x, int n) { return (x << n) | (x >> (16 - n)); } SIMSIMD_PUBLIC simsimd_u8_t simsimd_u8_rol(simsimd_u8_t x, int n) { return (x << n) | (x >> (8 - n)); } diff --git a/scripts/bench.cxx b/scripts/bench.cxx index a31a5483..37fb72e4 100644 --- a/scripts/bench.cxx +++ b/scripts/bench.cxx @@ -1,3 +1,4 @@ +#include // `std::array` #include // `std::sqrt` #include // `std::aligned_alloc` #include // `std::memcpy` @@ -25,6 +26,8 @@ #define SIMSIMD_NATIVE_BF16 1 #include +#include "matrix.hpp" // `matrix_gt` + constexpr std::size_t default_seconds = 10; constexpr std::size_t default_threads = 1; constexpr simsimd_distance_t signaling_distance = std::numeric_limits::signaling_NaN(); @@ -33,9 +36,12 @@ constexpr simsimd_distance_t signaling_distance = std::numeric_limits matmul_sizes = {256, 1024, 4096}; namespace bm = benchmark; +namespace av = ashvardanian::simsimd; // clang-format off template struct datatype_enum_to_type_gt { using value_t = void; }; @@ -496,7 +502,6 @@ void measure_fma(bm::State &state, kernel_at kernel, kernel_at baseline, l2_metr using pair_t = pair_at; using vector_t = typename pair_at::vector_t; - constexpr simsimd_distance_t alpha = 0.2; constexpr simsimd_distance_t beta = 0.3; static_assert(function_args_count(kernel_at {}) >= 6 && function_args_count(kernel_at {}) <= 7, @@ -567,6 +572,69 @@ void measure_fma(bm::State &state, kernel_at kernel, kernel_at baseline, l2_metr state.counters["pairs"] = bm::Counter(iterations, bm::Counter::kIsRate); } +/** + * @brief Measures the performance of a @b matrix-multiplication function against a baseline using Google Benchmark. + * @tparam pair_at The type representing the vector pair used in the measurement. + * @tparam metric_at The type of the metric function (default is void). + * @param state The benchmark state object provided by Google Benchmark. + * @param kernel The kernel function to benchmark. + * @param baseline The baseline function to compare against. + * @param side The side length of the matrix. + */ +template +void measure_matmul(bm::State &state, metric_at metric, metric_at baseline, std::size_t side) { + //! TODO: Compare the values of matrices against each other! + auto call_baseline = [&](pair_t const &inputs, vector_t &c) -> double { + baseline(inputs.a.data(), side, inputs.b.data(), side, side, side, side, c.data(), side); + return std::accumulate(c.data(), c.data() + side * side, 0.0) / (side * side); + }; + auto call_contender = [&](pair_t const &inputs, vector_t &c) -> double { + metric(inputs.a.data(), side, inputs.b.data(), side, side, side, side, c.data(), side); + return std::accumulate(c.data(), c.data() + side * side, 0.0) / (side * side); + }; + + // Let's average the distance results over many pairs. + constexpr std::size_t inputs_count = 8; + std::vector inputs(inputs_count); + std::vector outputs(inputs_count); + for (std::size_t i = 0; i != inputs.size(); ++i) { + pair_t &input = inputs[i]; + vector_t &output = outputs[i]; + input.a = input.b = output = vector_t(side * side); + input.a.randomize(), input.b.randomize(), output.randomize(); + } + + // Initialize the output buffers for distance calculations. + std::vector results_baseline(inputs.size()); + std::vector results_contender(inputs.size()); + for (std::size_t i = 0; i != inputs.size(); ++i) + results_baseline[i] = call_baseline(inputs[i], outputs[i]), + results_contender[i] = call_contender(inputs[i], outputs[i]); + + // The actual benchmarking loop. + std::size_t iterations = 0; + for (auto _ : state) + bm::DoNotOptimize((results_contender[iterations & (inputs_count - 1)] = call_contender( + inputs[iterations & (inputs_count - 1)], outputs[iterations & (inputs_count - 1)]))), + iterations++; + + // Measure the mean absolute delta and relative error. + double mean_delta = 0, mean_relative_error = 0; + for (std::size_t i = 0; i != inputs.size(); ++i) { + auto abs_delta = std::abs(results_contender[i] - results_baseline[i]); + mean_delta += abs_delta; + double error = abs_delta != 0 && results_baseline[i] != 0 ? abs_delta / results_baseline[i] : 0; + mean_relative_error += error; + } + mean_delta /= inputs.size(); + mean_relative_error /= inputs.size(); + state.counters["abs_delta"] = mean_delta; + state.counters["relative_error"] = mean_relative_error; + state.counters["ops/s"] = bm::Counter(iterations * side * side * side * 2, bm::Counter::kIsRate); + state.counters["bytes"] = bm::Counter(iterations * side * side * 2, bm::Counter::kIsRate); + state.counters["inputs"] = bm::Counter(iterations, bm::Counter::kIsRate); +} + template void dense_(std::string name, metric_at *distance_func, metric_at *baseline_func) { using pair_t = vectors_pair_gt; @@ -623,12 +691,23 @@ void curved_(std::string name, metric_at *distance_func, metric_at *baseline_fun ->Threads(default_threads); } +template +void matmul_(std::string name, metric_at *distance_func, metric_at *baseline_func) { + + using pair_t = vectors_pair_gt; + for (std::size_t size : matmul_sizes) { + std::string name_dims = name + "_" + std::to_string(size) + "✕" + std::to_string(size); + bm::RegisterBenchmark(name_dims.c_str(), measure_matmul, distance_func, baseline_func, + size) + ->MinTime(default_seconds); + } +} + #if SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS void dot_f32_blas(simsimd_f32_t const *a, simsimd_f32_t const *b, simsimd_size_t n, simsimd_distance_t *result) { *result = cblas_sdot((int)n, a, 1, b, 1); } - void dot_f64_blas(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size_t n, simsimd_distance_t *result) { *result = cblas_ddot((int)n, a, 1, b, 1); } @@ -655,9 +734,21 @@ void vdot_f64c_blas(simsimd_f64_t const *a, simsimd_f64_t const *b, simsimd_size cblas_zdotc_sub((int)n / 2, a, 1, b, 1, result); } +void nxcor_f32_blas(simsimd_size_t a_rows, simsimd_size_t b_rows, simsimd_size_t cols, simsimd_f32_t const *a, + simsimd_size_t a_stride, simsimd_f32_t const *b, simsimd_size_t ldb, simsimd_f32_t *c, + simsimd_size_t ldc) { + cblas_sgemm( // + CblasRowMajor, CblasNoTrans, CblasTrans, // + (int)a_rows, (int)b_rows, (int)cols, 1.0f, // + a, (int)a_stride / sizeof(simsimd_f32_t), // + b, (int)b_stride / sizeof(simsimd_f32_t), 0.0f, // + c, (int)c_stride / sizeof(simsimd_f32_t)); +} + #endif int main(int argc, char **argv) { + simsimd_capability_t runtime_caps = simsimd_capabilities(); // Log supported functionality @@ -697,6 +788,15 @@ int main(int argc, char **argv) { std::printf("- x86 Sierra Forest support enabled: %s\n", flags[(runtime_caps & simsimd_cap_sierra_k) != 0]); std::printf("\n"); +#if defined(SIMSIMD_DEFINED_LINUX) + if ((runtime_caps & simsimd_cap_sapphire_k) != 0) { + if (!_simsimd_capabilities_x86_enable_amx()) { + std::printf("Error: AMX can't be enabled\n"); + return 1; + } + } +#endif + // Run the benchmarks bm::Initialize(&argc, argv); if (bm::ReportUnrecognizedArguments(argc, argv)) return 1; @@ -720,6 +820,8 @@ int main(int argc, char **argv) { constexpr simsimd_datatype_t f16c_k = simsimd_datatype_f16c_k; constexpr simsimd_datatype_t bf16c_k = simsimd_datatype_bf16c_k; + // curved_("bilinear_bf16_sapphire", simsimd_bilinear_bf16_sapphire, simsimd_bilinear_bf16_accurate); + #if SIMSIMD_BUILD_BENCHMARKS_WITH_CBLAS dense_("dot_f32_blas", dot_f32_blas, simsimd_dot_f32_accurate); @@ -729,6 +831,8 @@ int main(int argc, char **argv) { dense_("vdot_f32c_blas", vdot_f32c_blas, simsimd_vdot_f32c_accurate); dense_("vdot_f64c_blas", vdot_f64c_blas, simsimd_vdot_f64c_serial); + matmul_("nxcor_f32_blas", nxcor_f32_blas, simsimd_nxcor_f32_accurate); + #endif #if SIMSIMD_TARGET_NEON @@ -1059,6 +1163,13 @@ int main(int argc, char **argv) { fma_("fma_i8_serial", simsimd_fma_i8_serial, simsimd_fma_i8_accurate, simsimd_l2_i8_serial); fma_("wsum_i8_serial", simsimd_wsum_i8_serial, simsimd_wsum_i8_accurate, simsimd_l2_i8_serial); + // matmul_("matmul_f32_cpp", ashvardanian::simsimd::matmul_f32_cpp, simsimd_nxcor_f32_serial); + // matmul_("matmul_f16_cpp", ashvardanian::simsimd::matmul_f16_cpp, simsimd_nxcor_f16_serial); + // matmul_("matmul_bf16_cpp", ashvardanian::simsimd::matmul_bf16_cpp, simsimd_nxcor_bf16_serial); + // matmul_("matmul_f32_serial", simsimd_nxcor_f32_serial, simsimd_nxcor_f32_serial); + // matmul_("matmul_f16_serial", simsimd_nxcor_f16_serial, simsimd_nxcor_f16_serial); + // matmul_("matmul_bf16_serial", simsimd_nxcor_bf16_serial, simsimd_nxcor_bf16_serial); + bm::RunSpecifiedBenchmarks(); bm::Shutdown(); return 0; diff --git a/scripts/matrix.hpp b/scripts/matrix.hpp new file mode 100644 index 00000000..282cc528 --- /dev/null +++ b/scripts/matrix.hpp @@ -0,0 +1,385 @@ +/** + * @file matrix.hpp + * @brief Helper structures for loading, multiplying, tiling, and logging matrices. + * @author Ash Vardanian + * @date September 14, 2024 + */ +#ifndef SIMSIMD_MATRIX_HPP +#define SIMSIMD_MATRIX_HPP + +#include // `syscall` +#include // `syscall` + +#include // `std::printf` +#include // `std::rand` +#include // `typeid` + +#include + +namespace ashvardanian { +namespace simsimd { + +/** + * @brief Generic matrix structure, with row-major layout and loop-unrolled tile loading and unloading operations. + */ +template struct matrix_gt { + using scalar_t = scalar_at; + using mutable_scalar_t = std::remove_const_t; + + scalar_at* data_{}; + std::size_t rows_{}; + std::size_t cols_{}; + std::size_t stride_bytes_{}; + + matrix_gt() noexcept = default; + matrix_gt(scalar_at* data, std::size_t rows, std::size_t cols, std::size_t stride_bytes) noexcept + : data_(data), rows_(rows), cols_(cols), stride_bytes_(stride_bytes) {} + + template + matrix_gt(scalar_at (&data)[rows_ak][cols_ak]) noexcept + : data_(reinterpret_cast(data)), rows_(rows_ak), cols_(cols_ak), stride_bytes_(cols_ak * sizeof(scalar_at)) {} + + matrix_gt(matrix_gt const&) = default; + matrix_gt& operator=(matrix_gt const&) = default; + + std::size_t rows() const noexcept { return rows_; } + std::size_t cols() const noexcept { return cols_; } + + scalar_t* row_data(std::size_t row) noexcept { return reinterpret_cast(reinterpret_cast(data_) + row * stride_bytes_); } + scalar_t const* row_data(std::size_t row) const noexcept { + return reinterpret_cast(reinterpret_cast(data_) + row * stride_bytes_); + } + + matrix_gt submatrix(std::size_t row_offset, std::size_t col_offset, std::size_t rows, std::size_t cols) const noexcept { + return {row_data(row_offset), rows, cols, stride_bytes_}; + } + + scalar_t& operator()(std::size_t row, std::size_t col) noexcept { return row_data(row)[col]; } + scalar_t operator()(std::size_t row, std::size_t col) const noexcept { return row_data(row)[col]; } + scalar_t& at(std::size_t row, std::size_t col) noexcept { return row_data(row)[col]; } + scalar_t at(std::size_t row, std::size_t col) const noexcept { return row_data(row)[col]; } + + void fill(scalar_t value) noexcept { + for (std::size_t i = 0; i < rows_; i++) + for (std::size_t j = 0; j < cols_; j++) + at(i, j) = value; + } + + /** + * @brief Fills the diagonal of the matrix with a given value, similar to NumPy. + * https://numpy.org/doc/2.0/reference/generated/numpy.fill_diagonal.html + */ + void fill_diagonal(scalar_t value) noexcept { + for (std::size_t i = 0; i < (std::min)(rows_, cols_); i++) + at(i, i) = value; + } + + void fill_random(std::int64_t min = -1, std::int64_t max = 1) noexcept { + for (std::size_t i = 0; i < rows_; i++) + for (std::size_t j = 0; j < cols_; j++) + at(i, j) = std::rand() % (max - min + 1) + min; + } + + void print() const noexcept { + for (std::size_t i = 0; i < rows_; i++) { + for (std::size_t j = 0; j < cols_; j++) { + if constexpr (std::is_unsigned_v) + std::printf("%8llu ", (unsigned long long)at(i, j)); + else if constexpr (std::is_signed_v) + std::printf("%8lld ", (long long)at(i, j)); + else + std::printf("%8.2f ", (float)at(i, j)); + } + std::printf("\n"); + } + } + + template + void export_internal_tile(std::size_t tile_row, std::size_t tile_col, mutable_scalar_t* tile) const noexcept { +#pragma omp unroll + for (std::size_t i = 0; i != tile_height_ak; ++i) { + auto data_row = row_data(tile_row + i); + for (std::size_t j = 0; j != tile_width_ak; ++j) + tile[i * tile_width_ak + j] = data_row[tile_col + j]; + } + } + + template + void export_bounding_tile(std::size_t tile_row, std::size_t tile_col, mutable_scalar_t* tile) const noexcept { + std::memset(tile, 0, tile_height_ak * tile_width_ak * sizeof(mutable_scalar_t)); + std::size_t tile_rows = (std::min)(tile_height_ak, rows_ - tile_row); + std::size_t tile_cols = (std::min)(tile_width_ak, cols_ - tile_col); + for (std::size_t i = 0; i != tile_rows; ++i) { + auto data_row = row_data(tile_row + i); + for (std::size_t j = 0; j != tile_cols; ++j) + tile[i * tile_width_ak + j] = data_row[tile_col + j]; + } + } + + template + void import_internal_tile(std::size_t tile_row, std::size_t tile_col, mutable_scalar_t const* tile) noexcept { +#pragma omp unroll + for (std::size_t i = 0; i != tile_height_ak; ++i) { + auto data_row = row_data(tile_row + i); + for (std::size_t j = 0; j != tile_width_ak; ++j) + data_row[tile_col + j] = tile[i * tile_width_ak + j]; + } + } + + template + void import_bounding_tile(std::size_t tile_row, std::size_t tile_col, mutable_scalar_t const* tile) noexcept { + std::size_t tile_rows = (std::min)(tile_height_ak, rows_ - tile_row); + std::size_t tile_cols = (std::min)(tile_width_ak, cols_ - tile_col); + for (std::size_t i = 0; i != tile_rows; ++i) { + auto data_row = row_data(tile_row + i); + for (std::size_t j = 0; j != tile_cols; ++j) + data_row[tile_col + j] = tile[i * tile_width_ak + j]; + } + } +}; + +/** + * @brief Baseline serial Cross-Corellation (nxcor) implementation for dense matrices. + */ +template +void nxcor(matrix_gt const& a, matrix_gt const& b, matrix_gt& c) { + + using scalar_result_t = std::remove_const_t; + using scalar_first_t = std::remove_const_t; + using scalar_second_t = std::remove_const_t; + + constexpr std::size_t a_tile_rows_k = 16; // Mostly influenced by CPU cache size. + constexpr std::size_t b_tile_rows_k = 16; // Mostly influenced by CPU cache size and cache line width. + constexpr std::size_t tile_depth_k = 16; // Mostly influenced by CPU register width. + + struct tiles_t { + alignas(64) scalar_first_t a[a_tile_rows_k][tile_depth_k]; + alignas(64) scalar_second_t b[b_tile_rows_k][tile_depth_k]; + alignas(64) scalar_result_t c[a_tile_rows_k][b_tile_rows_k]; + }; + + auto multiply_tile = [](tiles_t& tiles) noexcept { + for (std::size_t i = 0; i < a_tile_rows_k; ++i) + for (std::size_t j = 0; j < b_tile_rows_k; ++j) + for (std::size_t k = 0; k < tile_depth_k; ++k) + tiles.c[i][j] += tiles.a[i][k] * tiles.b[j][k]; + }; + +#pragma omp parallel for collapse(2) schedule(static) + for (std::size_t i_tile_offset = 0; i_tile_offset < a.rows(); i_tile_offset += a_tile_rows_k) { + for (std::size_t j_tile_offset = 0; j_tile_offset < b.cols(); j_tile_offset += b_tile_rows_k) { + bool is_last_in_a = i_tile_offset + a_tile_rows_k >= a.rows(); + bool is_last_in_b = j_tile_offset + b_tile_rows_k >= b.cols(); + bool is_bounding_tile = is_last_in_a || is_last_in_b; + + // Load a tile of C. + tiles_t tiles; + c.template export_bounding_tile(i_tile_offset, j_tile_offset, &tiles.c[0][0]); + + // Progress through columns of A and B. + std::size_t k_tile_offset = 0; + if (is_bounding_tile) { + for (; k_tile_offset + tile_depth_k <= a.cols(); k_tile_offset += tile_depth_k) { + a.template export_bounding_tile( // + i_tile_offset, k_tile_offset, &tiles.a[0][0]); + b.template export_bounding_tile( // + j_tile_offset, k_tile_offset, &tiles.b[0][0]); + multiply_tile(tiles); + } + } else { + for (; k_tile_offset + tile_depth_k <= a.cols(); k_tile_offset += tile_depth_k) { + a.template export_internal_tile( // + i_tile_offset, k_tile_offset, &tiles.a[0][0]); + b.template export_internal_tile( // + j_tile_offset, k_tile_offset, &tiles.b[0][0]); + multiply_tile(tiles); + } + } + + // Don't forget the tail of each row, if the number of columns is not divisible by the `tile_depth_k`. + if (k_tile_offset < a.cols()) { + a.template export_bounding_tile( // + i_tile_offset, k_tile_offset, &tiles.a[0][0]); + b.template export_bounding_tile( // + j_tile_offset, k_tile_offset, &tiles.b[0][0]); + multiply_tile(tiles); + } + + // Store C back. + c.template import_bounding_tile(i_tile_offset, j_tile_offset, &tiles.c[0][0]); + } + } +} + +/** + * @brief Multiplies two matrices using the Intel AMX instruction set. + * @tparam input_type The type of the input matrices ("bf16" or "i8" or "u8"). + * @tparam output_type The type of the output matrix ("f32" or "i32"). + * @tparam tile_m The number of rows in the first matrix (default is 16). + * @tparam tile_k The number of columns in the first matrix (default is 32). + * @tparam tile_n The number of columns in the second matrix (default is 16). + */ +template +void nxcor_amx( // + input_type (&matrix_a)[tile_m][tile_k], // 16 rows * 32 cols * 2 bytes/scalar = 1024 bytes + input_type (&matrix_b)[tile_n][tile_k], // transposed(32 rows * 16 cols) * 2 bytes/scalar = 1024 bytes + output_type (&result_matrix)[tile_m][tile_n] // 16 rows * 16 cols * 4 bytes/scalar = 1024 bytes +) { + + static_assert( // + sizeof(input_type) * tile_m * tile_k == 1024 && // + sizeof(input_type) * tile_k * tile_n == 1024 && // + sizeof(output_type) * tile_m * tile_n == 1024, + "Choose a simple tile size and thank me later"); + + // Set up the tile configuration structure + // There are 8 tile registers from TMM0 to TMM7. + // Each is 16 rows by 64 bytes, fitting up to 1 KB of data. + // The actual dimensions can be different and are controlled + // by `rows` and `colsb` - the width in bytes. + alignas(64) std::uint8_t tilecfg[64]; + std::memset(tilecfg, 0, sizeof(tilecfg)); + std::uint8_t* palette_id_ptr = &tilecfg[0]; + std::uint16_t* tiles_colsb_ptr = (std::uint16_t*)(&tilecfg[16]); + std::uint8_t* tiles_rows_ptr = &tilecfg[48]; + + *palette_id_ptr = 1; // The only palette currently supported + + // Important to note, AMX doesn't care about the real shape of our matrix, + // it only cares about it's own tile shape. Keep it simple, otherwise + // the next person reading this will be painting the walls with their brains. + tiles_rows_ptr[0] = 16; + tiles_rows_ptr[1] = 16; + tiles_rows_ptr[2] = 16; + tiles_colsb_ptr[0] = 64; + tiles_colsb_ptr[1] = 64; + tiles_colsb_ptr[2] = 64; + + _tile_loadconfig(&tilecfg); + _tile_zero(2); + _tile_loadd(0, &matrix_a[0][0], 64); + + // The second matrix must be reordered to fit the tile shape + constexpr int tile_k_pack = (4 / sizeof(input_type)); // Vertical K packing into Dword + input_type matrix_b_reordered[tile_k / tile_k_pack][tile_n][tile_k_pack]; // Re-laid B matrix + for (int k = 0; k < tile_k; ++k) + for (int n = 0; n < tile_n; ++n) + // We are shrinking the number of rows in the second matrix by 2x for `bf16` and by 4x for `i8` and `u8`. + // We are also practically making the rows longer by 2x for `bf16` and by 4x for `i8` and `u8`. + matrix_b_reordered[k / tile_k_pack][n][k % tile_k_pack] = matrix_b[n][k]; + _tile_loadd(1, &matrix_b_reordered[0][0][0], 64); + + // Here are the shape constraints: + // + // • #UD if srcdest.colbytes mod 4 ≠ 0. + // • #UD if src1.colbytes mod 4 ≠ 0. + // • #UD if src2.colbytes mod 4 ≠ 0. + // • #UD if srcdest.colbytes ≠ src2.colbytes - + // why the hell the row width of `f32` destination should + // be equal to the row width of `bfloat16` source?! + // • #UD if src1.colbytes / 4 ≠ src2.rows. + // so this practically means that the second matrix must have 2x + // fewer rows than the first one, meaning the number of columns in the + // first matrix must be 2x smaller than the number of rows in it! + // • #UD if srcdest.rows ≠ src1.rows. + if constexpr (sizeof(input_type) == 2) { + _tile_dpbf16ps(2, 0, 1); + } else { + _tile_dpbssd(2, 0, 1); + } + + // Store the result back into the result matrix + _tile_stored(2, result_matrix, 64); + + // Zero out the tile registers + _tile_release(); +} + +template void try_amx() { + std::printf("\n\n\n"); + std::printf("Trying AMX with %d x %d x %d matrix multiplication of type %s -> %s\n", tile_m, tile_k, tile_n, typeid(input_type).name(), + typeid(output_type).name()); + + input_type buffer_a[tile_m][tile_k]; + input_type buffer_b[tile_n][tile_k]; + output_type buffer_c[tile_m][tile_n] = {0}; + + matrix_gt matrix_a{&buffer_a[0][0], tile_m, tile_k, tile_k * sizeof(input_type)}; + matrix_gt matrix_b{&buffer_b[0][0], tile_n, tile_k, tile_k * sizeof(input_type)}; + matrix_gt matrix_c{&buffer_c[0][0], tile_m, tile_n, tile_n * sizeof(output_type)}; + + // Initialize the matrices with values + // std::iota(&buffer_a[0][0], &buffer_a[tile_m - 1][tile_k - 1] + 1, 1); + // std::iota(&buffer_b[0][0], &buffer_b[tile_n - 1][tile_k - 1] + 1, 1); + for (std::size_t row = 0; row != tile_m; ++row) + for (std::size_t col = 0; col != tile_k; ++col) + buffer_a[row][col] = row; + for (std::size_t row = 0; row != tile_n; ++row) + for (std::size_t col = 0; col != tile_k; ++col) + buffer_b[row][col] = -(__bf16)row; + + // Perform matrix multiplication using AMX-BF16 inline assembly + nxcor_amx(buffer_a, buffer_b, buffer_c); + std::printf("Resulting 16x16 matrix with AMX:\n"); + matrix_c.print(); + (void)matrix_a; + (void)matrix_b; + + // Compare this to naive multiplication + output_type buffer_c_serial[tile_m][tile_n] = {0}; + matrix_gt matrix_c_serial{&buffer_c_serial[0][0], tile_m, tile_n, tile_n * sizeof(output_type)}; + for (int i = 0; i < tile_m; i++) + for (int j = 0; j < tile_n; j++) + for (int k = 0; k < tile_k; k++) + buffer_c_serial[i][j] += (output_type)buffer_a[i][k] * (output_type)buffer_b[j][k]; + std::printf("Resulting 16x16 matrix after naive multiplication:\n"); + matrix_c_serial.print(); +} + +#if 0 +void check_tiled_amx() { + simsimd_bf16_t buffer_a[SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE][SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE]; + simsimd_bf16_t buffer_b[SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE][SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE]; + simsimd_bf16_t buffer_c[SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE][SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE]; + simsimd_bf16_t buffer_c_serial[SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE][SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE]; + + // The `bf16` resolution can accurately represent integers between -256 and 256, which makes it problematic + // for testing purposes. + av::matrix_gt<__bf16> a(buffer_a); + av::matrix_gt<__bf16> b(buffer_b); + av::matrix_gt<__bf16> c(buffer_c); + av::matrix_gt<__bf16> c_serial(buffer_c_serial); + + // a.fill(1.0f); + // b.fill(1.0f); + // for (std::size_t row = 0; row != SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE; ++row) + // for (std::size_t col = 0; col != SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE; ++col) + // buffer_a[row][col] = row, buffer_b[row][col] = -(__bf16)row; + a.fill_random(); + b.fill_random(); + av::nxcor(a, b, c_serial); + c_serial.print(); + + simsimd_nxcor_bf16_sapphire( // + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE, SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE, + SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE, // + &buffer_a[0][0], SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_bf16_t), // + &buffer_b[0][0], SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_bf16_t), // + &buffer_c[0][0], SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE * sizeof(simsimd_bf16_t)); + c.print(); + + // Find the first mismatch position + std::size_t row = 0, col = 0; + for (; row != SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE; ++row) + for (col = 0; col != SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE; ++col) + if (buffer_c[row][col] != buffer_c_serial[row][col]) + break; + if (row != SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE && col != SIMSIMD_NXCOR_BF16_AMX_L1_TILE_SIZE) + std::printf("Mismatch at row %zu, col %zu\n", row, col); +} +#endif + +} // namespace simsimd +} // namespace ashvardanian + +#endif