From e7a84b90dbb237a6954c2bd33575aaaae66bc270 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Sat, 25 Jan 2025 12:48:30 +0100 Subject: [PATCH 01/24] Update cubecl --- Cargo.lock | 135 ++++++++++-------- Cargo.toml | 4 +- crates/burn-jit/src/kernel/binary.rs | 4 +- crates/burn-jit/src/kernel/binary_int.rs | 4 +- crates/burn-jit/src/kernel/cast/base.rs | 2 +- crates/burn-jit/src/kernel/comparison.rs | 4 +- .../burn-jit/src/kernel/conv/conv2d/col2im.rs | 2 +- .../burn-jit/src/kernel/conv/conv2d/direct.rs | 2 +- .../src/kernel/conv/conv2d/gemm/launch.rs | 12 +- .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 40 ++++-- .../src/kernel/conv/conv2d/layout_swap.rs | 4 +- .../kernel/conv/conv2d/transpose_direct.rs | 2 +- crates/burn-jit/src/kernel/conv/conv3d.rs | 2 +- .../kernel/conv/deform_conv_transpose2d.rs | 4 +- crates/burn-jit/src/kernel/conv/error.rs | 11 +- crates/burn-jit/src/kernel/index/flip.rs | 2 +- crates/burn-jit/src/kernel/index/gather.rs | 2 +- .../burn-jit/src/kernel/index/repeat_dim.rs | 2 +- crates/burn-jit/src/kernel/index/scatter.rs | 2 +- crates/burn-jit/src/kernel/index/select.rs | 2 +- .../src/kernel/index/select_assign.rs | 2 +- crates/burn-jit/src/kernel/index/slice.rs | 2 +- .../src/kernel/interpolate/bicubic.rs | 2 +- .../src/kernel/interpolate/bilinear.rs | 2 +- .../src/kernel/interpolate/nearest.rs | 2 +- .../kernel/interpolate/nearest_backward.rs | 2 +- crates/burn-jit/src/kernel/mask/mask_fill.rs | 4 +- crates/burn-jit/src/kernel/mask/mask_where.rs | 4 +- .../src/kernel/pool/avg_pool2d_backward.rs | 2 +- .../src/kernel/pool/max_pool2d_backward.rs | 2 +- .../src/kernel/quantization/dequantize.rs | 4 +- .../src/kernel/quantization/quantize.rs | 10 +- crates/burn-jit/src/kernel/unary_float.rs | 2 +- crates/burn-jit/src/kernel/unary_int.rs | 2 +- crates/burn-jit/src/kernel/unary_numeric.rs | 2 +- crates/burn-jit/src/ops/numeric.rs | 2 +- crates/burn-wgpu/src/lib.rs | 2 +- examples/custom-cubecl-kernel/src/kernel.rs | 2 +- 38 files changed, 161 insertions(+), 131 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b5a1835a3f..f579c27035 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -585,9 +585,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "4.0.1" +version = "4.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +checksum = "74fa05ad7d803d413eb8380983b092cbbaf9a85f151b871360e7b00cd7060b37" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -1118,9 +1118,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.9" +version = "1.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" +checksum = "13208fcbb66eaeffe09b99fffbe1af420f00a7b35aa99ad683dfc1aa76145229" dependencies = [ "jobserver", "libc", @@ -1225,9 +1225,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" +checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796" dependencies = [ "clap_builder", "clap_derive", @@ -1235,9 +1235,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" +checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" dependencies = [ "anstream", "anstyle", @@ -1436,9 +1436,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -1537,9 +1537,9 @@ dependencies = [ [[package]] name = "crunchy" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] name = "crypto-common" @@ -1575,7 +1575,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1590,7 +1590,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1611,7 +1611,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1632,7 +1632,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-common", @@ -1646,7 +1646,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-common", @@ -1662,7 +1662,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-common", @@ -1678,9 +1678,9 @@ dependencies = [ [[package]] name = "cubecl-hip-sys" -version = "6.3.1000" +version = "6.3.1001" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4d987c1720eab39c72c515377a8001f683a4c4d99232a29fc0de389d9a8ce4f" +checksum = "c7e92df7f9feff6a469932fc4d4b349d28000af9e6f34e583eb4f8df70038d48" dependencies = [ "libc", ] @@ -1688,9 +1688,11 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-common", + "cubecl-macros-internal", + "derive_more 1.0.0", "float-ord", "half", "num-traits", @@ -1701,7 +1703,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-core", @@ -1713,7 +1715,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-common", "darling", @@ -1725,10 +1727,21 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "cubecl-macros-internal" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1744,7 +1757,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1754,7 +1767,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "async-channel", "async-lock", @@ -1776,7 +1789,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1791,7 +1804,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "ash", "async-channel", @@ -1811,9 +1824,9 @@ dependencies = [ [[package]] name = "cudarc" -version = "0.12.2" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd76de2aa3a7bdb9a65941ea5a3c688d941688f736a81b2fc5beb88747a7f25" +checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384" dependencies = [ "half", "libloading", @@ -2790,15 +2803,15 @@ dependencies = [ [[package]] name = "gix-trace" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04bdde120c29f1fc23a24d3e115aeeea3d60d8e65bab92cc5f9d90d9302eb952" +checksum = "7c396a2036920c69695f760a65e7f2677267ccf483f25046977d87e4cb2665f7" [[package]] name = "gix-utils" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba427e3e9599508ed98a6ddf8ed05493db114564e338e41f6a996d2e4790335f" +checksum = "ff08f24e03ac8916c478c8419d7d3c33393da9bb41fa4c24455d5406aeefd35f" dependencies = [ "fastrand", "unicode-normalization", @@ -3454,9 +3467,9 @@ checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -3526,9 +3539,9 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.10.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is_terminal_polyfill" @@ -3948,9 +3961,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" +checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924" dependencies = [ "adler2", "simd-adler32", @@ -4421,9 +4434,9 @@ dependencies = [ [[package]] name = "objc2-encode" -version = "4.0.3" +version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7891e71393cd1f227313c9379a26a584ff3d7e6e7159e988851f0934c993f0f8" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" [[package]] name = "objc2-foundation" @@ -4630,9 +4643,9 @@ dependencies = [ [[package]] name = "openssl-probe" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openssl-sys" @@ -5397,7 +5410,7 @@ dependencies = [ "once_cell", "polars-error", "pyo3", - "raw-cpuid 11.2.0", + "raw-cpuid 11.3.0", "rayon", "serde", "stacker", @@ -5805,9 +5818,9 @@ dependencies = [ [[package]] name = "range-alloc" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8a99fddc9f0ba0a85884b8d14e3592853e787d581ca1816c91349b10e4eeab" +checksum = "c3d6831663a5098ea164f89cff59c6284e95f4e3c76ce9848d4529f5ccca9bde" [[package]] name = "ratatui" @@ -5892,9 +5905,9 @@ dependencies = [ [[package]] name = "raw-cpuid" -version = "11.2.0" +version = "11.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" +checksum = "c6928fa44c097620b706542d428957635951bade7143269085389d42c8a4927e" dependencies = [ "bitflags 2.8.0", ] @@ -6263,9 +6276,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.43" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags 2.8.0", "errno", @@ -6480,9 +6493,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "seq-macro" @@ -7701,9 +7714,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "11cd88e12b17c6494200a9c1b683a04fcac9573ed74cd1b62aeb2727c5592243" [[package]] name = "unicode-normalization" @@ -7842,9 +7855,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744018581f9a3454a9e15beb8a33b017183f1e7c0cd170232a2d1453b23a51c4" +checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" dependencies = [ "getrandom", "rand", @@ -7863,9 +7876,9 @@ dependencies = [ [[package]] name = "valuable" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" [[package]] name = "value-trait" @@ -8088,9 +8101,9 @@ dependencies = [ [[package]] name = "wgpu" -version = "24.0.0" +version = "24.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e41253fc7b660735e2a2d9a58c563f2a047d3cc3445293d8f4095538c9e8afbe" +checksum = "47f55718f85c2fa756edffa0e7f0e0a60aba463d1362b57e23123c58f035e4b6" dependencies = [ "arrayvec", "bitflags 2.8.0", diff --git a/Cargo.toml b/Cargo.toml index 7cf3ddd008..7cc281e7cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/kernel/binary.rs b/crates/burn-jit/src/kernel/binary.rs index d7c4d789ab..f0da764a7a 100644 --- a/crates/burn-jit/src/kernel/binary.rs +++ b/crates/burn-jit/src/kernel/binary.rs @@ -112,7 +112,7 @@ pub(crate) fn kernel_scalar_binop( output: &mut Tensor>, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); @@ -132,7 +132,7 @@ pub(crate) fn kernel_binop( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { diff --git a/crates/burn-jit/src/kernel/binary_int.rs b/crates/burn-jit/src/kernel/binary_int.rs index 06706a7d28..390bfc479e 100644 --- a/crates/burn-jit/src/kernel/binary_int.rs +++ b/crates/burn-jit/src/kernel/binary_int.rs @@ -85,7 +85,7 @@ pub(crate) fn kernel_scalar_binop_int( output: &mut Tensor>, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); @@ -105,7 +105,7 @@ pub(crate) fn kernel_binop_int( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { diff --git a/crates/burn-jit/src/kernel/cast/base.rs b/crates/burn-jit/src/kernel/cast/base.rs index 798b79a0f0..43b24f071a 100644 --- a/crates/burn-jit/src/kernel/cast/base.rs +++ b/crates/burn-jit/src/kernel/cast/base.rs @@ -12,7 +12,7 @@ pub(crate) fn cast_element( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } let offset_input = index_offset_with_layout::( diff --git a/crates/burn-jit/src/kernel/comparison.rs b/crates/burn-jit/src/kernel/comparison.rs index e33687fb5a..a6de9025bb 100644 --- a/crates/burn-jit/src/kernel/comparison.rs +++ b/crates/burn-jit/src/kernel/comparison.rs @@ -82,7 +82,7 @@ pub(crate) fn kernel_scalar_cmp>( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } output[ABSOLUTE_POS] = Line::cast_from(O::execute(input[ABSOLUTE_POS], Line::new(scalar))); @@ -102,7 +102,7 @@ pub(crate) fn kernel_cmp>( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs index 11fb3b4aee..4f6931f86d 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -241,7 +241,7 @@ fn col2im_kernel( #[comptime] has_bias: bool, ) { if ABSOLUTE_POS >= image.len() { - return; + terminate!(); } let im_x = ABSOLUTE_POS % image.shape(3) + args.pad_w; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs index c724cfc3a3..1cd24f7c0c 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs @@ -35,7 +35,7 @@ fn direct_conv2d_kernel( #[comptime] kernel_size_1_unroll: Option, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_channels = weight.shape(1); diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs index f36f89bdf5..ad70a9b825 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs @@ -7,7 +7,7 @@ use burn_tensor::{ use cubecl::{ flex32, ir::{Elem, FloatKind}, - linalg::matmul::{self}, + linalg::matmul::{self, kernels::MatmulLaunchError}, tensor_line_size, tf32, Feature, }; use half::{bf16, f16}; @@ -195,18 +195,14 @@ where let cube_count = Alg::cube_count(&selection, &problem); let advanced_config = Default::default(); - let config = match Alg::make_config( + let config = Alg::make_config( config_input, &problem, &cube_dim, &cube_count, &advanced_config, - ) { - Ok(val) => val, - Err(err) => { - panic!("Can't launch conv kernel because of an invalid config: {err}") - } - }; + ) + .map_err(MatmulLaunchError::InvalidConfig)?; let bias = bias.unwrap_or_else(|| { empty_device::(input.client.clone(), input.device.clone(), Shape::new([1])) diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index 6b738ab988..09ce56898b 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -53,7 +53,7 @@ fn im2col_kernel( let out_w = args.out_w; if ABSOLUTE_POS > args.num_elements { - return; + terminate!(); } let out_x = ABSOLUTE_POS % out_w; @@ -98,25 +98,38 @@ fn im2col_kernel( } #[cfg(not(test))] -pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option { - let cube_count_per_batch = (out_h * out_w).div_ceil(burn_common::PLANE_DIM_APPROX); +pub(crate) fn batches_per_run( + batch_size: usize, + out_h: usize, + out_w: usize, +) -> Result { + use cubecl::linalg::matmul::kernels::MatmulAvailabilityError; + + let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::PLANE_DIM_APPROX); let max_cube_count = u16::MAX as usize; let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size); if max_simultaneous == 0 { - return None; + return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static( + cube_count_per_batch as u32, + 1, + 1, + )) + .into()); } - Some( - (0..=max_simultaneous) - .rev() - .find(|per_run| batch_size % per_run == 0) - .expect("Logically not possible"), - ) + Ok((0..=max_simultaneous) + .rev() + .find(|per_run| batch_size % per_run == 0) + .expect("Logically not possible")) } #[cfg(test)] #[allow(unused)] -pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option { - Some(1) +pub(crate) fn batches_per_run( + batch_size: usize, + out_h: usize, + out_w: usize, +) -> Result { + Ok(1) } fn im2col( @@ -214,8 +227,7 @@ pub fn conv2d_im2col( return execute_1x1_kernel::(input, weight, bias, options); } - let batches_per_run = batches_per_run(batch_size, out_h, out_w) - .expect("Image too large to run even one batch at once"); + let batches_per_run = batches_per_run(batch_size, out_h, out_w)?; let matmul_shape = Shape::new([groups, out_c_per_group, batches_per_run * out_h * out_w]); let mut out = if batches_per_run != batch_size { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs index 62f0e56d8f..7cbe09dbc0 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs @@ -107,7 +107,7 @@ fn nchw_to_nhwc_kernel( let batch = CUBE_POS_Z; if batch >= input.shape(0) { - return; + terminate!(); } let batch_offset = batch * input.stride(0); @@ -163,7 +163,7 @@ fn nchw_to_nhwc_kernel( let hw = base_hw + mat_hw; if hw >= shape_hw { - return; + terminate!(); } let mat_c_start = mat_hw_start; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index d3e91d5947..a8cd1ceb7f 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -32,7 +32,7 @@ fn conv_transpose2d_direct_kernel( args: ConvArgs, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_c_per_group = weight.shape(0) / args.groups; diff --git a/crates/burn-jit/src/kernel/conv/conv3d.rs b/crates/burn-jit/src/kernel/conv/conv3d.rs index 157610794b..a616c432b9 100644 --- a/crates/burn-jit/src/kernel/conv/conv3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv3d.rs @@ -41,7 +41,7 @@ fn conv3d_kernel( #[comptime] kernel_size_2_unroll: Option, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_channels = weight.shape(1); diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index ddee1360e4..5840f4dc9f 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -275,7 +275,7 @@ fn deform_col2img_coord_kernel( // Alternatively : [batch, offset_channels, out_h, out_w] if ABSOLUTE_POS >= grad_offset.len() { - return; + terminate!(); } let offset_channels = offset.shape(1); @@ -551,7 +551,7 @@ fn deform_col2img_kernel( ) { // Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] if ABSOLUTE_POS >= columns.len() { - return; + terminate!(); } let n_in_channels = args.in_channels; diff --git a/crates/burn-jit/src/kernel/conv/error.rs b/crates/burn-jit/src/kernel/conv/error.rs index 99c91fc751..2654a20e24 100644 --- a/crates/burn-jit/src/kernel/conv/error.rs +++ b/crates/burn-jit/src/kernel/conv/error.rs @@ -1,5 +1,8 @@ use core::fmt::Debug; -use cubecl::{linalg::matmul::kernels::MatmulLaunchError, tune::AutotuneError}; +use cubecl::{ + linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}, + tune::AutotuneError, +}; pub enum ConvLaunchError { Matmul(MatmulLaunchError), @@ -30,6 +33,12 @@ impl From for ConvLaunchError { } } +impl From for ConvLaunchError { + fn from(value: MatmulAvailabilityError) -> Self { + Self::Matmul(MatmulLaunchError::Unavailable(value)) + } +} + #[allow(clippy::from_over_into)] impl Into for ConvLaunchError { fn into(self) -> AutotuneError { diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index 583e0346d3..a682a76eac 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -11,7 +11,7 @@ fn flip_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/gather.rs b/crates/burn-jit/src/kernel/index/gather.rs index 9e9b5685bb..c1aa56072e 100644 --- a/crates/burn-jit/src/kernel/index/gather.rs +++ b/crates/burn-jit/src/kernel/index/gather.rs @@ -12,7 +12,7 @@ fn gather_kernel( dim: &u32, ) { if ABSOLUTE_POS >= indices.len() { - return; + terminate!(); } let index = indices[ABSOLUTE_POS]; diff --git a/crates/burn-jit/src/kernel/index/repeat_dim.rs b/crates/burn-jit/src/kernel/index/repeat_dim.rs index 3887bfbd8b..b19f9e2b21 100644 --- a/crates/burn-jit/src/kernel/index/repeat_dim.rs +++ b/crates/burn-jit/src/kernel/index/repeat_dim.rs @@ -4,7 +4,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch_unchecked)] fn repeat_dim_kernel(input: &Tensor, output: &mut Tensor, dim: u32) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/scatter.rs b/crates/burn-jit/src/kernel/index/scatter.rs index 4ddd9c00fb..4cca94f824 100644 --- a/crates/burn-jit/src/kernel/index/scatter.rs +++ b/crates/burn-jit/src/kernel/index/scatter.rs @@ -46,7 +46,7 @@ fn scatter_kernel( let should_stop = ABSOLUTE_POS >= num_elems; if should_stop { - return; + terminate!(); } for i in 0..shape_value { diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index b104bf504f..fe664ab420 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -10,7 +10,7 @@ fn select_kernel( dim: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index a0fed49dbd..cd4c013f63 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -29,7 +29,7 @@ fn select_assign_kernel( } if ABSOLUTE_POS >= num_elems { - return; + terminate!(); } let strides_tensor_dim = tensor.stride(dim); diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index 7f20f033b8..b6daba8da5 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -52,7 +52,7 @@ fn slice_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index 1d545d79c7..3f77ef1302 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch)] fn interpolate_bicubic_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/bilinear.rs b/crates/burn-jit/src/kernel/interpolate/bilinear.rs index 3557fcdbb8..f0cb95b536 100644 --- a/crates/burn-jit/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-jit/src/kernel/interpolate/bilinear.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch)] fn interpolate_bilinear_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest.rs b/crates/burn-jit/src/kernel/interpolate/nearest.rs index 0743a13567..0e6ba32552 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch_unchecked)] fn interpolate_nearest_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs index 5ea860a7ae..f0442ec92e 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch_unchecked)] fn interpolate_nearest_backward_kernel(grad: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let out_h = output.shape(2); diff --git a/crates/burn-jit/src/kernel/mask/mask_fill.rs b/crates/burn-jit/src/kernel/mask/mask_fill.rs index 386e7a5039..95096c7994 100644 --- a/crates/burn-jit/src/kernel/mask/mask_fill.rs +++ b/crates/burn-jit/src/kernel/mask/mask_fill.rs @@ -16,7 +16,7 @@ fn mask_fill_readonly_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); @@ -35,7 +35,7 @@ fn mask_fill_inplace_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= input.len() { - return; + terminate!(); } let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); diff --git a/crates/burn-jit/src/kernel/mask/mask_where.rs b/crates/burn-jit/src/kernel/mask/mask_where.rs index 5518e9648b..99384fde98 100644 --- a/crates/burn-jit/src/kernel/mask/mask_where.rs +++ b/crates/burn-jit/src/kernel/mask/mask_where.rs @@ -16,7 +16,7 @@ fn mask_where_readonly_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); @@ -36,7 +36,7 @@ fn mask_where_inplace_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= input.len() { - return; + terminate!(); } let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs index bba68c7166..d2a5a21d0a 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs @@ -24,7 +24,7 @@ fn avg_pool2d_backward_kernel( #[comptime] count_include_pad: bool, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index 6da6e2b37c..40259c4573 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -16,7 +16,7 @@ fn max_pool2d_with_indices_backward_kernel( #[comptime] kernel_size_1: i32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index 72040d8839..270e32f854 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -48,7 +48,7 @@ pub(crate) fn dequantize_per_tensor_affine_int8_kernel( ) { // Last two positions contain the qparams if ABSOLUTE_POS >= input.len() - 2 { - return; + terminate!(); } let qparams = QParams::new(scheme); @@ -85,7 +85,7 @@ pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( ) { // Last position contains the qparam if ABSOLUTE_POS >= input.len() - 1 { - return; + terminate!(); } let qparams = QParams::new(scheme); diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs index e9494aa987..0a7b0ea553 100644 --- a/crates/burn-jit/src/kernel/quantization/quantize.rs +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -34,7 +34,7 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let scale = scale[0]; @@ -43,13 +43,13 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( // Cast the scale to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 1 { output[ABSOLUTE_POS] = u32::bitcast_from(scale); - return; + terminate!(); } // Cast the offset to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 2 { output[ABSOLUTE_POS] = u32::bitcast_from(offset); - return; + terminate!(); } let line_size = comptime!(input.line_size()); @@ -120,7 +120,7 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let scale = scale[0]; @@ -128,7 +128,7 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( // Cast the scale to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 1 { output[ABSOLUTE_POS] = u32::bitcast_from(scale); - return; + terminate!(); } let line_size = comptime!(input.line_size()); diff --git a/crates/burn-jit/src/kernel/unary_float.rs b/crates/burn-jit/src/kernel/unary_float.rs index 33a311ecbc..4664d3c0b3 100644 --- a/crates/burn-jit/src/kernel/unary_float.rs +++ b/crates/burn-jit/src/kernel/unary_float.rs @@ -27,7 +27,7 @@ pub(crate) fn unary_float( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } if comptime![to_contiguous] { diff --git a/crates/burn-jit/src/kernel/unary_int.rs b/crates/burn-jit/src/kernel/unary_int.rs index 5e60898699..17bced52d1 100644 --- a/crates/burn-jit/src/kernel/unary_int.rs +++ b/crates/burn-jit/src/kernel/unary_int.rs @@ -27,7 +27,7 @@ pub(crate) fn unary_int( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } if comptime![to_contiguous] { diff --git a/crates/burn-jit/src/kernel/unary_numeric.rs b/crates/burn-jit/src/kernel/unary_numeric.rs index 0b8dcb2cbc..aaeadbb685 100644 --- a/crates/burn-jit/src/kernel/unary_numeric.rs +++ b/crates/burn-jit/src/kernel/unary_numeric.rs @@ -27,7 +27,7 @@ pub(crate) fn unary_numeric( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } if comptime![to_contiguous] { diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index 2c2c7987ab..cf15916aab 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -31,7 +31,7 @@ pub fn full_device( #[cube(launch)] pub fn full_kernel(tensor: &mut Tensor, value: C) { if ABSOLUTE_POS >= tensor.len() { - return; + terminate!(); } tensor[ABSOLUTE_POS] = value; diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 7aab106b29..3d29d219d0 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -12,8 +12,8 @@ pub use burn_jit::{ pub use burn_jit::{tensor::JitTensor, JitBackend}; pub use burn_jit::{BoolElement, FloatElement, IntElement}; pub use cubecl::flex32; +pub use cubecl::prelude::CubeDim; pub use cubecl::wgpu::*; -pub use cubecl::CubeDim; pub type Wgsl = cubecl::wgpu::WgslCompiler; #[cfg(feature = "spirv")] diff --git a/examples/custom-cubecl-kernel/src/kernel.rs b/examples/custom-cubecl-kernel/src/kernel.rs index 0809971327..08d4ded4d7 100644 --- a/examples/custom-cubecl-kernel/src/kernel.rs +++ b/examples/custom-cubecl-kernel/src/kernel.rs @@ -17,7 +17,7 @@ pub fn fused_matmul_add_relu_kernel( let dim_k = rhs.shape(rhs.rank() - 1); if row >= n_rows || col >= n_cols { - return; + terminate!(); } let offset_output = batch * n_rows * n_cols; From 267ec63eb1f8f6cf0196d555a1059ecfd8398a51 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Sun, 26 Jan 2025 19:17:07 +0100 Subject: [PATCH 02/24] Update to scope merge --- Cargo.lock | 16 +--------------- Cargo.toml | 8 ++++---- crates/burn-jit/src/fusion/matmul/args.rs | 2 +- crates/burn-jit/src/fusion/on_write/ir.rs | 4 ++-- .../src/tensor/quantization/scheme.rs | 2 +- 5 files changed, 9 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f579c27035..2182d011a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1575,7 +1575,6 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1590,7 +1589,6 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1611,7 +1609,6 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1632,7 +1629,6 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-common", @@ -1646,7 +1642,6 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-common", @@ -1662,7 +1657,6 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-common", @@ -1688,7 +1682,6 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -1697,13 +1690,13 @@ dependencies = [ "half", "num-traits", "serde", + "smallvec", "type_hash", ] [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-core", @@ -1715,7 +1708,6 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-common", "darling", @@ -1730,7 +1722,6 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "darling", "proc-macro2", @@ -1741,7 +1732,6 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1757,7 +1747,6 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1767,7 +1756,6 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "async-channel", "async-lock", @@ -1789,7 +1777,6 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1804,7 +1791,6 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 7cc281e7cd..f097d8d3ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } +# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } +# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } ### For local development. ### -# cubecl = { path = "../cubecl/crates/cubecl", default-features = false } -# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } +cubecl = { path = "../cubecl/crates/cubecl", default-features = false } +cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } ### For the release. ### # cubecl = { version = "0.4.0", default-features = false } # cubecl-common = { version = "0.4.0", default-features = false } diff --git a/crates/burn-jit/src/fusion/matmul/args.rs b/crates/burn-jit/src/fusion/matmul/args.rs index 1dbbf3baea..bba18e88f9 100644 --- a/crates/burn-jit/src/fusion/matmul/args.rs +++ b/crates/burn-jit/src/fusion/matmul/args.rs @@ -247,7 +247,7 @@ impl CubeType for FusedMatmulState { } impl Init for FusedMatmulStateExpand { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _context: &mut Scope) -> Self { self } } diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index 0cec2d29c7..36c8e402a0 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -45,13 +45,13 @@ impl CubeType for Arg { } impl Init for Arg { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _context: &mut Scope) -> Self { self } } impl IntoRuntime for Arg { - fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType { + fn __expand_runtime_method(self, _context: &mut Scope) -> Self::ExpandType { self } } diff --git a/crates/burn-tensor/src/tensor/quantization/scheme.rs b/crates/burn-tensor/src/tensor/quantization/scheme.rs index fb141ee16d..27fa996ad6 100644 --- a/crates/burn-tensor/src/tensor/quantization/scheme.rs +++ b/crates/burn-tensor/src/tensor/quantization/scheme.rs @@ -37,7 +37,7 @@ impl CubeType for QuantizationScheme { } #[cfg(feature = "cubecl")] impl cubecl::frontend::Init for QuantizationScheme { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _scope: &mut cubecl::ir::Scope) -> Self { self } } From 71584e0e164f9d4ac74677731691874bad3c7510 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Sun, 26 Jan 2025 20:33:45 +0100 Subject: [PATCH 03/24] Fix bitwise shift --- crates/burn-jit/src/ops/int_ops.rs | 16 ++++------------ crates/burn-tensor/src/tests/ops/bitwise.rs | 4 ++++ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 068c1269d9..8da778d1e8 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -328,26 +328,18 @@ where } fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - let lhs_cast = kernel::cast::(lhs); - let rhs_cast = kernel::cast::(rhs); - launch_binop_int::(lhs_cast, rhs_cast) + launch_binop_int::(lhs, rhs) } fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { - let lhs_cast = kernel::cast::(lhs); - let rhs_cast = rhs.elem::(); - launch_scalar_binop_int::(lhs_cast, rhs_cast) + launch_scalar_binop_int::(lhs, rhs) } fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - let lhs_cast = kernel::cast::(lhs); - let rhs_cast = kernel::cast::(rhs); - launch_binop_int::(lhs_cast, rhs_cast) + launch_binop_int::(lhs, rhs) } fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { - let lhs_cast = kernel::cast::(lhs); - let rhs_cast = rhs.elem::(); - launch_scalar_binop_int::(lhs_cast, rhs_cast) + launch_scalar_binop_int::(lhs, rhs) } } diff --git a/crates/burn-tensor/src/tests/ops/bitwise.rs b/crates/burn-tensor/src/tests/ops/bitwise.rs index 73702a716e..c85f5edcc5 100644 --- a/crates/burn-tensor/src/tests/ops/bitwise.rs +++ b/crates/burn-tensor/src/tests/ops/bitwise.rs @@ -124,6 +124,10 @@ mod tests { #[test] fn should_apply_bitwise_left_shift_2d() { + if (IntType::MAX as u32) < 512 { + return; + } + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); From 58383643b51e21209497afb6a1b5fed412acdd8a Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Tue, 28 Jan 2025 16:11:24 +0100 Subject: [PATCH 04/24] Initial JIT implementation --- Cargo.lock | 15 +- crates/burn-cuda/Cargo.toml | 5 +- crates/burn-cuda/src/lib.rs | 1 + crates/burn-jit/Cargo.toml | 5 +- crates/burn-jit/src/kernel/mod.rs | 4 + .../kernel/vision/connected_components/bke.rs | 388 ++++++++++++++++ .../hardware_accelerated.rs | 422 ++++++++++++++++++ .../kernel/vision/connected_components/mod.rs | 44 ++ crates/burn-jit/src/kernel/vision/mod.rs | 2 + crates/burn-jit/src/kernel/vision/ops.rs | 35 ++ crates/burn-jit/src/lib.rs | 3 +- crates/burn-jit/src/ops/base.rs | 1 + crates/burn-jit/src/ops/numeric.rs | 24 + crates/burn-jit/src/tensor/base.rs | 3 +- crates/burn-vision/Cargo.toml | 23 + .../src/cpu_impl/connected_components.rs | 21 + crates/burn-vision/src/cpu_impl/mod.rs | 3 + crates/burn-vision/src/lib.rs | 9 + crates/burn-vision/src/ops/base.rs | 93 ++++ crates/burn-vision/src/ops/mod.rs | 3 + crates/burn-vision/src/tensor.rs | 28 ++ .../src/tests/connected_components.rs | 179 ++++++++ crates/burn-vision/src/tests/mod.rs | 42 ++ crates/burn-wgpu/Cargo.toml | 3 + crates/burn-wgpu/src/lib.rs | 1 + 25 files changed, 1352 insertions(+), 5 deletions(-) create mode 100644 crates/burn-jit/src/kernel/vision/connected_components/bke.rs create mode 100644 crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs create mode 100644 crates/burn-jit/src/kernel/vision/connected_components/mod.rs create mode 100644 crates/burn-jit/src/kernel/vision/mod.rs create mode 100644 crates/burn-jit/src/kernel/vision/ops.rs create mode 100644 crates/burn-vision/Cargo.toml create mode 100644 crates/burn-vision/src/cpu_impl/connected_components.rs create mode 100644 crates/burn-vision/src/cpu_impl/mod.rs create mode 100644 crates/burn-vision/src/lib.rs create mode 100644 crates/burn-vision/src/ops/base.rs create mode 100644 crates/burn-vision/src/ops/mod.rs create mode 100644 crates/burn-vision/src/tensor.rs create mode 100644 crates/burn-vision/src/tests/connected_components.rs create mode 100644 crates/burn-vision/src/tests/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 2182d011a5..3bfc278cf2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -707,6 +707,7 @@ dependencies = [ "burn-fusion", "burn-jit", "burn-tensor", + "burn-vision", "bytemuck", "cubecl", "derive-new 0.7.0", @@ -822,6 +823,7 @@ dependencies = [ "burn-ndarray", "burn-tensor", "burn-tensor-testgen", + "burn-vision", "bytemuck", "cubecl", "derive-new 0.7.0", @@ -964,6 +966,16 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "burn-vision" +version = "0.17.0" +dependencies = [ + "burn-tensor", + "burn-tensor-testgen", + "cubecl", + "serde", +] + [[package]] name = "burn-wgpu" version = "0.17.0" @@ -971,6 +983,7 @@ dependencies = [ "burn-fusion", "burn-jit", "burn-tensor", + "burn-vision", "cubecl", "half", "paste", @@ -1690,7 +1703,6 @@ dependencies = [ "half", "num-traits", "serde", - "smallvec", "type_hash", ] @@ -1804,6 +1816,7 @@ dependencies = [ "derive-new 0.6.0", "hashbrown 0.14.5", "log", + "sanitize-filename 0.5.0", "web-time", "wgpu", ] diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml index 1a92e695b2..196018ed0e 100644 --- a/crates/burn-cuda/Cargo.toml +++ b/crates/burn-cuda/Cargo.toml @@ -12,8 +12,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda" version.workspace = true [features] -default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"] autotune = ["burn-jit/autotune"] +default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"] doc = ["burn-jit/doc"] fusion = ["burn-fusion", "burn-jit/fusion"] std = ["burn-jit/std", "cubecl/std"] @@ -37,6 +37,9 @@ log = { workspace = true } burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } +burn-vision = { path = "../burn-vision", version = "0.17.0", default-features = false, features = [ + "export_tests", +] } paste = { workspace = true } diff --git a/crates/burn-cuda/src/lib.rs b/crates/burn-cuda/src/lib.rs index 0387da0215..518376d617 100644 --- a/crates/burn-cuda/src/lib.rs +++ b/crates/burn-cuda/src/lib.rs @@ -21,4 +21,5 @@ mod tests { // TODO: Add tests for bf16 burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]); + burn_vision::testgen_all!(); } diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index 214b21eef3..0a7f9f2e82 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -13,13 +13,14 @@ version.workspace = true [features] autotune = [] -default = ["autotune", "std", "fusion", "cubecl/default"] +default = ["autotune", "std", "fusion", "vision", "cubecl/default"] doc = ["default"] export_tests = [ "burn-tensor-testgen", "serial_test", "burn-autodiff/export_tests", "burn-tensor/export_tests", + "burn-vision?/export_tests", "burn-ndarray", "fusion", "paste", @@ -27,6 +28,7 @@ export_tests = [ fusion = ["burn-fusion"] fusion-experimental = ["fusion"] std = ["cubecl/std", "burn-tensor/std"] +vision = ["burn-vision"] template = [] @@ -37,6 +39,7 @@ burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = "cubecl", "repr", ] } +burn-vision = { path = "../burn-vision", version = "0.17.0", optional = true } cubecl = { workspace = true, features = ["linalg", "reduce"] } bytemuck = { workspace = true } diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index 93d2833976..c7f746061f 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -37,6 +37,10 @@ pub mod quantization; /// Reduction algorithms pub mod reduce; +/// Vision algorithms +#[cfg(feature = "vision")] +pub mod vision; + pub(crate) use clamp::*; pub(crate) use comparison::*; pub(crate) use index::*; diff --git a/crates/burn-jit/src/kernel/vision/connected_components/bke.rs b/crates/burn-jit/src/kernel/vision/connected_components/bke.rs new file mode 100644 index 0000000000..331806c237 --- /dev/null +++ b/crates/burn-jit/src/kernel/vision/connected_components/bke.rs @@ -0,0 +1,388 @@ +//! Block-based komura equivalence, adapted from +//! S. Allegretti, F. Bolelli, C. Grana, +//! "Optimized Block-Based Algorithms to Label Connected Components on GPUs," +//! in IEEE Transactions on Parallel and Distributed Systems, 2019. + +use crate::{ + kernel, + ops::numeric::{empty_device, zeros_device}, + tensor::JitTensor, + tests::burn_tensor::{DType, Shape}, + JitElement, JitRuntime, +}; +use cubecl::cube; +use cubecl::prelude::*; + +mod info { + pub const A: u8 = 0; + pub const B: u8 = 1; + pub const C: u8 = 2; + pub const D: u8 = 3; + pub const Q: u8 = 5; + pub const R: u8 = 6; + pub const S: u8 = 7; +} + +#[cube] +fn has_bit(bitmap: I, pos: u8) -> bool { + bool::cast_from((bitmap >> I::cast_from(pos)) & I::new(1)) +} + +#[cube] +fn set_bit(bitmap: I, pos: u8) -> I { + bitmap | (I::new(1) << I::cast_from(pos)) +} + +#[cube] +fn find_root(s_buf: &Tensor>, n: u32) -> u32 { + let mut n = n; + while Atomic::load(&s_buf[n]) != n { + n = Atomic::load(&s_buf[n]); + } + n +} + +#[cube] +fn find_root_and_compress(s_buf: &mut Tensor, id: u32) -> u32 { + let mut n = id; + while s_buf[n] != n { + n = s_buf[n]; + s_buf[id] = n; + } + n +} + +#[cube] +fn tree_union(s_buf: &Tensor>, a: u32, b: u32) { + let mut a = a; + let mut b = b; + #[allow(unused_assignments)] + let mut done = false; + + loop { + a = find_root(s_buf, a); + b = find_root(s_buf, b); + + #[allow(clippy::comparison_chain, reason = "not supported in cubecl")] + if a < b { + let old = Atomic::min(&s_buf[b], a); + done = old == b; + b = old; + } else if b < a { + let old = Atomic::min(&s_buf[a], b); + done = old == a; + a = old; + } else { + done = true; + } + + if done { + break; + } + } +} + +#[cube(launch)] +fn init_labeling(img: &Tensor, labels: &mut Tensor, last_pixel: &mut Array) { + let batch = ABSOLUTE_POS_Z; + let row = ABSOLUTE_POS_Y * 2; + let col = ABSOLUTE_POS_X * 2; + + if row >= labels.shape(1) || col >= labels.shape(2) { + terminate!(); + } + + let img_rows = img.shape(2); + let img_cols = img.shape(3); + let img_stride = img.stride(2); + let labels_stride = labels.stride(1); + + let img_index = batch * img.stride(0) + row * img_stride + col * img.stride(3); + let labels_index = batch * labels.stride(0) + row * labels_stride + col * labels.stride(2); + + let mut p = 0u16; + + // Bitmask representing two kinds of information + // Bits 0, 1, 2, 3 are set if pixel a, b, c, d are foreground, respectively + // Bits 4, 5, 6, 7 are set if block P, Q, R, S need to be merged to X in Merge phase + let mut info = 0u8; + + let mut buffer = Array::::new(4); + #[unroll] + for i in 0..4 { + buffer[i] = 0; + } + + if col + 1 < img_cols { + buffer[0] = img[img_index]; + buffer[1] = img[img_index + 1]; + + if row + 1 < img_rows { + buffer[2] = img[img_index + img_stride]; + buffer[3] = img[img_index + img_stride + 1]; + } + } else { + buffer[0] = img[img_index]; + + if row + 1 < img_rows { + buffer[2] = img[img_index + img_stride]; + } + } + + if buffer[0] != 0 { + p |= 0x777; + info = set_bit::(info, info::A); + } + if buffer[1] != 0 { + p |= 0x777 << 1; + info = set_bit::(info, info::B); + } + if buffer[2] != 0 { + p |= 0x777 << 4; + info = set_bit::(info, info::C); + } + if buffer[3] != 0 { + info = set_bit::(info, info::D); + } + + if col == 0 { + p &= 0xeeee; + } + if col + 1 >= img_cols { + p &= 0x3333; + } else if col + 2 >= img_cols { + p &= 0x7777; + } + + if row == 0 { + p &= 0xfff0; + } + if row + 1 >= img_rows { + p &= 0x00ff; + } else if row + 2 >= img_rows { + p &= 0x0fff; + } + + // P is now ready to be used to find neighbor blocks + // P value avoids range errors + + let mut father_offset = 0i32; + + // P square + if has_bit::(p, 0) && img[img_index - img_stride - 1] != 0 { + father_offset = -(2 * labels_stride as i32 + 2); + } + + // Q square + if (has_bit::(p, 1) && img[img_index - img_stride] != 0) + || (has_bit::(p, 2) && img[img_index + 1 - img_stride] != 0) + { + if father_offset == 0 { + father_offset = -(2 * labels_stride as i32); + } else { + info = set_bit::(info, info::Q); + } + } + + // R square + if has_bit::(p, 3) && img[img_index + 2 - img_stride] != 0 { + if father_offset == 0 { + father_offset = -(2 * labels_stride as i32 - 2); + } else { + info = set_bit::(info, info::R); + } + } + + // S square + if (has_bit::(p, 4) && img[img_index - 1] != 0) + || (has_bit::(p, 8) && img[img_index + img_stride - 1] != 0) + { + if father_offset == 0 { + father_offset = -2i32; + } else { + info = set_bit::(info, info::S); + } + } + + labels[labels_index] = labels_index as i32 + father_offset; + if col + 1 < labels.shape(2) { + labels[labels_index + 1] = info as i32; + } else if row + 1 < labels.shape(1) { + labels[labels_index + labels_stride] = info as i32; + } else { + last_pixel[0] = info; + } +} + +#[cube(launch)] +fn merge(labels: &mut Tensor>, last_pixel: &mut Array) { + let batch = ABSOLUTE_POS_Z; + let row = ABSOLUTE_POS_Y * 2; + let col = ABSOLUTE_POS_X * 2; + let rows = labels.shape(1); + let cols = labels.shape(2); + let labels_stride = labels.stride(1); + let labels_index = batch * labels.stride(0) + row * labels_stride + col; + + if row >= labels.shape(1) || col >= labels.shape(2) { + terminate!(); + } + + let info = if col + 1 < cols { + Atomic::load(&labels[labels_index + 1]) as u8 + } else if row + 1 < rows { + Atomic::load(&labels[labels_index + labels_stride]) as u8 + } else { + last_pixel[0] + }; + + if has_bit::(info, info::Q) { + tree_union(labels, labels_index, labels_index - 2 * labels_stride); + } + if has_bit::(info, info::R) { + tree_union(labels, labels_index, labels_index - 2 * labels_stride + 2); + } + if has_bit::(info, info::S) { + tree_union(labels, labels_index, labels_index - 1); + } +} + +#[cube(launch)] +fn compression(labels: &mut Tensor) { + let batch = ABSOLUTE_POS_Z; + let row = ABSOLUTE_POS_Y * 2; + let col = ABSOLUTE_POS_X * 2; + let labels_index = batch * labels.stride(0) + row * labels.stride(1) + col; + + if row < labels.shape(1) && col < labels.shape(2) { + find_root_and_compress(labels, labels_index); + } +} + +#[cube(launch)] +fn final_labeling(img: &Tensor, labels: &mut Tensor) { + let batch = ABSOLUTE_POS_Z; + let row = ABSOLUTE_POS_Y * 2; + let col = ABSOLUTE_POS_X * 2; + let rows = labels.shape(1); + let cols = labels.shape(2); + let label_stride = labels.stride(1); + let img_stride = img.stride(2); + let labels_index = batch * labels.stride(0) + row * label_stride + col; + + if row >= labels.shape(1) || col >= labels.shape(2) { + terminate!(); + } + + let mut label = 0; + #[allow(unused_assignments)] + let mut info = 0u8; + let mut buffer = Array::::new(2); + + if col + 1 < cols { + buffer[0] = label[labels_index]; + buffer[1] = label[labels_index + 1]; + label = buffer[0] + 1; + info = buffer[1] as u8; + } else { + label = labels[labels_index] + 1; + if row + 1 < rows { + info = labels[labels_index + label_stride] as u8; + } else { + // Read from the input image + // "a" is already in position 0 + info = img[batch * img.stride(0) + row * img_stride + col]; + } + } + + if col + 1 < cols { + labels[labels_index] = has_bit::(info, info::B) as u32 * label; + labels[labels_index + 1] = has_bit::(info, info::A) as u32 * label; + + if row + 1 < rows { + labels[labels_index + label_stride] = has_bit::(info, info::D) as u32 * label; + labels[labels_index + label_stride + 1] = has_bit::(info, info::C) as u32 * label; + } + } else { + labels[labels_index] = has_bit::(info, info::A) as u32 * label; + + if row + 1 < rows { + labels[labels_index + label_stride] = has_bit::(info, info::C) as u32 * label; + } + } +} + +#[expect( + unused, + reason = "currently broken because kernel reassigns pointers and I need to figure out how to port that" +)] +pub fn block_based_komura_equivalence( + img: JitTensor, +) -> JitTensor { + let img = kernel::cast::(img); + + let [batches, channels, rows, columns] = img.shape.dims(); + assert_eq!(channels, 1, "Channels must be 1 for connected components"); + + let shape = Shape::new([batches, rows, columns]); + let labels = zeros_device::(img.client.clone(), img.device.clone(), shape); + + let last_pixel = if (rows == 1 || columns == 1) && (rows + columns) % 2 == 0 { + empty_device::(img.client.clone(), img.device.clone(), Shape::new([1])) + } else { + let offset = (((rows - 2) * labels.strides[2]) + (columns - 2)) * size_of::(); + JitTensor::new_contiguous( + labels.client.clone(), + labels.device.clone(), + Shape::new([1]), + labels.handle.clone().offset_start(offset as u64), + DType::U8, + ) + }; + + let cube_dim = CubeDim::default(); + let cube_count_x = (columns as u32).div_ceil(2).div_ceil(cube_dim.x); + let cube_count_y = (rows as u32).div_ceil(2).div_ceil(cube_dim.y); + let cube_count = CubeCount::Static(cube_count_x, cube_count_y, batches as u32); + + init_labeling::launch( + &img.client, + cube_count.clone(), + cube_dim, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + last_pixel.as_array_arg::(1), + ); + + compression::launch( + &img.client, + cube_count.clone(), + cube_dim, + labels.as_tensor_arg::(1), + ); + + merge::launch( + &img.client, + cube_count.clone(), + cube_dim, + labels.as_tensor_arg::(1), + last_pixel.as_array_arg::(1), + ); + + compression::launch( + &img.client, + cube_count.clone(), + cube_dim, + labels.as_tensor_arg::(1), + ); + + final_labeling::launch( + &img.client, + cube_count.clone(), + cube_dim, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + ); + + labels +} diff --git a/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs b/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs new file mode 100644 index 0000000000..41918dac96 --- /dev/null +++ b/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs @@ -0,0 +1,422 @@ +//! Hardware Accelerated 4-connected, adapted from +//! A. Hennequin, L. Lacassagne, L. Cabaret, Q. Meunier, +//! "A new Direct Connected Component Labeling and Analysis Algorithms for GPUs", +//! DASIP, 2018 + +use crate::{ + kernel::vision::connected_components::stats_from_opts, ops::numeric::zeros_device, + tensor::JitTensor, BoolElement, FloatElement, IntElement, JitBackend, JitRuntime, +}; +use burn_tensor::Shape; +use burn_vision::{ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity}; +use cubecl::{prelude::*, Feature}; + +const BLOCK_H: u32 = 4; + +#[cube] +fn merge(labels: &Tensor>, label_1: u32, label_2: u32) { + let mut label_1 = label_1; + let mut label_2 = label_2; + + while label_1 != label_2 && (label_1 != Atomic::load(&labels[label_1]) - 1) { + label_1 = Atomic::load(&labels[label_1]) - 1; + } + while label_1 != label_2 && (label_2 != Atomic::load(&labels[label_2]) - 1) { + label_2 = Atomic::load(&labels[label_2]) - 1; + } + while label_1 != label_2 { + #[allow(clippy::manual_swap)] + if label_1 < label_2 { + let tmp = label_1; + label_1 = label_2; + label_2 = tmp; + } + let label_3 = Atomic::min(&labels[label_1], label_2 + 1) - 1; + if label_1 == label_3 { + label_1 = label_2; + } else { + label_1 = label_3; + } + } +} + +#[cube] +fn start_distance(pixels: u32, tx: u32) -> u32 { + u32::leading_zeros(u32::bitwise_not(pixels << (32 - tx))) +} + +#[cube] +fn end_distance(pixels: u32, tx: u32) -> u32 { + u32::find_first_set(u32::bitwise_not(pixels >> (tx + 1))) +} + +#[cube(launch)] +fn strip_labeling( + img: &Tensor, + labels: &Tensor>, + #[comptime] connectivity: Connectivity, +) { + let mut shared_pixels = SharedMemory::::new(BLOCK_H); + + let batch = ABSOLUTE_POS_Z; + let y = ABSOLUTE_POS_Y; + let rows = labels.shape(1); + let cols = labels.shape(2); + + if y >= rows { + terminate!(); + } + + let img_stride = img.stride(2); + let labels_stride = labels.stride(1); + + let img_line_base = batch * img.stride(0) + y * img_stride + UNIT_POS_X; + let labels_line_base = batch * labels.stride(0) + y * labels.stride(1) + UNIT_POS_X; + + let mut distance_y = 0; + let mut distance_y_1 = 0; + + for i in range_stepped(0, img.shape(3), PLANE_DIM) { + let x = UNIT_POS_X + i; + + if x < cols { + let mut mask = 0xffffffffu32; + let involved_cols = cols - i; + if involved_cols < 32 { + mask >>= 32 - involved_cols; + } + + let img_index = img_line_base + i; + let labels_index = labels_line_base + i; + + let p_y = bool::cast_from(img[img_index]); + + let pixels_y = plane_ballot(p_y)[0] & mask; + let mut s_dist_y = start_distance(pixels_y, UNIT_POS_X); + + if p_y && s_dist_y == 0 { + Atomic::store( + &labels[labels_index], + labels_index - select(UNIT_POS_X == 0, distance_y, 0) + 1, + ); + } + + // Only needed pre-Volta, but we can't check that at present + sync_units(); + + if UNIT_POS_X == 0 { + shared_pixels[UNIT_POS_Y] = pixels_y; + } + + sync_units(); + + // Requires if and not select, because `select` may execute the then branch even if the + // condition is false (on non-CUDA backends), which can lead to OOB reads. + let pixels_y_1 = if UNIT_POS_Y > 0 { + shared_pixels[UNIT_POS_Y - 1] + } else { + 0u32 + }; + + let p_y_1 = (pixels_y_1 >> UNIT_POS_X) & 1 != 0; + let mut s_dist_y_1 = start_distance(pixels_y_1, UNIT_POS_X); + + if UNIT_POS_X == 0 { + s_dist_y = distance_y; + s_dist_y_1 = distance_y_1; + } + + match connectivity { + Connectivity::Four => { + if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) { + let label_1 = labels_index - s_dist_y; + let label_2 = labels_index - s_dist_y_1 - labels_stride; + merge(labels, label_1, label_2); + } + } + Connectivity::Eight => { + let pixels_y_shifted = (pixels_y << 1) | (distance_y > 0) as u32; + let pixels_y_1_shifted = (pixels_y_1 << 1) | (distance_y_1 > 0) as u32; + + if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) { + let label_1 = labels_index - s_dist_y; + let label_2 = labels_index - s_dist_y_1 - labels_stride; + merge(labels, label_1, label_2); + } else if p_y && s_dist_y == 0 && (pixels_y_1_shifted >> UNIT_POS_X) & 1 != 0 { + let s_dist_y_1_prev = select( + UNIT_POS_X == 0, + distance_y_1 - 1, + start_distance(pixels_y_1, UNIT_POS_X - 1), + ); + let label_1 = labels_index; + let label_2 = labels_index - labels_stride - 1 - s_dist_y_1_prev; + merge(labels, label_1, label_2); + } else if p_y_1 && s_dist_y_1 == 0 && (pixels_y_shifted >> UNIT_POS_X) & 1 != 0 + { + let s_dist_y_prev = select( + UNIT_POS_X == 0, + distance_y - 1, + start_distance(pixels_y, UNIT_POS_X - 1), + ); + let label_1 = labels_index - 1 - s_dist_y_prev; + let label_2 = labels_index - labels_stride; + merge(labels, label_1, label_2); + } + } + } + + if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) { + let label_1 = labels_index - s_dist_y; + let label_2 = labels_index - s_dist_y_1 - labels_stride; + merge(labels, label_1, label_2); + } + + let mut d = start_distance(pixels_y_1, 32); + distance_y_1 = d + select(d == 32, distance_y_1, 0); + d = start_distance(pixels_y, 32); + distance_y = d + select(d == 32, distance_y, 0); + } + } +} + +#[cube(launch)] +fn strip_merge(img: &Tensor, labels: &Tensor>) { + let batch = ABSOLUTE_POS_Z; + let y = ABSOLUTE_POS_Y * BLOCK_H; + let x = ABSOLUTE_POS_X; + + let img_step = img.stride(2); + let labels_step = labels.stride(1); + let cols = img.shape(3); + + if y < labels.shape(1) && x < labels.shape(2) && y > 0 { + let mut mask = 0xffffffffu32; + if cols - CUBE_POS_X * CUBE_DIM_X < 32 { + mask >>= 32 - (cols - CUBE_POS_X * CUBE_DIM_X); + } + + let img_index = batch * img.stride(0) + y * img_step + x; + let labels_index = batch * labels.stride(0) + y * labels_step + x; + + let img_index_up = img_index - img_step; + let labels_index_up = labels_index - labels_step; + + let p = bool::cast_from(img[img_index]); + let p_up = bool::cast_from(img[img_index_up]); + + let pixels = plane_ballot(p)[0] & mask; + let pixels_up = plane_ballot(p_up)[0] & mask; + + if p && p_up { + let s_dist = start_distance(pixels, UNIT_POS_X); + let s_dist_up = start_distance(pixels_up, UNIT_POS_X); + if s_dist == 0 || s_dist_up == 0 { + merge(labels, labels_index - s_dist, labels_index_up - s_dist_up); + } + } + } +} + +#[cube(launch)] +fn relabeling(img: &Tensor, labels: &mut Tensor) { + let batch = ABSOLUTE_POS_Z; + let y = ABSOLUTE_POS_Y; + let x = ABSOLUTE_POS_X; + + let cols = labels.shape(2); + let rows = labels.shape(1); + let img_step = img.stride(2); + let labels_step = labels.stride(1); + + if x < cols && y < rows { + let mut mask = 0xffffffffu32; + if cols - CUBE_POS_X * CUBE_DIM_X < 32 { + mask >>= 32 - (cols - CUBE_POS_X * CUBE_DIM_X); + } + + let img_index = batch * img.stride(0) + y * img_step + x; + let labels_index = batch * labels.stride(0) + y * labels_step + x; + + let p = bool::cast_from(img[img_index]); + let pixels = plane_ballot(p)[0] & mask; + let s_dist = start_distance(pixels, UNIT_POS_X); + let mut label = 0u32; + + if p && s_dist == 0 { + label = labels[labels_index] - 1; + while label != labels[label] - 1 { + label = labels[label] - 1; + } + } + + label = plane_broadcast(label, UNIT_POS_X - s_dist); + + if p { + labels[labels_index] = label + 1; + } + } +} + +#[cube(launch)] +fn analysis( + img: &Tensor, + labels: &mut Tensor, + area: &mut Tensor>, + top: &mut Tensor>, + left: &mut Tensor>, + right: &mut Tensor>, + bottom: &mut Tensor>, + #[comptime] opts: ConnectedStatsOptions, +) { + let batch = ABSOLUTE_POS_Z; + let y = ABSOLUTE_POS_Y; + let x = ABSOLUTE_POS_X; + + let cols = labels.shape(2); + let rows = labels.shape(1); + let img_step = img.stride(2); + let labels_step = labels.stride(1); + + if x < cols && y < rows { + let mut mask = 0xffffffffu32; + if cols - CUBE_POS_X * CUBE_DIM_X < 32 { + mask >>= 32 - (cols - CUBE_POS_X * CUBE_DIM_X); + } + + let img_index = batch * img.stride(0) + y * img_step + x; + let labels_index = batch * labels.stride(0) + y * labels_step + x; + + let p = bool::cast_from(img[img_index]); + let pixels = plane_ballot(p)[0] & mask; + let s_dist = start_distance(pixels, UNIT_POS_X); + let count = end_distance(pixels, UNIT_POS_X); + let max_x = x + count - 1; + + let mut label = 0u32; + + if p && s_dist == 0 { + label = labels[labels_index] - 1; + while label != labels[label] - 1 { + label = labels[label] - 1; + } + + if opts.area_enabled { + Atomic::add(&area[label], count); + } + if opts.left_enabled { + Atomic::min(&left[label], x); + } + if opts.top_enabled { + Atomic::min(&top[label], y); + } + if opts.right_enabled { + Atomic::max(&right[label], max_x); + } + if opts.bottom_enabled { + Atomic::max(&bottom[label], y); + } + } + + label = plane_broadcast(label, UNIT_POS_X - s_dist); + + if p { + labels[labels_index] = label + 1; + } + } +} + +#[allow(clippy::type_complexity)] +pub fn hardware_accelerated( + img: JitTensor, + stats_opt: ConnectedStatsOptions, + connectivity: Connectivity, +) -> Result< + ( + JitTensor, + ConnectedStatsPrimitive>, + ), + String, +> { + let client = img.client.clone(); + let device = img.device.clone(); + + if !client.properties().feature_enabled(Feature::Plane) { + return Err("Requires plane instructions".into()); + } + + let props = client.properties().hardware_properties(); + + if props.plane_size_min != 32 || props.plane_size_min != props.plane_size_max { + return Err( + "Currently only supports 32 wide planes because it's heavily tied to plane op width" + .into(), + ); + } + + let [batches, channels, rows, cols] = img.shape.dims(); + assert_eq!(channels, 1); + + let shape = Shape::new([batches, rows, cols]); + let labels = zeros_device::(client.clone(), device.clone(), shape); + + let warp_size = 32; + let cube_dim = CubeDim::new_2d(warp_size, BLOCK_H); + let cube_count = CubeCount::Static(1, (rows as u32).div_ceil(cube_dim.y), batches as u32); + + strip_labeling::launch::( + &client, + cube_count, + cube_dim, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + connectivity, + ); + + let cube_count = CubeCount::Static( + (cols as u32).div_ceil(cube_dim.x), + (rows as u32).div_ceil(BLOCK_H).div_ceil(cube_dim.y), + batches as u32, + ); + + strip_merge::launch::( + &client, + cube_count, + cube_dim, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + ); + + let cube_count = CubeCount::Static( + (cols as u32).div_ceil(cube_dim.x), + (rows as u32).div_ceil(cube_dim.y), + batches as u32, + ); + + let stats = stats_from_opts(labels.clone(), stats_opt); + + if stats_opt == ConnectedStatsOptions::none() { + relabeling::launch::( + &client, + cube_count, + cube_dim, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + ); + } else { + analysis::launch::( + &client, + cube_count, + cube_dim, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + stats.area.as_tensor_arg::(1), + stats.top.as_tensor_arg::(1), + stats.left.as_tensor_arg::(1), + stats.right.as_tensor_arg::(1), + stats.bottom.as_tensor_arg::(1), + stats_opt, + ); + } + + Ok((labels, stats)) +} diff --git a/crates/burn-jit/src/kernel/vision/connected_components/mod.rs b/crates/burn-jit/src/kernel/vision/connected_components/mod.rs new file mode 100644 index 0000000000..668348d574 --- /dev/null +++ b/crates/burn-jit/src/kernel/vision/connected_components/mod.rs @@ -0,0 +1,44 @@ +use crate::{ + ops::numeric::{full_device, zeros_device}, + tensor::JitTensor, + BoolElement, FloatElement, IntElement, JitBackend, JitRuntime, +}; + +mod bke; +mod hardware_accelerated; + +use burn_tensor::Shape; +use burn_vision::{ConnectedStatsOptions, ConnectedStatsPrimitive}; +pub use hardware_accelerated::*; + +pub(crate) fn stats_from_opts( + l: JitTensor, + opts: ConnectedStatsOptions, +) -> ConnectedStatsPrimitive> +where + R: JitRuntime, + F: FloatElement, + I: IntElement, + BT: BoolElement, +{ + let [batches, height, width] = l.shape.dims(); + let shape = Shape::new([batches, height * width]); + let zeros = || zeros_device::(l.client.clone(), l.device.clone(), shape.clone()); + let max = || full_device::(l.client.clone(), shape.clone(), l.device.clone(), u32::MAX); + let dummy = || { + JitTensor::new_contiguous( + l.client.clone(), + l.device.clone(), + shape.clone(), + l.handle.clone(), + l.dtype, + ) + }; + ConnectedStatsPrimitive { + area: opts.area_enabled.then(zeros).unwrap_or_else(dummy), + left: opts.left_enabled.then(max).unwrap_or_else(dummy), + top: opts.top_enabled.then(max).unwrap_or_else(dummy), + right: opts.right_enabled.then(zeros).unwrap_or_else(dummy), + bottom: opts.bottom_enabled.then(zeros).unwrap_or_else(dummy), + } +} diff --git a/crates/burn-jit/src/kernel/vision/mod.rs b/crates/burn-jit/src/kernel/vision/mod.rs new file mode 100644 index 0000000000..9d610df49a --- /dev/null +++ b/crates/burn-jit/src/kernel/vision/mod.rs @@ -0,0 +1,2 @@ +mod connected_components; +mod ops; diff --git a/crates/burn-jit/src/kernel/vision/ops.rs b/crates/burn-jit/src/kernel/vision/ops.rs new file mode 100644 index 0000000000..006a2ac012 --- /dev/null +++ b/crates/burn-jit/src/kernel/vision/ops.rs @@ -0,0 +1,35 @@ +use crate::{BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; +use burn_tensor::ops::{BoolTensor, IntTensor}; +use burn_vision::{ + cpu_impl, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, VisionOps, +}; + +use super::connected_components::hardware_accelerated; + +impl VisionOps for JitBackend +where + R: JitRuntime, + F: FloatElement, + I: IntElement, + BT: BoolElement, +{ + fn connected_components(img: BoolTensor, connectivity: Connectivity) -> IntTensor { + hardware_accelerated::( + img.clone(), + ConnectedStatsOptions::none(), + connectivity, + ) + .map(|it| it.0) + .unwrap_or_else(|_| cpu_impl::connected_components::(img, connectivity)) + } + + fn connected_components_with_stats( + img: BoolTensor, + connectivity: Connectivity, + opts: ConnectedStatsOptions, + ) -> (IntTensor, ConnectedStatsPrimitive) { + hardware_accelerated::(img.clone(), opts, connectivity).unwrap_or_else(|_| { + cpu_impl::connected_components_with_stats::(img, connectivity, opts) + }) + } +} diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index acf69d9aec..ae15fb945f 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -7,7 +7,8 @@ extern crate derive_new; extern crate alloc; -mod ops; +/// Utilities for implementing JIT kernels +pub mod ops; /// Kernel module pub mod kernel; diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index 645aaf1535..9327e1fc92 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -76,6 +76,7 @@ pub(crate) fn swap_dims( tensor } +/// Permute a tensor's dimensions pub fn permute(mut tensor: JitTensor, axes: &[usize]) -> JitTensor { // remap strides tensor.strides = axes.iter().map(|i| tensor.strides[*i]).collect(); diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index cf15916aab..432276ccb6 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -9,6 +9,7 @@ use cubecl::client::ComputeClient; use cubecl::tensor_vectorization_factor; use cubecl::{calculate_cube_count_elemwise, prelude::*}; +/// Create a tensor filled with `value` pub fn full( shape: Shape, device: &R::Device, @@ -19,6 +20,7 @@ pub fn full( full_device::(client, shape, device.clone(), value) } +/// Create a tensor filled with `value` pub fn full_device( client: ComputeClient, shape: Shape, @@ -56,12 +58,14 @@ pub fn full_device( empty } +/// Create a tensor filled with zeros pub fn zeros(shape: Shape, device: &R::Device) -> JitTensor { let client = R::client(device); zeros_device::(client, device.clone(), shape) } +/// Create a tensor filled with zeros pub fn zeros_device( client: ComputeClient, device: R::Device, @@ -70,12 +74,14 @@ pub fn zeros_device( full_device::(client, shape, device, 0.elem()) } +/// Create a tensor filled with ones pub fn ones(shape: Shape, device: &R::Device) -> JitTensor { let client = R::client(device); ones_device::(client, device.clone(), shape) } +/// Create a tensor filled with ones pub fn ones_device( client: ComputeClient, device: R::Device, @@ -84,6 +90,7 @@ pub fn ones_device( full_device::(client, shape, device, 1.elem()) } +/// Create a tensor with uninitialized memory pub fn empty_device( client: ComputeClient, device: R::Device, @@ -94,38 +101,47 @@ pub fn empty_device( JitTensor::new_contiguous(client, device, shape, buffer, E::dtype()) } +/// Add two tensors pub fn add(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } +/// Add a tensor and a scalar pub fn add_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } +/// Subtract two tensors pub fn sub(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } +/// Subtract a tensor and a scalar pub fn sub_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } +/// Multiply two tensors pub fn mul(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } +/// Multiply a tensor and a scalar pub fn mul_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } +/// Divide two tensors pub fn div(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::(lhs, rhs) } +/// Divide a tensor by a scalar pub fn div_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } +/// Calculate remainder of two tensors pub fn remainder( lhs: JitTensor, rhs: JitTensor, @@ -133,14 +149,17 @@ pub fn remainder( launch_binop::(lhs, rhs) } +/// Calculate the remainder of a tensor with a scalar pub fn remainder_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop::(lhs, rhs) } +/// Calculate the power of two tensors pub fn pow(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::>(lhs, rhs) } +/// Bitwise and two tensors pub fn bitwise_and( lhs: JitTensor, rhs: JitTensor, @@ -148,10 +167,12 @@ pub fn bitwise_and( launch_binop_int::(lhs, rhs) } +/// Bitwise and with a scalar pub fn bitwise_and_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop_int::(lhs, rhs) } +/// Bitwise or two tensors pub fn bitwise_or( lhs: JitTensor, rhs: JitTensor, @@ -159,10 +180,12 @@ pub fn bitwise_or( launch_binop_int::(lhs, rhs) } +/// Bitwise or with a scalar pub fn bitwise_or_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop_int::(lhs, rhs) } +/// Bitwise xor two tensors pub fn bitwise_xor( lhs: JitTensor, rhs: JitTensor, @@ -170,6 +193,7 @@ pub fn bitwise_xor( launch_binop_int::(lhs, rhs) } +/// Bitwise xor with a scalar pub fn bitwise_xor_scalar(lhs: JitTensor, rhs: E) -> JitTensor { launch_scalar_binop_int::(lhs, rhs) } diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index b586c4a6b7..7b72073c06 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -23,7 +23,8 @@ pub struct JitTensor { pub device: R::Device, /// The strides of the tensor. pub strides: Vec, - pub(crate) dtype: DType, + /// The datatype of the tensor. + pub dtype: DType, } impl From> for TensorHandle { diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml new file mode 100644 index 0000000000..04979aaebd --- /dev/null +++ b/crates/burn-vision/Cargo.toml @@ -0,0 +1,23 @@ +[package] +authors = ["nathanielsimard "] +categories = ["science"] +description = "Vision processing operations for burn tensors" +documentation = "https://docs.rs/burn-vision" +edition.workspace = true +keywords = ["deep-learning", "machine-learning", "gpu"] +license.workspace = true +name = "burn-vision" +readme.workspace = true +repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-vision" +version.workspace = true + + +[features] +export_tests = ["burn-tensor-testgen"] + +[dependencies] +burn-tensor = { path = "../burn-tensor" } +cubecl = { workspace = true } +serde = { workspace = true } + +burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } diff --git a/crates/burn-vision/src/cpu_impl/connected_components.rs b/crates/burn-vision/src/cpu_impl/connected_components.rs new file mode 100644 index 0000000000..c2065787b3 --- /dev/null +++ b/crates/burn-vision/src/cpu_impl/connected_components.rs @@ -0,0 +1,21 @@ +use burn_tensor::{ + backend::Backend, + ops::{BoolTensor, IntTensor}, +}; + +use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity}; + +pub fn connected_components( + _img: BoolTensor, + _connectivity: Connectivity, +) -> IntTensor { + todo!() +} + +pub fn connected_components_with_stats( + _img: BoolTensor, + _connectivity: Connectivity, + _options: ConnectedStatsOptions, +) -> (IntTensor, ConnectedStatsPrimitive) { + todo!() +} diff --git a/crates/burn-vision/src/cpu_impl/mod.rs b/crates/burn-vision/src/cpu_impl/mod.rs new file mode 100644 index 0000000000..6f51d94902 --- /dev/null +++ b/crates/burn-vision/src/cpu_impl/mod.rs @@ -0,0 +1,3 @@ +mod connected_components; + +pub use connected_components::*; diff --git a/crates/burn-vision/src/lib.rs b/crates/burn-vision/src/lib.rs new file mode 100644 index 0000000000..adbce61216 --- /dev/null +++ b/crates/burn-vision/src/lib.rs @@ -0,0 +1,9 @@ +pub mod cpu_impl; +mod ops; +mod tensor; + +#[cfg(feature = "export_tests")] +mod tests; + +pub use ops::*; +pub use tensor::*; diff --git a/crates/burn-vision/src/ops/base.rs b/crates/burn-vision/src/ops/base.rs new file mode 100644 index 0000000000..1b0d46193a --- /dev/null +++ b/crates/burn-vision/src/ops/base.rs @@ -0,0 +1,93 @@ +use burn_tensor::{ + backend::Backend, + ops::{BoolTensor, IntTensor}, + Int, Tensor, +}; + +use crate::cpu_impl; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Connectivity { + Four, + Eight, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct ConnectedStatsOptions { + pub area_enabled: bool, + pub top_enabled: bool, + pub left_enabled: bool, + pub right_enabled: bool, + pub bottom_enabled: bool, +} + +#[derive(Clone, Debug)] +pub struct ConnectedStats { + pub area: Tensor, + pub top: Tensor, + pub left: Tensor, + pub right: Tensor, + pub bottom: Tensor, +} + +pub struct ConnectedStatsPrimitive { + pub area: IntTensor, + pub left: IntTensor, + pub top: IntTensor, + pub right: IntTensor, + pub bottom: IntTensor, +} + +impl From> for ConnectedStats { + fn from(value: ConnectedStatsPrimitive) -> Self { + ConnectedStats { + area: Tensor::from_primitive(value.area), + top: Tensor::from_primitive(value.top), + left: Tensor::from_primitive(value.left), + right: Tensor::from_primitive(value.right), + bottom: Tensor::from_primitive(value.bottom), + } + } +} + +impl Default for ConnectedStatsOptions { + fn default() -> Self { + Self::all() + } +} + +impl ConnectedStatsOptions { + pub fn none() -> Self { + Self { + area_enabled: false, + top_enabled: false, + left_enabled: false, + right_enabled: false, + bottom_enabled: false, + } + } + + pub fn all() -> Self { + Self { + area_enabled: true, + top_enabled: true, + left_enabled: true, + right_enabled: true, + bottom_enabled: true, + } + } +} + +pub trait VisionOps { + fn connected_components(img: BoolTensor, connectivity: Connectivity) -> IntTensor { + cpu_impl::connected_components::(img, connectivity) + } + + fn connected_components_with_stats( + img: BoolTensor, + connectivity: Connectivity, + opts: ConnectedStatsOptions, + ) -> (IntTensor, ConnectedStatsPrimitive) { + cpu_impl::connected_components_with_stats(img, connectivity, opts) + } +} diff --git a/crates/burn-vision/src/ops/mod.rs b/crates/burn-vision/src/ops/mod.rs new file mode 100644 index 0000000000..cbcb6ac7e7 --- /dev/null +++ b/crates/burn-vision/src/ops/mod.rs @@ -0,0 +1,3 @@ +mod base; + +pub use base::*; diff --git a/crates/burn-vision/src/tensor.rs b/crates/burn-vision/src/tensor.rs new file mode 100644 index 0000000000..c7432b7d5f --- /dev/null +++ b/crates/burn-vision/src/tensor.rs @@ -0,0 +1,28 @@ +use burn_tensor::{backend::Backend, Bool, Int, Tensor}; + +use crate::{ConnectedStats, ConnectedStatsOptions, Connectivity, VisionOps}; + +pub trait ConnectedComponents { + fn connected_components(self, connectivity: Connectivity) -> Tensor; + fn connected_components_with_stats( + self, + connectivity: Connectivity, + options: ConnectedStatsOptions, + ) -> (Tensor, ConnectedStats); +} + +impl> ConnectedComponents for Tensor { + fn connected_components(self, connectivity: Connectivity) -> Tensor { + Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity)) + } + + fn connected_components_with_stats( + self, + connectivity: Connectivity, + options: ConnectedStatsOptions, + ) -> (Tensor, ConnectedStats) { + let (labels, stats) = + B::connected_components_with_stats(self.into_primitive(), connectivity, options); + (Tensor::from_primitive(labels), stats.into()) + } +} diff --git a/crates/burn-vision/src/tests/connected_components.rs b/crates/burn-vision/src/tests/connected_components.rs new file mode 100644 index 0000000000..9216d4714e --- /dev/null +++ b/crates/burn-vision/src/tests/connected_components.rs @@ -0,0 +1,179 @@ +#[burn_tensor_testgen::testgen(connected_components)] +mod tests { + use std::collections::HashMap; + + use super::*; + use burn_tensor::{Tensor, TensorData}; + use burn_vision::{ + as_type, ConnectedComponents, ConnectedStats, ConnectedStatsOptions, Connectivity, + VisionOps, + }; + + fn space_invader() -> [[IntType; 14]; 9] { + as_type!(IntType: [ + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0], + [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1], + [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1], + [1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + ]) + } + + #[test] + fn should_support_8_connectivity() { + let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<4>(); + + let output = tensor.connected_components(Connectivity::Eight); + let expected = space_invader(); // All pixels are in the same group for 8-connected + let expected = TestTensorInt::<2>::from(expected).unsqueeze::<3>(); + + normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false); + } + + #[test] + fn should_support_8_connectivity_with_stats() { + let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<4>(); + + let (output, stats) = tensor + .connected_components_with_stats(Connectivity::Eight, ConnectedStatsOptions::all()); + let expected = space_invader(); // All pixels are in the same group for 8-connected + let expected = TestTensorInt::<2>::from(expected).unsqueeze::<3>(); + + let (area, left, top, right, bottom) = normalize_stats( + stats.area.into_data(), + stats.left.into_data(), + stats.top.into_data(), + stats.right.into_data(), + stats.bottom.into_data(), + ); + + normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false); + + area.assert_eq(&TensorData::from([[58]]), false); + left.assert_eq(&TensorData::from([[0]]), false); + top.assert_eq(&TensorData::from([[0]]), false); + right.assert_eq(&TensorData::from([[13]]), false); + bottom.assert_eq(&TensorData::from([[8]]), false); + } + + #[test] + fn should_support_4_connectivity() { + let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<4>(); + + let output = tensor.connected_components(Connectivity::Four); + let expected = as_type!(IntType: [ + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0], + [0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0], + [0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0], + [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0], + [4, 0, 0, 3, 3, 0, 0, 0, 0, 3, 3, 0, 0, 5], + [4, 4, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 5, 5], + [4, 4, 0, 3, 3, 3, 0, 0, 3, 3, 3, 0, 5, 5], + [0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0], + ]); + let expected = TestTensorInt::<2>::from(expected).unsqueeze::<3>(); + + normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false); + } + + #[test] + fn should_support_4_connectivity_with_stats() { + let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<4>(); + + let (output, stats) = tensor + .connected_components_with_stats(Connectivity::Four, ConnectedStatsOptions::all()); + let expected = as_type!(IntType: [ + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0], + [0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0], + [0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0], + [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0], + [4, 0, 0, 3, 3, 0, 0, 0, 0, 3, 3, 0, 0, 5], + [4, 4, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 5, 5], + [4, 4, 0, 3, 3, 3, 0, 0, 3, 3, 3, 0, 5, 5], + [0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0], + ]); + let expected = TestTensorInt::<2>::from(expected).unsqueeze::<3>(); + + let (area, left, top, right, bottom) = normalize_stats( + stats.area.into_data(), + stats.left.into_data(), + stats.top.into_data(), + stats.right.into_data(), + stats.bottom.into_data(), + ); + + normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false); + + area.assert_eq(&TensorData::from([[1, 1, 46, 5, 5]]), false); + left.assert_eq(&TensorData::from([[3, 10, 1, 0, 12]]), false); + top.assert_eq(&TensorData::from([[0, 0, 1, 5, 5]]), false); + right.assert_eq(&TensorData::from([[3, 10, 12, 1, 13]]), false); + bottom.assert_eq(&TensorData::from([[0, 0, 8, 7, 7]]), false); + } + + /// Normalize labels to sequential since actual labels aren't required to be contiguous and + /// different algorithms can return different numbers even if correct + fn normalize_labels(mut labels: TensorData) -> TensorData { + let mut next_label = 0; + let mut mappings = HashMap::::default(); + let data = labels.as_mut_slice::().unwrap(); + for label in data { + if *label != 0 { + let relabel = mappings.entry(*label).or_insert_with(|| { + next_label += 1; + next_label + }); + *label = *relabel; + } + } + labels + } + + fn normalize_stats( + area: TensorData, + left: TensorData, + top: TensorData, + right: TensorData, + bottom: TensorData, + ) -> (TensorData, TensorData, TensorData, TensorData, TensorData) { + let batches = area.shape[0]; + + let area = area.as_slice::().unwrap(); + let left = left.as_slice::().unwrap(); + let top = top.as_slice::().unwrap(); + let right = right.as_slice::().unwrap(); + let bottom = bottom.as_slice::().unwrap(); + + let mut area_new = vec![]; + let mut left_new = vec![]; + let mut top_new = vec![]; + let mut right_new = vec![]; + let mut bottom_new = vec![]; + + for (label, area) in area.iter().enumerate() { + if *area != 0 { + area_new.push(*area); + left_new.push(left[label]); + top_new.push(top[label]); + right_new.push(right[label]); + bottom_new.push(bottom[label]); + } + } + + let shape = [batches, area_new.len() / batches]; + + ( + TensorData::new(area_new, shape.clone()), + TensorData::new(left_new, shape.clone()), + TensorData::new(top_new, shape), + TensorData::new(right_new, shape.clone()), + TensorData::new(bottom_new, shape.clone()), + ) + } +} diff --git a/crates/burn-vision/src/tests/mod.rs b/crates/burn-vision/src/tests/mod.rs new file mode 100644 index 0000000000..1e03c9fad5 --- /dev/null +++ b/crates/burn-vision/src/tests/mod.rs @@ -0,0 +1,42 @@ +mod connected_components; + +#[macro_export] +macro_rules! testgen_all { + () => { + use burn_tensor::{Bool, Float, Int}; + + pub type TestBackend = JitBackend; + + type TestTensor = burn_tensor::Tensor; + type TestTensorInt = burn_tensor::Tensor; + type TestTensorBool = burn_tensor::Tensor; + + pub mod vision { + pub use super::*; + + pub type FloatType = ::FloatElem; + pub type IntType = ::IntElem; + pub type BoolType = ::BoolElem; + + $crate::testgen_connected_components!(); + } + }; +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! as_type { + ($ty:ident: [$($elem:tt),*]) => { + [$($crate::as_type![$ty: $elem]),*] + }; + ($ty:ident: [$($elem:tt,)*]) => { + [$($crate::as_type![$ty: $elem]),*] + }; + ($ty:ident: $elem:expr) => { + { + use cubecl::prelude::{Float, Int}; + + $ty::new($elem) + } + }; +} diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index c2e034ada5..7f10ca3966 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -35,6 +35,9 @@ burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } +burn-vision = { path = "../burn-vision", version = "0.17.0", features = [ + "export_tests", +] } half = { workspace = true } paste = { workspace = true } diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 3d29d219d0..c467afce82 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -111,4 +111,5 @@ mod tests { burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]); #[cfg(not(feature = "spirv"))] burn_jit::testgen_all!([f32], [i32], [u32]); + burn_vision::testgen_all!(); } From 9e6515041a7592566d03c41bd52e73dc87c95f14 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Tue, 28 Jan 2025 16:48:07 +0100 Subject: [PATCH 05/24] Move testgen to burn-jit --- Cargo.lock | 2 -- crates/burn-cuda/Cargo.toml | 3 --- crates/burn-cuda/src/lib.rs | 1 - crates/burn-jit/Cargo.toml | 3 ++- crates/burn-jit/src/tests/mod.rs | 4 ++++ crates/burn-vision/src/tests/mod.rs | 6 ------ crates/burn-wgpu/Cargo.toml | 7 ++----- crates/burn-wgpu/src/lib.rs | 1 - 8 files changed, 8 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 77060eafa8..7b72b862c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -707,7 +707,6 @@ dependencies = [ "burn-fusion", "burn-jit", "burn-tensor", - "burn-vision", "bytemuck", "cubecl", "derive-new 0.7.0", @@ -983,7 +982,6 @@ dependencies = [ "burn-fusion", "burn-jit", "burn-tensor", - "burn-vision", "cubecl", "half", "paste", diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml index 196018ed0e..290e79b804 100644 --- a/crates/burn-cuda/Cargo.toml +++ b/crates/burn-cuda/Cargo.toml @@ -37,9 +37,6 @@ log = { workspace = true } burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-vision = { path = "../burn-vision", version = "0.17.0", default-features = false, features = [ - "export_tests", -] } paste = { workspace = true } diff --git a/crates/burn-cuda/src/lib.rs b/crates/burn-cuda/src/lib.rs index 518376d617..0387da0215 100644 --- a/crates/burn-cuda/src/lib.rs +++ b/crates/burn-cuda/src/lib.rs @@ -21,5 +21,4 @@ mod tests { // TODO: Add tests for bf16 burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]); - burn_vision::testgen_all!(); } diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index 0a7f9f2e82..061aa4d5ea 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -20,9 +20,10 @@ export_tests = [ "serial_test", "burn-autodiff/export_tests", "burn-tensor/export_tests", - "burn-vision?/export_tests", + "burn-vision/export_tests", "burn-ndarray", "fusion", + "vision", "paste", ] fusion = ["burn-fusion"] diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index a79ac3c437..b5e82254e3 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -33,6 +33,7 @@ pub use burn_autodiff; pub use burn_fusion; pub use burn_ndarray; pub use burn_tensor; +pub use burn_vision; pub use serial_test; #[macro_export] @@ -43,7 +44,10 @@ macro_rules! testgen_all { }; ([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => { mod jit { + pub use $crate::tests::burn_vision; + burn_jit::testgen_jit!([$($float),*], [$($int),*], [$($bool),*]); + burn_vision::testgen_all!(); mod kernel { use super::*; diff --git a/crates/burn-vision/src/tests/mod.rs b/crates/burn-vision/src/tests/mod.rs index 1e03c9fad5..c1d6f904e4 100644 --- a/crates/burn-vision/src/tests/mod.rs +++ b/crates/burn-vision/src/tests/mod.rs @@ -5,12 +5,6 @@ macro_rules! testgen_all { () => { use burn_tensor::{Bool, Float, Int}; - pub type TestBackend = JitBackend; - - type TestTensor = burn_tensor::Tensor; - type TestTensorInt = burn_tensor::Tensor; - type TestTensorBool = burn_tensor::Tensor; - pub mod vision { pub use super::*; diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index eb3919f4d7..3f4e2dcc0d 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -21,12 +21,12 @@ std = ["burn-jit/std", "cubecl/std"] template = ["burn-jit/template", "cubecl/template"] # Backends -webgpu = ["cubecl-wgsl"] vulkan = ["cubecl-spirv"] +webgpu = ["cubecl-wgsl"] # Compilers -cubecl-wgsl = [] cubecl-spirv = ["cubecl/wgpu-spirv"] +cubecl-wgsl = [] [dependencies] cubecl = { workspace = true, features = ["wgpu"] } @@ -42,9 +42,6 @@ burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-vision = { path = "../burn-vision", version = "0.17.0", features = [ - "export_tests", -] } half = { workspace = true } paste = { workspace = true } diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 6aa7a5acb9..c11854fcaf 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -124,5 +124,4 @@ mod tests { burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]); #[cfg(not(feature = "vulkan"))] burn_jit::testgen_all!([f32], [i32], [u32]); - burn_vision::testgen_all!(); } From 0484f511b8860dabfe9a254a0aec1d739676ee2e Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:35:05 +0100 Subject: [PATCH 06/24] Improve HA4/8 algo --- .../hardware_accelerated.rs | 106 +++++++++++++++--- 1 file changed, 88 insertions(+), 18 deletions(-) diff --git a/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs b/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs index 41918dac96..278a7b3e10 100644 --- a/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs +++ b/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs @@ -180,19 +180,24 @@ fn strip_labeling( } #[cube(launch)] -fn strip_merge(img: &Tensor, labels: &Tensor>) { - let batch = ABSOLUTE_POS_Z; - let y = ABSOLUTE_POS_Y * BLOCK_H; - let x = ABSOLUTE_POS_X; +fn strip_merge( + img: &Tensor, + labels: &Tensor>, + #[comptime] connectivity: Connectivity, +) { + let batch = CUBE_POS_Z; + let plane_start_x = CUBE_POS_X * (CUBE_DIM_X * CUBE_DIM_Z - PLANE_DIM) + UNIT_POS_Z * PLANE_DIM; + let y = (CUBE_POS_Y + 1) * BLOCK_H; + let x = plane_start_x + UNIT_POS_X; let img_step = img.stride(2); let labels_step = labels.stride(1); let cols = img.shape(3); - if y < labels.shape(1) && x < labels.shape(2) && y > 0 { + if y < labels.shape(1) && x < labels.shape(2) { let mut mask = 0xffffffffu32; - if cols - CUBE_POS_X * CUBE_DIM_X < 32 { - mask >>= 32 - (cols - CUBE_POS_X * CUBE_DIM_X); + if cols - plane_start_x < 32 { + mask >>= 32 - (cols - plane_start_x); } let img_index = batch * img.stride(0) + y * img_step + x; @@ -207,11 +212,72 @@ fn strip_merge(img: &Tensor, labels: &Tensor> let pixels = plane_ballot(p)[0] & mask; let pixels_up = plane_ballot(p_up)[0] & mask; - if p && p_up { - let s_dist = start_distance(pixels, UNIT_POS_X); - let s_dist_up = start_distance(pixels_up, UNIT_POS_X); - if s_dist == 0 || s_dist_up == 0 { - merge(labels, labels_index - s_dist, labels_index_up - s_dist_up); + match connectivity { + Connectivity::Four => { + if p && p_up { + let s_dist = start_distance(pixels, UNIT_POS_X); + let s_dist_up = start_distance(pixels_up, UNIT_POS_X); + if s_dist == 0 || s_dist_up == 0 { + merge(labels, labels_index - s_dist, labels_index_up - s_dist_up); + } + } + } + Connectivity::Eight => { + let mut last_dist_vec = SharedMemory::::new(32); + let mut last_dist_up_vec = SharedMemory::::new(32); + + let s_dist = start_distance(pixels, UNIT_POS_X); + let s_dist_up = start_distance(pixels_up, UNIT_POS_X); + + if UNIT_POS_PLANE == PLANE_DIM - 1 { + last_dist_vec[UNIT_POS_Z] = start_distance(pixels, 32); + last_dist_up_vec[UNIT_POS_Z] = start_distance(pixels_up, 32); + } + + sync_units(); + + if CUBE_POS_X == 0 || UNIT_POS_Z > 0 { + let last_dist = if UNIT_POS_Z > 0 { + last_dist_vec[UNIT_POS_Z - 1] + } else { + 0u32 + }; + let last_dist_up = if UNIT_POS_Z > 0 { + last_dist_up_vec[UNIT_POS_Z - 1] + } else { + 0u32 + }; + + let p_prev = + select(UNIT_POS_X > 0, (pixels >> (UNIT_POS_X - 1)) & 1, last_dist) != 0; + let p_up_prev = select( + UNIT_POS_X > 0, + (pixels_up >> (UNIT_POS_X - 1)) & 1, + last_dist_up, + ) != 0; + + if p && p_up { + let s_dist = start_distance(pixels, UNIT_POS_X); + let s_dist_up = start_distance(pixels_up, UNIT_POS_X); + if s_dist == 0 || s_dist_up == 0 { + merge(labels, labels_index - s_dist, labels_index_up - s_dist_up); + } + } else if p && p_up_prev && s_dist == 0 { + let s_dist_up_prev = select( + UNIT_POS_X == 0, + last_dist_up - 1, + start_distance(pixels_up, UNIT_POS_X - 1), + ); + merge(labels, labels_index, labels_index_up - 1 - s_dist_up_prev); + } else if p_prev && p_up && s_dist_up == 0 { + let s_dist_prev = select( + UNIT_POS_X == 0, + last_dist - 1, + start_distance(pixels, UNIT_POS_X - 1), + ); + merge(labels, labels_index - 1 - s_dist_prev, labels_index_up); + } + } } } } @@ -220,8 +286,9 @@ fn strip_merge(img: &Tensor, labels: &Tensor> #[cube(launch)] fn relabeling(img: &Tensor, labels: &mut Tensor) { let batch = ABSOLUTE_POS_Z; + let plane_start_x = CUBE_POS_X * CUBE_DIM_X; let y = ABSOLUTE_POS_Y; - let x = ABSOLUTE_POS_X; + let x = plane_start_x + UNIT_POS_X; let cols = labels.shape(2); let rows = labels.shape(1); @@ -230,8 +297,8 @@ fn relabeling(img: &Tensor, labels: &mut Tensor) { if x < cols && y < rows { let mut mask = 0xffffffffu32; - if cols - CUBE_POS_X * CUBE_DIM_X < 32 { - mask >>= 32 - (cols - CUBE_POS_X * CUBE_DIM_X); + if cols - plane_start_x < 32 { + mask >>= 32 - (cols - plane_start_x); } let img_index = batch * img.stride(0) + y * img_step + x; @@ -372,18 +439,21 @@ pub fn hardware_accelerated( &client, cube_count, - cube_dim, + cube_dim_merge, img.as_tensor_arg::(1), labels.as_tensor_arg::(1), + connectivity, ); let cube_count = CubeCount::Static( From f62a9ee6f3d9507706d34a112fe8ed227bd497df Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Tue, 28 Jan 2025 18:50:36 +0100 Subject: [PATCH 07/24] Terminate units past the predefined 32 plane size --- .../connected_components/hardware_accelerated.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs b/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs index 278a7b3e10..19cec32763 100644 --- a/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs +++ b/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs @@ -56,6 +56,10 @@ fn strip_labeling( labels: &Tensor>, #[comptime] connectivity: Connectivity, ) { + if UNIT_POS_PLANE >= 32 { + terminate!(); + } + let mut shared_pixels = SharedMemory::::new(BLOCK_H); let batch = ABSOLUTE_POS_Z; @@ -413,11 +417,8 @@ pub fn hardware_accelerated(client.clone(), device.clone(), shape); + // Assume 32 wide warp. Currently, larger warps are handled by just exiting everything past 32. + // This isn't ideal but we require CUBE_DIM_X == warp_size, and we can't query the actual warp + // size at compile time. let warp_size = 32; let cube_dim = CubeDim::new_2d(warp_size, BLOCK_H); let cube_count = CubeCount::Static(1, (rows as u32).div_ceil(cube_dim.y), batches as u32); From 8edac2bb70cd727a2d0f2b431f58590309700a38 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Thu, 30 Jan 2025 14:47:37 +0100 Subject: [PATCH 08/24] move jit backend back into `burn-vision` and make tests work --- Cargo.lock | 40 ++++--------------- crates/burn-jit/Cargo.toml | 6 +-- crates/burn-jit/src/kernel/mod.rs | 4 -- crates/burn-jit/src/ops/mod.rs | 5 ++- crates/burn-jit/src/tests/mod.rs | 4 -- crates/burn-vision/Cargo.toml | 22 ++++++++-- .../cpu}/connected_components.rs | 0 .../src/{cpu_impl => backends/cpu}/mod.rs | 0 .../backends/jit}/connected_components/bke.rs | 4 +- .../hardware_accelerated.rs | 33 +++++++++------ .../backends/jit}/connected_components/mod.rs | 12 +++--- .../src/backends/jit}/mod.rs | 0 .../src/backends/jit}/ops.rs | 12 +++--- crates/burn-vision/src/backends/mod.rs | 3 ++ crates/burn-vision/src/lib.rs | 4 +- crates/burn-vision/src/ops/base.rs | 7 ++-- .../src/tests/connected_components.rs | 7 +--- crates/burn-vision/src/tests/mod.rs | 9 +++-- crates/burn-vision/tests/main.rs | 20 ++++++++++ 19 files changed, 98 insertions(+), 94 deletions(-) rename crates/burn-vision/src/{cpu_impl => backends/cpu}/connected_components.rs (100%) rename crates/burn-vision/src/{cpu_impl => backends/cpu}/mod.rs (100%) rename crates/{burn-jit/src/kernel/vision => burn-vision/src/backends/jit}/connected_components/bke.rs (99%) rename crates/{burn-jit/src/kernel/vision => burn-vision/src/backends/jit}/connected_components/hardware_accelerated.rs (94%) rename crates/{burn-jit/src/kernel/vision => burn-vision/src/backends/jit}/connected_components/mod.rs (94%) rename crates/{burn-jit/src/kernel/vision => burn-vision/src/backends/jit}/mod.rs (100%) rename crates/{burn-jit/src/kernel/vision => burn-vision/src/backends/jit}/ops.rs (70%) create mode 100644 crates/burn-vision/src/backends/mod.rs create mode 100644 crates/burn-vision/tests/main.rs diff --git a/Cargo.lock b/Cargo.lock index 7b72b862c0..08d99e254d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -822,7 +822,6 @@ dependencies = [ "burn-ndarray", "burn-tensor", "burn-tensor-testgen", - "burn-vision", "bytemuck", "cubecl", "derive-new 0.7.0", @@ -969,9 +968,14 @@ dependencies = [ name = "burn-vision" version = "0.17.0" dependencies = [ + "burn-cuda", + "burn-jit", + "burn-ndarray", "burn-tensor", "burn-tensor-testgen", + "burn-wgpu", "cubecl", + "derive-new 0.7.0", "serde", ] @@ -1698,10 +1702,11 @@ dependencies = [ "cubecl-macros-internal", "derive_more 1.0.0", "float-ord", + "fnv", "half", "num-traits", "serde", - "type_hash", + "variadics_please", ] [[package]] @@ -7629,37 +7634,6 @@ dependencies = [ "rustc-hash 1.1.0", ] -[[package]] -name = "type_hash" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c86f48f11992d3e379358c63cb25736c0b23944ff000d1583bbccad2b0b7c6" -dependencies = [ - "type_hash_core", - "type_hash_macros", -] - -[[package]] -name = "type_hash_core" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87b1e93e2cd97790892dbe2d2813fbaa6eebaeb960265f59e363e79e51e4997a" -dependencies = [ - "fnv", -] - -[[package]] -name = "type_hash_macros" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "746fc164e076483ef087b3989f7aa80ffd9320fa558f3cb72cecfb9bb1dbc41e" -dependencies = [ - "either", - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "typenum" version = "1.17.0" diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index 061aa4d5ea..214b21eef3 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -13,23 +13,20 @@ version.workspace = true [features] autotune = [] -default = ["autotune", "std", "fusion", "vision", "cubecl/default"] +default = ["autotune", "std", "fusion", "cubecl/default"] doc = ["default"] export_tests = [ "burn-tensor-testgen", "serial_test", "burn-autodiff/export_tests", "burn-tensor/export_tests", - "burn-vision/export_tests", "burn-ndarray", "fusion", - "vision", "paste", ] fusion = ["burn-fusion"] fusion-experimental = ["fusion"] std = ["cubecl/std", "burn-tensor/std"] -vision = ["burn-vision"] template = [] @@ -40,7 +37,6 @@ burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = "cubecl", "repr", ] } -burn-vision = { path = "../burn-vision", version = "0.17.0", optional = true } cubecl = { workspace = true, features = ["linalg", "reduce"] } bytemuck = { workspace = true } diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index c7f746061f..93d2833976 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -37,10 +37,6 @@ pub mod quantization; /// Reduction algorithms pub mod reduce; -/// Vision algorithms -#[cfg(feature = "vision")] -pub mod vision; - pub(crate) use clamp::*; pub(crate) use comparison::*; pub(crate) use index::*; diff --git a/crates/burn-jit/src/ops/mod.rs b/crates/burn-jit/src/ops/mod.rs index 2e23e3835d..c396bdacdd 100644 --- a/crates/burn-jit/src/ops/mod.rs +++ b/crates/burn-jit/src/ops/mod.rs @@ -7,6 +7,7 @@ mod qtensor; mod transaction; pub(crate) mod base; -pub(crate) use base::*; +pub use base::*; -pub(crate) mod numeric; +/// Numeric utility functions for jit backends +pub mod numeric; diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index b5e82254e3..a79ac3c437 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -33,7 +33,6 @@ pub use burn_autodiff; pub use burn_fusion; pub use burn_ndarray; pub use burn_tensor; -pub use burn_vision; pub use serial_test; #[macro_export] @@ -44,10 +43,7 @@ macro_rules! testgen_all { }; ([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => { mod jit { - pub use $crate::tests::burn_vision; - burn_jit::testgen_jit!([$($float),*], [$($int),*], [$($bool),*]); - burn_vision::testgen_all!(); mod kernel { use super::*; diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index 04979aaebd..69dfbaf52f 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -13,11 +13,25 @@ version.workspace = true [features] -export_tests = ["burn-tensor-testgen"] +default = ["jit-backend"] +export-tests = ["burn-tensor-testgen"] +jit-backend = ["cubecl", "burn-jit"] + +# Test features +cpu = ["export-tests"] +cuda = ["jit-backend", "export-tests"] +vulkan = ["burn-wgpu/vulkan", "wgpu"] +wgpu = ["jit-backend", "export-tests"] [dependencies] -burn-tensor = { path = "../burn-tensor" } -cubecl = { workspace = true } +burn-jit = { path = "../burn-jit", version = "0.17.0", optional = true } +burn-tensor = { path = "../burn-tensor", version = "0.17.0" } +burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } +cubecl = { workspace = true, optional = true } +derive-new = { workspace = true } serde = { workspace = true } -burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } +[dev-dependencies] +burn-cuda = { path = "../burn-cuda", version = "0.17.0", default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } +burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", default-features = false } diff --git a/crates/burn-vision/src/cpu_impl/connected_components.rs b/crates/burn-vision/src/backends/cpu/connected_components.rs similarity index 100% rename from crates/burn-vision/src/cpu_impl/connected_components.rs rename to crates/burn-vision/src/backends/cpu/connected_components.rs diff --git a/crates/burn-vision/src/cpu_impl/mod.rs b/crates/burn-vision/src/backends/cpu/mod.rs similarity index 100% rename from crates/burn-vision/src/cpu_impl/mod.rs rename to crates/burn-vision/src/backends/cpu/mod.rs diff --git a/crates/burn-jit/src/kernel/vision/connected_components/bke.rs b/crates/burn-vision/src/backends/jit/connected_components/bke.rs similarity index 99% rename from crates/burn-jit/src/kernel/vision/connected_components/bke.rs rename to crates/burn-vision/src/backends/jit/connected_components/bke.rs index 331806c237..ae8633cadb 100644 --- a/crates/burn-jit/src/kernel/vision/connected_components/bke.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/bke.rs @@ -3,13 +3,13 @@ //! "Optimized Block-Based Algorithms to Label Connected Components on GPUs," //! in IEEE Transactions on Parallel and Distributed Systems, 2019. -use crate::{ +use burn_jit::{ kernel, ops::numeric::{empty_device, zeros_device}, tensor::JitTensor, - tests::burn_tensor::{DType, Shape}, JitElement, JitRuntime, }; +use burn_tensor::{DType, Shape}; use cubecl::cube; use cubecl::prelude::*; diff --git a/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs similarity index 94% rename from crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs rename to crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs index 19cec32763..3ffb883b11 100644 --- a/crates/burn-jit/src/kernel/vision/connected_components/hardware_accelerated.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs @@ -4,11 +4,14 @@ //! DASIP, 2018 use crate::{ - kernel::vision::connected_components::stats_from_opts, ops::numeric::zeros_device, - tensor::JitTensor, BoolElement, FloatElement, IntElement, JitBackend, JitRuntime, + backends::jit::connected_components::stats_from_opts, ConnectedStatsOptions, + ConnectedStatsPrimitive, Connectivity, +}; +use burn_jit::{ + ops::numeric::zeros_device, tensor::JitTensor, BoolElement, FloatElement, IntElement, + JitBackend, JitRuntime, }; use burn_tensor::Shape; -use burn_vision::{ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity}; use cubecl::{prelude::*, Feature}; const BLOCK_H: u32 = 4; @@ -50,16 +53,19 @@ fn end_distance(pixels: u32, tx: u32) -> u32 { u32::find_first_set(u32::bitwise_not(pixels >> (tx + 1))) } +#[cube] +#[expect(unconditional_panic, reason = "clippy thinks PLANE_DIM is always 2")] +fn ballot_dyn(y: u32, pred: bool) -> u32 { + let index = y % (PLANE_DIM / 32); + plane_ballot(pred)[index] +} + #[cube(launch)] fn strip_labeling( img: &Tensor, labels: &Tensor>, #[comptime] connectivity: Connectivity, ) { - if UNIT_POS_PLANE >= 32 { - terminate!(); - } - let mut shared_pixels = SharedMemory::::new(BLOCK_H); let batch = ABSOLUTE_POS_Z; @@ -95,7 +101,7 @@ fn strip_labeling( let p_y = bool::cast_from(img[img_index]); - let pixels_y = plane_ballot(p_y)[0] & mask; + let pixels_y = ballot_dyn(UNIT_POS_Y, p_y) & mask; let mut s_dist_y = start_distance(pixels_y, UNIT_POS_X); if p_y && s_dist_y == 0 { @@ -213,8 +219,8 @@ fn strip_merge( let p = bool::cast_from(img[img_index]); let p_up = bool::cast_from(img[img_index_up]); - let pixels = plane_ballot(p)[0] & mask; - let pixels_up = plane_ballot(p_up)[0] & mask; + let pixels = ballot_dyn(UNIT_POS_Z, p) & mask; + let pixels_up = ballot_dyn(UNIT_POS_Z, p_up) & mask; match connectivity { Connectivity::Four => { @@ -309,7 +315,7 @@ fn relabeling(img: &Tensor, labels: &mut Tensor) { let labels_index = batch * labels.stride(0) + y * labels_step + x; let p = bool::cast_from(img[img_index]); - let pixels = plane_ballot(p)[0] & mask; + let pixels = ballot_dyn(UNIT_POS_Y, p) & mask; let s_dist = start_distance(pixels, UNIT_POS_X); let mut label = 0u32; @@ -358,7 +364,7 @@ fn analysis( let labels_index = batch * labels.stride(0) + y * labels_step + x; let p = bool::cast_from(img[img_index]); - let pixels = plane_ballot(p)[0] & mask; + let pixels = ballot_dyn(UNIT_POS_Y, p) & mask; let s_dist = start_distance(pixels, UNIT_POS_X); let count = end_distance(pixels, UNIT_POS_X); let max_x = x + count - 1; @@ -429,7 +435,8 @@ pub fn hardware_accelerated( l: JitTensor, opts: ConnectedStatsOptions, diff --git a/crates/burn-jit/src/kernel/vision/mod.rs b/crates/burn-vision/src/backends/jit/mod.rs similarity index 100% rename from crates/burn-jit/src/kernel/vision/mod.rs rename to crates/burn-vision/src/backends/jit/mod.rs diff --git a/crates/burn-jit/src/kernel/vision/ops.rs b/crates/burn-vision/src/backends/jit/ops.rs similarity index 70% rename from crates/burn-jit/src/kernel/vision/ops.rs rename to crates/burn-vision/src/backends/jit/ops.rs index 006a2ac012..4935de561a 100644 --- a/crates/burn-jit/src/kernel/vision/ops.rs +++ b/crates/burn-vision/src/backends/jit/ops.rs @@ -1,8 +1,8 @@ -use crate::{BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; -use burn_tensor::ops::{BoolTensor, IntTensor}; -use burn_vision::{ - cpu_impl, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, VisionOps, +use crate::{ + backends::cpu, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, VisionOps, }; +use burn_jit::{BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; +use burn_tensor::ops::{BoolTensor, IntTensor}; use super::connected_components::hardware_accelerated; @@ -20,7 +20,7 @@ where connectivity, ) .map(|it| it.0) - .unwrap_or_else(|_| cpu_impl::connected_components::(img, connectivity)) + .unwrap_or_else(|_| cpu::connected_components::(img, connectivity)) } fn connected_components_with_stats( @@ -29,7 +29,7 @@ where opts: ConnectedStatsOptions, ) -> (IntTensor, ConnectedStatsPrimitive) { hardware_accelerated::(img.clone(), opts, connectivity).unwrap_or_else(|_| { - cpu_impl::connected_components_with_stats::(img, connectivity, opts) + cpu::connected_components_with_stats::(img, connectivity, opts) }) } } diff --git a/crates/burn-vision/src/backends/mod.rs b/crates/burn-vision/src/backends/mod.rs new file mode 100644 index 0000000000..6886bb4907 --- /dev/null +++ b/crates/burn-vision/src/backends/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod cpu; +#[cfg(feature = "jit-backend")] +mod jit; diff --git a/crates/burn-vision/src/lib.rs b/crates/burn-vision/src/lib.rs index adbce61216..03643b83fb 100644 --- a/crates/burn-vision/src/lib.rs +++ b/crates/burn-vision/src/lib.rs @@ -1,8 +1,8 @@ -pub mod cpu_impl; +pub mod backends; mod ops; mod tensor; -#[cfg(feature = "export_tests")] +#[cfg(feature = "export-tests")] mod tests; pub use ops::*; diff --git a/crates/burn-vision/src/ops/base.rs b/crates/burn-vision/src/ops/base.rs index 1b0d46193a..ddbd479507 100644 --- a/crates/burn-vision/src/ops/base.rs +++ b/crates/burn-vision/src/ops/base.rs @@ -1,11 +1,10 @@ +use crate::backends::cpu; use burn_tensor::{ backend::Backend, ops::{BoolTensor, IntTensor}, Int, Tensor, }; -use crate::cpu_impl; - #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum Connectivity { Four, @@ -80,7 +79,7 @@ impl ConnectedStatsOptions { pub trait VisionOps { fn connected_components(img: BoolTensor, connectivity: Connectivity) -> IntTensor { - cpu_impl::connected_components::(img, connectivity) + cpu::connected_components::(img, connectivity) } fn connected_components_with_stats( @@ -88,6 +87,6 @@ pub trait VisionOps { connectivity: Connectivity, opts: ConnectedStatsOptions, ) -> (IntTensor, ConnectedStatsPrimitive) { - cpu_impl::connected_components_with_stats(img, connectivity, opts) + cpu::connected_components_with_stats(img, connectivity, opts) } } diff --git a/crates/burn-vision/src/tests/connected_components.rs b/crates/burn-vision/src/tests/connected_components.rs index 9216d4714e..c299769774 100644 --- a/crates/burn-vision/src/tests/connected_components.rs +++ b/crates/burn-vision/src/tests/connected_components.rs @@ -3,11 +3,8 @@ mod tests { use std::collections::HashMap; use super::*; - use burn_tensor::{Tensor, TensorData}; - use burn_vision::{ - as_type, ConnectedComponents, ConnectedStats, ConnectedStatsOptions, Connectivity, - VisionOps, - }; + use burn_tensor::TensorData; + use burn_vision::{as_type, ConnectedComponents, ConnectedStatsOptions, Connectivity}; fn space_invader() -> [[IntType; 14]; 9] { as_type!(IntType: [ diff --git a/crates/burn-vision/src/tests/mod.rs b/crates/burn-vision/src/tests/mod.rs index c1d6f904e4..11851577ed 100644 --- a/crates/burn-vision/src/tests/mod.rs +++ b/crates/burn-vision/src/tests/mod.rs @@ -5,14 +5,15 @@ macro_rules! testgen_all { () => { use burn_tensor::{Bool, Float, Int}; + pub type TestTensorInt = burn_tensor::Tensor; + pub type TestTensorBool = burn_tensor::Tensor; + pub mod vision { pub use super::*; - pub type FloatType = ::FloatElem; pub type IntType = ::IntElem; - pub type BoolType = ::BoolElem; - $crate::testgen_connected_components!(); + burn_vision::testgen_connected_components!(); } }; } @@ -28,7 +29,7 @@ macro_rules! as_type { }; ($ty:ident: $elem:expr) => { { - use cubecl::prelude::{Float, Int}; + use cubecl::prelude::*; $ty::new($elem) } diff --git a/crates/burn-vision/tests/main.rs b/crates/burn-vision/tests/main.rs new file mode 100644 index 0000000000..2819632cc1 --- /dev/null +++ b/crates/burn-vision/tests/main.rs @@ -0,0 +1,20 @@ +#[cfg(all(test, feature = "cpu"))] +mod tests_cpu { + pub type TestBackend = burn_ndarray::NdArray; + + burn_vision::testgen_all!(); +} + +#[cfg(all(test, feature = "wgpu"))] +mod tests_wgpu { + pub type TestBackend = burn_wgpu::Wgpu; + + burn_vision::testgen_all!(); +} + +#[cfg(all(test, feature = "cuda"))] +mod tests_cuda { + pub type TestBackend = burn_cuda::Cuda; + + burn_vision::testgen_all!(); +} From 05b40e3c8dbdb980a62c0c7fb1061217792dc9c4 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Thu, 30 Jan 2025 23:57:17 +0100 Subject: [PATCH 09/24] Add initial CPU implementation without stats --- Cargo.lock | 1 + crates/burn-ndarray/src/element.rs | 9 + crates/burn-ndarray/src/lib.rs | 2 +- crates/burn-ndarray/src/ops/conv.rs | 10 +- crates/burn-ndarray/src/ops/deform_conv.rs | 9 +- crates/burn-ndarray/src/ops/maxpool.rs | 14 +- crates/burn-ndarray/src/ops/module.rs | 18 +- crates/burn-ndarray/src/ops/padding.rs | 9 +- crates/burn-vision/Cargo.toml | 2 + .../src/backends/cpu/connected_components.rs | 92 +- .../Spaghetti_center_line_forest_code.rs | 1954 +++++++++++++++++ .../Spaghetti_first_line_forest_code.rs | 223 ++ .../spaghetti/Spaghetti_forest_labels.rs | 191 ++ .../Spaghetti_last_line_forest_code.rs | 787 +++++++ .../Spaghetti_single_line_forest_code.rs | 91 + .../cpu/connected_components/spaghetti/mod.rs | 214 ++ .../Spaghetti4C_center_line_forest_code.rs | 42 + .../Spaghetti4C_first_line_forest_code.rs | 31 + .../spaghetti_4c/Spaghetti4C_forest_labels.rs | 21 + .../connected_components/spaghetti_4c/mod.rs | 81 + crates/burn-vision/src/backends/cpu/mod.rs | 1 + crates/burn-vision/src/backends/cpu/ops.rs | 24 + crates/burn-vision/src/lib.rs | 2 + crates/burn-vision/tests/main.rs | 2 +- 24 files changed, 3785 insertions(+), 45 deletions(-) create mode 100644 crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_center_line_forest_code.rs create mode 100644 crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_first_line_forest_code.rs create mode 100644 crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_forest_labels.rs create mode 100644 crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_last_line_forest_code.rs create mode 100644 crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_single_line_forest_code.rs create mode 100644 crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs create mode 100644 crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_center_line_forest_code.rs create mode 100644 crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_first_line_forest_code.rs create mode 100644 crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_forest_labels.rs create mode 100644 crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs create mode 100644 crates/burn-vision/src/backends/cpu/ops.rs diff --git a/Cargo.lock b/Cargo.lock index 08d99e254d..a3ab30f4ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -976,6 +976,7 @@ dependencies = [ "burn-wgpu", "cubecl", "derive-new 0.7.0", + "ndarray 0.16.1", "serde", ] diff --git a/crates/burn-ndarray/src/element.rs b/crates/burn-ndarray/src/element.rs index a700d9e30f..093ecf5e8f 100644 --- a/crates/burn-ndarray/src/element.rs +++ b/crates/burn-ndarray/src/element.rs @@ -16,6 +16,7 @@ where { } +/// An int element for ndarray backend. pub trait IntNdArrayElement: NdArrayElement + Signed {} /// A general element for ndarray backend. @@ -34,13 +35,21 @@ pub trait NdArrayElement: /// A element for ndarray backend that supports exp ops. pub trait ExpElement { + /// Exponent fn exp_elem(self) -> Self; + /// Log fn log_elem(self) -> Self; + /// Log1p fn log1p_elem(self) -> Self; + /// Powf fn powf_elem(self, value: f32) -> Self; + /// Powi fn powi_elem(self, value: i32) -> Self; + /// Sqrt fn sqrt_elem(self) -> Self; + /// Abs fn abs_elem(self) -> Self; + /// Abs for int fn int_abs_elem(self) -> Self; } diff --git a/crates/burn-ndarray/src/lib.rs b/crates/burn-ndarray/src/lib.rs index 60c139bd25..95736b5efe 100644 --- a/crates/burn-ndarray/src/lib.rs +++ b/crates/burn-ndarray/src/lib.rs @@ -21,7 +21,7 @@ mod sharing; mod tensor; pub use backend::*; -pub use element::FloatNdArrayElement; +pub use element::*; pub(crate) use sharing::*; pub use tensor::*; diff --git a/crates/burn-ndarray/src/ops/conv.rs b/crates/burn-ndarray/src/ops/conv.rs index 429618826a..8f45b8f8f8 100644 --- a/crates/burn-ndarray/src/ops/conv.rs +++ b/crates/burn-ndarray/src/ops/conv.rs @@ -11,7 +11,7 @@ use ndarray::{ }; use crate::{ - element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, + element::FloatNdArrayElement, ops::padding::{apply_padding_4d, apply_padding_5d}, sharing::UnsafeSharedRef, tensor::NdArrayTensor, @@ -98,7 +98,7 @@ fn conv3d_mad_inner( } } -pub(crate) fn conv2d( +pub(crate) fn conv2d( x: NdArrayTensor, weight: NdArrayTensor, bias: Option>, @@ -126,7 +126,7 @@ pub(crate) fn conv2d(x, options.padding, 0i32.elem()).array; + let x = apply_padding_4d::(x, options.padding, 0i32.elem()).array; // Convert inputs from dynamic indexes to static to improve perf. let x = x.into_dimensionality::().unwrap(); @@ -310,7 +310,7 @@ pub(crate) fn conv_transpose2d( NdArrayTensor::new(output.into_dyn().into_shared()) } -pub(crate) fn conv3d( +pub(crate) fn conv3d( x: NdArrayTensor, weight: NdArrayTensor, bias: Option>, @@ -345,7 +345,7 @@ pub(crate) fn conv3d(x, options.padding, 0i32.elem()).array; + let x = apply_padding_5d::(x, options.padding, 0i32.elem()).array; // Convert inputs from dynamic indexes to static to improve perf. let x = x.into_dimensionality::().unwrap(); diff --git a/crates/burn-ndarray/src/ops/deform_conv.rs b/crates/burn-ndarray/src/ops/deform_conv.rs index 56b969a67c..a003e392f3 100644 --- a/crates/burn-ndarray/src/ops/deform_conv.rs +++ b/crates/burn-ndarray/src/ops/deform_conv.rs @@ -11,7 +11,7 @@ use ndarray::{ #[cfg(not(feature = "std"))] use num_traits::Float; -use crate::{element::QuantElement, FloatNdArrayElement, NdArrayTensor}; +use crate::{FloatNdArrayElement, NdArrayTensor}; use super::matmul::matmul; @@ -255,7 +255,6 @@ pub mod backward { #[cfg(target_has_atomic = "32")] use core::sync::atomic::Ordering; - use crate::element::IntNdArrayElement; use atomic_float::AtomicF32; use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4}; @@ -270,11 +269,7 @@ pub mod backward { ); /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. - pub(crate) fn deform_conv2d_backward< - F: FloatNdArrayElement, - I: IntNdArrayElement, - Q: QuantElement, - >( + pub(crate) fn deform_conv2d_backward( input: NdArrayTensor, offset: NdArrayTensor, weight: NdArrayTensor, diff --git a/crates/burn-ndarray/src/ops/maxpool.rs b/crates/burn-ndarray/src/ops/maxpool.rs index 90ffe30a95..b7f8e776e3 100644 --- a/crates/burn-ndarray/src/ops/maxpool.rs +++ b/crates/burn-ndarray/src/ops/maxpool.rs @@ -1,5 +1,5 @@ use crate::{ - element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, + element::{FloatNdArrayElement, IntNdArrayElement}, ops::padding::apply_padding_4d, sharing::UnsafeSharedRef, tensor::NdArrayTensor, @@ -9,7 +9,7 @@ use burn_common::{iter_range_par, run_par}; use burn_tensor::{ElementConversion, TensorMetadata}; use ndarray::Array4; -pub(crate) fn max_pool2d( +pub(crate) fn max_pool2d( x: NdArrayTensor, kernel_size: [usize; 2], stride: [usize; 2], @@ -30,7 +30,7 @@ pub(crate) fn max_pool2d(x, padding, inf).array; + let x = apply_padding_4d::(x, padding, inf).array; let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); let unsafe_shared_out = UnsafeSharedRef::new(&mut output); @@ -69,11 +69,7 @@ pub(crate) fn max_pool2d( +pub(crate) fn max_pool2d_with_indices( x: NdArrayTensor, kernel_size: [usize; 2], stride: [usize; 2], @@ -94,7 +90,7 @@ pub(crate) fn max_pool2d_with_indices< / stride_width) + 1; - let x = apply_padding_4d::(x, padding, inf).array; + let x = apply_padding_4d::(x, padding, inf).array; let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); let mut indices = Array4::::zeros((batch_size, channels, out_height, out_width)); diff --git a/crates/burn-ndarray/src/ops/module.rs b/crates/burn-ndarray/src/ops/module.rs index f0885e52e2..dbceac1934 100644 --- a/crates/burn-ndarray/src/ops/module.rs +++ b/crates/burn-ndarray/src/ops/module.rs @@ -46,11 +46,7 @@ impl ModuleOps, options: ConvOptions<2>, ) -> NdArrayTensorFloat { - module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv2d::< - E, - I, - Q, - >( + module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv2d::( x, weight, bias, options ) .into()) @@ -89,7 +85,7 @@ impl ModuleOps( + let (x, offset, weight, mask, bias) = deform_conv2d_backward::( x, offset, weight, @@ -163,7 +159,7 @@ impl ModuleOps FloatTensor { - module_op!(inp(x), opt(), E, |x| max_pool2d::( + module_op!(inp(x), opt(), E, |x| max_pool2d::( x, kernel_size, stride, @@ -182,7 +178,7 @@ impl ModuleOps MaxPool2dWithIndices> { module_op!(inp(x), opt(), E, |x| { let (output, indices) = - max_pool2d_with_indices::(x, kernel_size, stride, padding, dilation); + max_pool2d_with_indices::(x, kernel_size, stride, padding, dilation); MaxPool2dWithIndices::new(output.into(), indices) }) } @@ -282,11 +278,7 @@ impl ModuleOps>, options: ConvOptions<3>, ) -> FloatTensor { - module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::< - E, - I, - Q, - >( + module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::( x, weight, bias, options ) .into()) diff --git a/crates/burn-ndarray/src/ops/padding.rs b/crates/burn-ndarray/src/ops/padding.rs index 99bcef5a3e..ccf3252205 100644 --- a/crates/burn-ndarray/src/ops/padding.rs +++ b/crates/burn-ndarray/src/ops/padding.rs @@ -1,13 +1,10 @@ -use crate::{ - element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, - tensor::NdArrayTensor, -}; +use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor}; use burn_tensor::TensorMetadata; use ndarray::{Array4, Array5}; use super::NdArrayOps; -pub(crate) fn apply_padding_4d( +pub(crate) fn apply_padding_4d( x: NdArrayTensor, padding: [usize; 2], elem: E, @@ -37,7 +34,7 @@ pub(crate) fn apply_padding_4d( +pub(crate) fn apply_padding_5d( x: NdArrayTensor, padding: [usize; 3], elem: E, diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index 69dfbaf52f..f92e2d6be4 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -25,10 +25,12 @@ wgpu = ["jit-backend", "export-tests"] [dependencies] burn-jit = { path = "../burn-jit", version = "0.17.0", optional = true } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } burn-tensor = { path = "../burn-tensor", version = "0.17.0" } burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } cubecl = { workspace = true, optional = true } derive-new = { workspace = true } +ndarray = { workspace = true } serde = { workspace = true } [dev-dependencies] diff --git a/crates/burn-vision/src/backends/cpu/connected_components.rs b/crates/burn-vision/src/backends/cpu/connected_components.rs index c2065787b3..638cc0cb9e 100644 --- a/crates/burn-vision/src/backends/cpu/connected_components.rs +++ b/crates/burn-vision/src/backends/cpu/connected_components.rs @@ -1,15 +1,40 @@ +use alloc::vec::Vec; use burn_tensor::{ backend::Backend, ops::{BoolTensor, IntTensor}, + Bool, Int, Shape, Tensor, TensorData, }; +use ndarray::{Array3, Axis}; use crate::{ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity}; +mod spaghetti; +mod spaghetti_4c; + pub fn connected_components( - _img: BoolTensor, - _connectivity: Connectivity, + img: BoolTensor, + connectivity: Connectivity, ) -> IntTensor { - todo!() + let device = B::bool_device(&img); + let img = Tensor::::from_primitive(img); + let [batches, _, height, width] = img.shape().dims(); + let img = img.into_data().convert::().to_vec::().unwrap(); + let img = Array3::from_shape_vec((batches, height, width), img).unwrap(); + + let process = match connectivity { + Connectivity::Four => spaghetti_4c::process::, + Connectivity::Eight => spaghetti::process::, + }; + + let mut out = process(img.index_axis(Axis(0), 0)); + for i in 1..batches { + let batch = process(img.index_axis(Axis(0), i)); + out.append(Axis(0), batch.view()).unwrap(); + } + println!("{out:?}"); + let (data, _) = out.into_raw_vec_and_offset(); + let data = TensorData::new(data, Shape::new([batches, height, width])); + Tensor::::from_data(data, &device).into_primitive() } pub fn connected_components_with_stats( @@ -19,3 +44,64 @@ pub fn connected_components_with_stats( ) -> (IntTensor, ConnectedStatsPrimitive) { todo!() } + +pub trait Solver { + fn init(max_labels: usize) -> Self; + /// Hack to get around mutable borrow limitations on methods + fn merge(label_1: u32, label_2: u32, solver: &mut Self) -> u32; + fn new_label(&mut self) -> u32; + fn flatten(&mut self); + fn get_label(&mut self, i_label: u32) -> u32; +} + +pub(crate) struct UnionFind { + labels: Vec, +} + +impl Solver for UnionFind { + fn init(max_labels: usize) -> Self { + let mut labels = Vec::with_capacity(max_labels); + labels.push(0); + Self { labels } + } + + fn merge(mut label_1: u32, mut label_2: u32, solver: &mut Self) -> u32 { + while solver.labels[label_1 as usize] < label_1 { + label_1 = solver.labels[label_1 as usize]; + } + + while solver.labels[label_2 as usize] < label_2 { + label_2 = solver.labels[label_2 as usize]; + } + + if label_1 < label_2 { + solver.labels[label_2 as usize] = label_1; + label_1 + } else { + solver.labels[label_1 as usize] = label_2; + label_2 + } + } + + fn new_label(&mut self) -> u32 { + let len = self.labels.len() as u32; + self.labels.push(len); + len + } + + fn flatten(&mut self) { + let mut k = 1; + for i in 1..self.labels.len() { + if self.labels[i] < i as u32 { + self.labels[i] = self.labels[self.labels[i] as usize]; + } else { + self.labels[i] = k; + k += 1; + } + } + } + + fn get_label(&mut self, i_label: u32) -> u32 { + self.labels[i_label as usize] + } +} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_center_line_forest_code.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_center_line_forest_code.rs new file mode 100644 index 0000000000..60bb3c4ba7 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_center_line_forest_code.rs @@ -0,0 +1,1954 @@ +no_analyze! {{ +use centerLabels::*;let mut label = entry; +while let Some(next) = (|label| -> Option { match label { + NODE_1=> { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_2); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_7); + } + } + else { + return Some(NODE_3); + } + } + NODE_3=> { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_2); + } + else { + return Some(NODE_4); + } + } + NODE_4=> { + if img_row01[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_2); + } + else { + img_labels_row00[c as usize] = 0; + return Some(cl_tree_1); + } + } + NODE_2=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + return Some(NODE_5); + } + } + NODE_5=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_6); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_4); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_3); + } + } + } + NODE_7=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + else { + return Some(NODE_8); + } + } + NODE_9=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_4); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_3); + } + } + NODE_10=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + else { + if img_row11[(c - 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver), img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_11); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver), img_labels_row12[(c) as usize], solver); + return Some(cl_tree_5); + } + } + else { + if img_row11[(c - 1) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver), img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver), img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_8); + } + else { + return Some(NODE_9); + } + } + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_12); + } + else { + return Some(NODE_11); + } + } + } + } + NODE_8=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver), img_labels_row12[(c) as usize], solver); + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + } + else { + return Some(NODE_9); + } + } + NODE_12=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + if img_row11[(c - 1) as usize] > 0 { + return Some(NODE_13); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_6); + } + else { + if img_row11[(c - 1) as usize] > 0 { + return Some(NODE_14); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_4); + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(cl_tree_3); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_3); + } + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_10); + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(cl_tree_9); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_9); + } + } + } + } + } + NODE_6=> { + if img_row12[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_5); + } + } + NODE_15=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_5); + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_10); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_9); + } + } + } + } + NODE_11=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_10); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_9); + } + } + NODE_13=> { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_11); + } + } + NODE_14=> { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(cl_tree_5); + } + } + NODE_16=> { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_17); + } + else { + return Some(NODE_18); + } + } + NODE_18=> { + if img_row11[(c + 2) as usize] > 0 { + return Some(NODE_19); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_8); + } + } + NODE_17=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + } + NODE_20=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + } + NODE_21=> { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_22); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_18); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_12); + } + } + } + NODE_23=> { + if img_row11[(c + 2) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_8); + } + } + NODE_24=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + NODE_25=> { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_7); + } + } + NODE_26=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_8); + } + } + NODE_19=> { + if img_row12[(c + 1) as usize] > 0 { + return Some(NODE_20); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + NODE_27=> { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_7); + } + else { + return Some(NODE_2); + } + } + else { + return Some(NODE_25); + } + } + else { + return Some(NODE_3); + } + } + NODE_28=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + NODE_29=> { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + if img_row11[(c + 2) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_3); + } + } + } + else { + return Some(NODE_4); + } + } + NODE_30=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_31); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_8); + } + } + NODE_22=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + } + NODE_31=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + NODE_32=> { + if img_row12[(c - 1) as usize] > 0 { + return Some(NODE_17); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + NODE_33=> { + if img_row12[(c - 1) as usize] > 0 { + return Some(NODE_20); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + NODE_34=> { + if img_row12[(c - 1) as usize] > 0 { + return Some(NODE_22); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + NODE_35=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_33); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + if img_row11[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_36); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver), img_labels_row12[(c) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + } + else { + if img_row11[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_37); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_4); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_3); + } + } + } + NODE_37=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_4); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_4); + } + } + NODE_38=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_31); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_8); + } + } + NODE_39=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + return Some(NODE_33); + } + else { + if img_row11[(c) as usize] > 0 { + return Some(NODE_36); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(cl_tree_5); + } + } + } + else { + if img_row11[(c) as usize] > 0 { + return Some(NODE_37); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_3); + } + } + } + NODE_36=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_5); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver), img_labels_row12[(c) as usize], solver); + return Some(cl_tree_5); + } + } + NODE_40=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_10); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_10); + } + } +cl_tree_0 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_0); } else { return Some(cl_break_1_0); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_15); + } + else { + return Some(NODE_1); + } +} +cl_tree_1 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_1); } else { return Some(cl_break_1_1); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_12); + } + else { + return Some(NODE_1); + } +} +cl_tree_2 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_2); } else { return Some(cl_break_1_2); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_10); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_7); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + } + else { + return Some(NODE_3); + } + } +} +cl_tree_3 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_3); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_23); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_12); + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + else { + return Some(NODE_23); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + } + else { + return Some(NODE_29); + } + } +} +cl_tree_4 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_4); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_28); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_30); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_12); + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_24); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + else { + return Some(NODE_30); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + } + else { + return Some(NODE_29); + } + } +} +cl_tree_5 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_5); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_26); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_12); + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + return Some(NODE_26); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_6); + } + else { + if img_row11[(c + 2) as usize] > 0 { + return Some(NODE_6); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_4); + } + } + } + else { + return Some(NODE_4); + } + } + } +} +cl_tree_6 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_6); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_21); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_16); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + } + else { + return Some(NODE_3); + } + } +} +cl_tree_7 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_4); } else { return Some(cl_break_1_7); } } + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_10); + } + else { + return Some(NODE_12); + } + } + else { + return Some(NODE_27); + } +} +cl_tree_8 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_8); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_28); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_38); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_12); + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_24); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + else { + return Some(NODE_38); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + } + else { + return Some(NODE_29); + } + } +} +cl_tree_9 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_5); } else { return Some(cl_break_1_9); } } + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_8); + } + else { + return Some(NODE_11); + } + } + } + else { + return Some(NODE_15); + } + } + else { + return Some(NODE_27); + } +} +cl_tree_10 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_6); } else { return Some(cl_break_1_10); } } + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_34); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_39); + } + else { + if img_row11[(c) as usize] > 0 { + return Some(NODE_40); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_9); + } + } + } + } + else { + return Some(NODE_15); + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_32); + } + else { + return Some(NODE_39); + } + } + else { + return Some(NODE_2); + } + } + else { + return Some(NODE_25); + } + } + else { + return Some(NODE_3); + } + } +} +cl_tree_11 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_7); } else { return Some(cl_break_1_11); } } + if img_row00[(c) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + return Some(NODE_21); + } + else { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_21); + } + else { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_11); + } + else { + return Some(NODE_13); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_6); + } + else { + return Some(NODE_14); + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_4); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(cl_tree_3); + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_10); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(cl_tree_9); + } + } + } + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + return Some(NODE_16); + } + else { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_17); + } + else { + if img_row11[(c + 2) as usize] > 0 { + return Some(NODE_19); + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(cl_tree_4); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_3); + } + } + } + } + else { + return Some(NODE_2); + } + } + } + else { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + else { + if img_row00[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_7); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_7); + } + } + } + } + else { + return Some(NODE_3); + } + } +} +cl_tree_12 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_8); } else { return Some(cl_break_1_12); } } + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_34); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_11); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_35); + } + else { + if img_row11[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_40); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_10); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(cl_tree_9); + } + } + } + } + else { + return Some(NODE_15); + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_32); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(cl_tree_6); + } + } + else { + return Some(NODE_35); + } + } + else { + return Some(NODE_2); + } + } + else { + return Some(NODE_25); + } + } + else { + return Some(NODE_3); + } + } +} + NODE_41=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + NODE_42=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + NODE_43=> { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + img_labels_row00[c as usize] = 0; + } + } + NODE_44=> { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = 0; + } + } + NODE_45=> { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c - 2) as usize], solver); + } + else { + return Some(NODE_46); + } + } + NODE_47=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + NODE_48=> { + if img_row01[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + else { + img_labels_row00[c as usize] = 0; + } + } + NODE_46=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } +cl_break_0_0 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_47); + } + else { + return Some(NODE_43); + } + return None;} +cl_break_0_1 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_41); + } + else { + return Some(NODE_43); + } + return None;} +cl_break_0_2 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_45); + } + else { + return Some(NODE_44); + } + return None;} +cl_break_0_3 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + return Some(NODE_44); + } + return None;} +cl_break_0_4 => { + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_45); + } + else { + return Some(NODE_41); + } + } + else { + return Some(NODE_48); + } + return None;} +cl_break_0_5 => { + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_46); + } + else { + return Some(NODE_47); + } + } + else { + return Some(NODE_48); + } + return None;} +cl_break_0_6 => { + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_42); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_47); + } + } + else { + return Some(NODE_48); + } + return None;} +cl_break_0_7 => { + if img_row00[(c) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row00[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + else { + img_labels_row00[c as usize] = 0; + } + } + return None;} +cl_break_0_8 => { + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_42); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_47); + } + } + else { + return Some(NODE_48); + } + return None;} + NODE_49=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_50); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + NODE_51=> { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_52); + } + else { + if img_row11[(c) as usize] > 0 { + return Some(NODE_53); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + } + else { + return Some(NODE_54); + } + } + NODE_53=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + NODE_55=> { + if img_row01[(c - 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_52); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + if img_row11[(c) as usize] > 0 { + return Some(NODE_50); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + } + else { + return Some(NODE_54); + } + } + NODE_56=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_53); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + NODE_57=> { + if img_row11[(c + 1) as usize] > 0 { + return Some(NODE_58); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + NODE_58=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + } + NODE_59=> { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + NODE_60=> { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + else { + return Some(NODE_61); + } + } + NODE_62=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + NODE_52=> { + if img_row12[(c - 1) as usize] > 0 { + return Some(NODE_58); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + NODE_63=> { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_64); + } + else { + return Some(NODE_59); + } + } + else { + return Some(NODE_65); + } + } + NODE_64=> { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_66); + } + else { + return Some(NODE_54); + } + } + NODE_50=> { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_53); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + NODE_61=> { + if img_row01[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + img_labels_row00[c as usize] = 0; + } + } + NODE_67=> { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + NODE_65=> { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_54); + } + else { + return Some(NODE_61); + } + } + NODE_68=> { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_54); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + else { + return Some(NODE_65); + } + } + NODE_66=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + return Some(NODE_67); + } + } + NODE_69=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + if img_row11[(c - 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver), img_labels_row12[(c - 2) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c - 2) as usize], solver); + } + else { + return Some(NODE_67); + } + } + } + NODE_70=> { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c) as usize], img_labels_row12[(c - 2) as usize], solver); + } + } + NODE_54=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + NODE_71=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c - 1) as usize] > 0 { + return Some(NODE_70); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + } +cl_break_1_0 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_54); + } + else { + return Some(NODE_68); + } + return None;} +cl_break_1_1 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_71); + } + else { + return Some(NODE_68); + } + return None;} +cl_break_1_2 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_69); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_66); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_65); + } + } + return None;} +cl_break_1_3 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_62); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_62); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_60); + } + } + return None;} +cl_break_1_4 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_56); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_56); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_60); + } + } + return None;} +cl_break_1_5 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + return Some(NODE_61); + } + } + } + return None;} +cl_break_1_6 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_57); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_57); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_65); + } + } + return None;} +cl_break_1_7 => { + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_69); + } + else { + return Some(NODE_71); + } + } + else { + return Some(NODE_63); + } + return None;} +cl_break_1_8 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_49); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_49); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_60); + } + } + return None;} +cl_break_1_9 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_64); + } + else { + return Some(NODE_63); + } + return None;} +cl_break_1_10 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_51); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_51); + } + else { + return Some(NODE_59); + } + } + else { + return Some(NODE_65); + } + } + return None;} +cl_break_1_11 => { + if img_row00[(c) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + return Some(NODE_57); + } + else { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_57); + } + else { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + return Some(NODE_70); + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + } + } + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + return Some(NODE_57); + } + else { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_57); + } + else { + return Some(NODE_54); + } + } + } + else { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row00[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + } + else { + return Some(NODE_65); + } + } + return None;} +cl_break_1_12 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_55); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_55); + } + else { + return Some(NODE_59); + } + } + else { + return Some(NODE_65); + } + } + return None;} + }; None})(label) +{ +label = next; +} +}} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_first_line_forest_code.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_first_line_forest_code.rs new file mode 100644 index 0000000000..4cc475d836 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_first_line_forest_code.rs @@ -0,0 +1,223 @@ +no_analyze!{{ +use firstLabels::*;let mut label = entry; +while let Some(next) = (|label| -> Option { match label { + NODE_72=> { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(fl_tree_1); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(fl_tree_2); + } + } + NODE_73=> { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_1); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_2); + } + } + NODE_74=> { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_1); + } + else { + if img_row01[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_1); + } + else { + img_labels_row00[c as usize] = 0; + return Some(fl_tree_0); + } + } + } +fl_tree_0 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_0); } else { return Some(fl_break_1_0); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_73); + } + else { + if img_row01[(c) as usize] > 0 { + return Some(NODE_73); + } + else { + return Some(NODE_74); + } + } +} +fl_tree_1 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_1); } else { return Some(fl_break_1_1); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_72); + } + else { + if img_row01[(c) as usize] > 0 { + return Some(NODE_72); + } + else { + return Some(NODE_74); + } + } +} +fl_tree_2 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_2); } else { return Some(fl_break_1_2); } } + if img_row00[(c) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + return Some(NODE_72); + } + else { + return Some(NODE_73); + } + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(fl_tree_1); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_1); + } + } + else { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(fl_tree_2); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_2); + } + } + } + else { + return Some(NODE_74); + } + } +} + NODE_75=> { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } +fl_break_0_0 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + img_labels_row00[c as usize] = 0; + } + } + return None;} +fl_break_0_1 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = 0; + } + } + return None;} +fl_break_0_2 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_75); + } + else { + if img_row01[(c) as usize] > 0 { + return Some(NODE_75); + } + else { + img_labels_row00[c as usize] = 0; + } + } + return None;} + NODE_76=> { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + if img_row01[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + img_labels_row00[c as usize] = 0; + } + } + } + NODE_77=> { + if img_row01[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } +fl_break_1_0 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + return Some(NODE_76); + } + } + return None;} +fl_break_1_1 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row01[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + return Some(NODE_76); + } + } + return None;} +fl_break_1_2 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_77); + } + else { + if img_row01[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_77); + } + else { + return Some(NODE_77); + } + } + else { + return Some(NODE_76); + } + } + return None;} +fl_ => {}, + }; None})(label) +{ +label = next; +} +}} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_forest_labels.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_forest_labels.rs new file mode 100644 index 0000000000..6c994fc7d3 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_forest_labels.rs @@ -0,0 +1,191 @@ +/// Workaround for rust-analyzer bug that causes invalid errors on the `include!`. +macro_rules! no_analyze { + ($tokens:tt) => { + $tokens + }; +} + +pub(crate) use no_analyze; + +#[allow(non_snake_case, non_camel_case_types, unused)] +pub enum centerLabels { + NODE_1, + NODE_2, + NODE_3, + NODE_4, + NODE_5, + NODE_6, + NODE_7, + NODE_8, + NODE_9, + NODE_10, + NODE_11, + NODE_12, + NODE_13, + NODE_14, + NODE_15, + NODE_16, + NODE_17, + NODE_18, + NODE_19, + NODE_20, + NODE_21, + NODE_22, + NODE_23, + NODE_24, + NODE_25, + NODE_26, + NODE_27, + NODE_28, + NODE_29, + NODE_30, + NODE_31, + NODE_32, + NODE_33, + NODE_34, + NODE_35, + NODE_36, + NODE_37, + NODE_38, + NODE_39, + NODE_40, + NODE_41, + NODE_42, + NODE_43, + NODE_44, + NODE_45, + NODE_46, + NODE_47, + NODE_48, + NODE_49, + NODE_50, + NODE_51, + NODE_52, + NODE_53, + NODE_54, + NODE_55, + NODE_56, + NODE_57, + NODE_58, + NODE_59, + NODE_60, + NODE_61, + NODE_62, + NODE_63, + NODE_64, + NODE_65, + NODE_66, + NODE_67, + NODE_68, + NODE_69, + NODE_70, + NODE_71, + cl_tree_0, + cl_tree_1, + cl_tree_2, + cl_tree_3, + cl_tree_4, + cl_tree_5, + cl_tree_6, + cl_tree_7, + cl_tree_8, + cl_tree_9, + cl_tree_10, + cl_tree_11, + cl_tree_12, + cl_break_0_0, + cl_break_0_1, + cl_break_0_2, + cl_break_0_3, + cl_break_0_4, + cl_break_0_5, + cl_break_0_6, + cl_break_0_7, + cl_break_0_8, + cl_break_1_0, + cl_break_1_1, + cl_break_1_2, + cl_break_1_3, + cl_break_1_4, + cl_break_1_5, + cl_break_1_6, + cl_break_1_7, + cl_break_1_8, + cl_break_1_9, + cl_break_1_10, + cl_break_1_11, + cl_break_1_12, +} + +#[allow(non_snake_case, non_camel_case_types, unused)] +pub enum firstLabels { + NODE_72, + NODE_73, + NODE_74, + NODE_75, + NODE_76, + NODE_77, + fl_tree_0, + fl_tree_1, + fl_tree_2, + fl_break_0_0, + fl_break_0_1, + fl_break_0_2, + fl_break_1_0, + fl_break_1_1, + fl_break_1_2, + fl_, +} + +#[allow(non_snake_case, non_camel_case_types, unused)] +pub enum lastLabels { + NODE_78, + NODE_79, + NODE_80, + NODE_81, + NODE_82, + NODE_83, + NODE_84, + NODE_85, + NODE_86, + NODE_87, + NODE_88, + NODE_89, + NODE_90, + NODE_91, + NODE_92, + ll_tree_0, + ll_tree_1, + ll_tree_2, + ll_tree_3, + ll_tree_4, + ll_tree_5, + ll_tree_6, + ll_tree_7, + ll_break_0_0, + ll_break_0_1, + ll_break_0_2, + ll_break_0_3, + ll_break_1_0, + ll_break_1_1, + ll_break_1_2, + ll_break_1_3, + ll_break_1_4, + ll_break_1_5, + ll_break_1_6, + ll_break_1_7, + ll_, +} + +#[allow(non_snake_case, non_camel_case_types, unused)] +pub enum singleLabels { + NODE_93, + NODE_94, + sl_tree_0, + sl_tree_1, + sl_break_0_0, + sl_break_0_1, + sl_break_1_0, + sl_break_1_1, + sl_, +} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_last_line_forest_code.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_last_line_forest_code.rs new file mode 100644 index 0000000000..945c40f132 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_last_line_forest_code.rs @@ -0,0 +1,787 @@ +no_analyze!{{ +use lastLabels::*;let mut label = entry; +while let Some(next) = (|label| -> Option { match label { + NODE_78=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_6); + } + } + NODE_79=> { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(ll_tree_4); + } + } + NODE_80=> { + if img_row12[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_4); + } + } + NODE_81=> { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_5); + } + else { + if img_row11[(c + 2) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(ll_tree_2); + } + } + } + else { + img_labels_row00[c as usize] = 0; + return Some(ll_tree_1); + } + } + NODE_82=> { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_5); + } + else { + return Some(NODE_83); + } + } + else { + img_labels_row00[c as usize] = 0; + return Some(ll_tree_1); + } + } + NODE_84=> { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(ll_tree_6); + } + } + NODE_83=> { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_80); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_3); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(ll_tree_2); + } + } + } + NODE_85=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c + 2) as usize], img_labels_row12[(c - 2) as usize], solver); + return Some(ll_tree_4); + } + } + NODE_86=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_6); + } + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_7); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_0); + } + } + } +ll_tree_0 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_0); } else { return Some(ll_break_1_0); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_83); + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_0); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(ll_tree_0); + } + } + } + } + else { + return Some(NODE_82); + } +} +ll_tree_1 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_1); } else { return Some(ll_break_1_1); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + if img_row11[(c - 1) as usize] > 0 { + return Some(NODE_84); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_80); + } + else { + if img_row11[(c - 1) as usize] > 0 { + return Some(NODE_79); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_3); + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(ll_tree_2); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(ll_tree_2); + } + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_0); + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(ll_tree_0); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(ll_tree_0); + } + } + } + } + } + else { + return Some(NODE_82); + } +} +ll_tree_2 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_2); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_6); + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_7); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_0); + } + } + } + else { + return Some(NODE_81); + } +} +ll_tree_3 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_3); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_78); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_6); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_85); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_7); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_0); + } + } + } + else { + return Some(NODE_81); + } +} +ll_tree_4 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_4); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c + 2) as usize]; + return Some(ll_tree_4); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_7); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_0); + } + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_5); + } + else { + if img_row11[(c + 2) as usize] > 0 { + return Some(NODE_80); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_3); + } + } + } + else { + img_labels_row00[c as usize] = 0; + return Some(ll_tree_1); + } + } +} +ll_tree_5 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_5); } } + if img_row00[(c) as usize] > 0 { + return Some(NODE_86); + } + else { + return Some(NODE_82); + } +} +ll_tree_6 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_3); } else { return Some(ll_break_1_6); } } + if img_row00[(c) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + return Some(NODE_86); + } + else { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_6); + } + else { + return Some(NODE_84); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + return Some(NODE_80); + } + else { + return Some(NODE_79); + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_3); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(ll_tree_2); + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + return Some(ll_tree_0); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + return Some(ll_tree_0); + } + } + } + } + } + else { + return Some(NODE_82); + } +} +ll_tree_7 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_7); } } + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_78); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_6); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + return Some(ll_tree_6); + } + } + else { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 2) as usize] > 0 { + if img_row12[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_85); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c + 2) as usize], solver); + return Some(ll_tree_4); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_7); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(ll_tree_0); + } + } + } + else { + return Some(NODE_81); + } +} +ll_break_0_0 => { + if img_row00[(c) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + else { + img_labels_row00[c as usize] = 0; + } + return None;} +ll_break_0_1 => { + if img_row00[(c) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + else { + img_labels_row00[c as usize] = 0; + } + return None;} +ll_break_0_2 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = 0; + } + return None;} +ll_break_0_3 => { + if img_row00[(c) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + } + } + else { + img_labels_row00[c as usize] = 0; + } + return None;} + NODE_87=> { + if img_row00[(c + 1) as usize] > 0 { + return Some(NODE_88); + } + else { + img_labels_row00[c as usize] = 0; + } + } + NODE_88=> { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + NODE_89=> { + if img_row00[(c + 1) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + else { + img_labels_row00[c as usize] = 0; + } + } + NODE_90=> { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c) as usize], img_labels_row12[(c - 2) as usize], solver); + } + } + NODE_91=> { + if img_row12[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row12[(c) as usize], img_labels_row12[(c - 2) as usize], solver); + } + } + NODE_92=> { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row12[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } +ll_break_1_0 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_88); + } + else { + return Some(NODE_87); + } + return None;} +ll_break_1_1 => { + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c - 1) as usize] > 0 { + return Some(NODE_90); + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + if img_row11[(c - 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = solver.new_label(); + } + } + } + } + else { + return Some(NODE_87); + } + return None;} +ll_break_1_2 => { + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_89); + } + return None;} +ll_break_1_3 => { + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + return Some(NODE_91); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_89); + } + return None;} +ll_break_1_4 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = 0; + } + } + return None;} +ll_break_1_5 => { + if img_row00[(c) as usize] > 0 { + return Some(NODE_92); + } + else { + return Some(NODE_87); + } + return None;} +ll_break_1_6 => { + if img_row00[(c) as usize] > 0 { + if img_row00[(c - 1) as usize] > 0 { + return Some(NODE_92); + } + else { + if img_row11[(c + 1) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + return Some(NODE_90); + } + } + else { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row12[(c) as usize]; + } + else { + img_labels_row00[c as usize] = img_labels_row12[(c - 2) as usize]; + } + } + } + } + else { + return Some(NODE_87); + } + return None;} +ll_break_1_7 => { + if img_row00[(c) as usize] > 0 { + if img_row11[(c + 1) as usize] > 0 { + if img_row12[(c) as usize] > 0 { + if img_row11[(c - 2) as usize] > 0 { + return Some(NODE_91); + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 2) as usize], img_labels_row12[(c) as usize], solver); + } + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + } + else { + return Some(NODE_89); + } + return None;} +ll_ => {}, + }; None})(label) +{ +label = next; +} +}} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_single_line_forest_code.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_single_line_forest_code.rs new file mode 100644 index 0000000000..e818d77e16 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_single_line_forest_code.rs @@ -0,0 +1,91 @@ +no_analyze!{{ +use singleLabels::*;let mut label = entry; +while let Some(next) = (|label| -> Option { match label { + NODE_93=> { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(sl_tree_1); + } + else { + img_labels_row00[c as usize] = 0; + return Some(sl_tree_0); + } + } +sl_tree_0 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(sl_break_0_0); } else { return Some(sl_break_1_0); } } + if img_row00[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(sl_tree_1); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(sl_tree_0); + } + } + else { + return Some(NODE_93); + } +} +sl_tree_1 => { +if ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(sl_break_0_1); } else { return Some(sl_break_1_1); } } + if img_row00[(c) as usize] > 0 { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(sl_tree_1); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + return Some(sl_tree_0); + } + } + else { + return Some(NODE_93); + } +} +sl_break_0_0 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + img_labels_row00[c as usize] = 0; + } + return None;} +sl_break_0_1 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + img_labels_row00[c as usize] = 0; + } + return None;} + NODE_94=> { + if img_row00[(c + 1) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + img_labels_row00[c as usize] = 0; + } + } +sl_break_1_0 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + } + else { + return Some(NODE_94); + } + return None;} +sl_break_1_1 => { + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 2) as usize]; + } + else { + return Some(NODE_94); + } + return None;} +sl_ => {}, + }; None})(label) +{ +label = next; +} +}} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs new file mode 100644 index 0000000000..b39c922072 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs @@ -0,0 +1,214 @@ +//! Spaghetti algorithm for connected component labeling +//! F. Bolelli, S. Allegretti, L. Baraldi, and C. Grana, +//! "Spaghetti Labeling: Directed Acyclic Graphs for Block-Based Bonnected Components Labeling," +//! IEEE Transactions on Image Processing, vol. 29, no. 1, pp. 1999-2012, 2019. +//! +//! Decision forests are generated using a modified [GRAPHGEN](https://github.com/wingertge/GRAPHGEN) +//! as described in +//! +//! F. Bolelli, S. Allegretti, C. Grana. +//! "One DAG to Rule Them All." +//! IEEE Transactions on Pattern Analisys and Machine Intelligence, 2021 + +#![allow( + unreachable_code, + clippy::collapsible_else_if, + clippy::if_same_then_else +)] + +use ndarray::{s, Array2, ArrayView2, Axis}; + +#[allow(non_snake_case)] +mod Spaghetti_forest_labels; +pub(crate) use Spaghetti_forest_labels::*; + +use super::Solver; + +pub fn process(img: ArrayView2) -> Array2 { + let (h, w) = img.dim(); + + let e_rows = h as u32 & 0xfffffffe; + let o_rows = h % 2 == 1; + let e_cols = w as u32 & 0xfffffffe; + let o_cols = w % 2 == 1; + + let mut img_labels = Array2::default(img.raw_dim()); + + let mut solver = LabelsSolver::init(((h + 1) / 2) * ((w + 1) / 2) + 1); + + let solver = &mut solver; + + let w = w as i32; + + if h == 1 { + // Single line + let r = 0; + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + + // Row pointers for the output image + let mut img_labels_row00 = img_labels.slice_mut(s![r, ..]); + let mut c = -2i32; + let entry = singleLabels::sl_tree_0; + + include!("Spaghetti_single_line_forest_code.rs"); + } else { + // More than one line + + // First couple of lines + { + let img_row00 = img.index_axis(Axis(0), 0); + let img_row01 = img.index_axis(Axis(0), 1); + let mut img_labels_row00 = img_labels.slice_mut(s![0, ..]); + let mut c = -2i32; + let entry = firstLabels::fl_tree_0; + + include!("Spaghetti_first_line_forest_code.rs"); + } + + // Every other line but the last one if image has an odd number of rows + for r in (2..e_rows as usize).step_by(2) { + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + let img_row12 = img.index_axis(Axis(0), r - 2); + let img_row11 = img.index_axis(Axis(0), r - 1); + let img_row01 = img.index_axis(Axis(0), r + 1); + + // Row pointers for the output image + let (mut img_labels_row00, img_labels_row12) = + img_labels.multi_slice_mut((s![r, ..], s![r - 2, ..])); + + let mut c = -2; + let entry = centerLabels::cl_tree_0; + + include!("Spaghetti_center_line_forest_code.rs"); + } + + if o_rows { + let r = h - 1; + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + let img_row12 = img.index_axis(Axis(0), r - 2); + let img_row11 = img.index_axis(Axis(0), r - 1); + + // Row pointers for the output image + let (mut img_labels_row00, img_labels_row12) = + img_labels.multi_slice_mut((s![r, ..], s![r - 2, ..])); + let mut c = -2; + let entry = lastLabels::ll_tree_0; + + include!("Spaghetti_last_line_forest_code.rs"); + } + } + + solver.flatten(); + + for r in (0..e_rows as usize).step_by(2) { + //Pointers: + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + let img_row01 = img.index_axis(Axis(0), r + 1); + + // Row pointers for the output image + let (mut img_labels_row00, mut img_labels_row01) = + img_labels.multi_slice_mut((s![r, ..], s![r + 1, ..])); + + for c in (0..e_cols as usize).step_by(2) { + let mut i_label = img_labels_row00[c]; + if i_label > 0 { + i_label = solver.get_label(i_label); + if img_row00[c] > 0 { + img_labels_row00[c] = i_label; + } else { + img_labels_row00[c] = 0; + } + if img_row00[c + 1] > 0 { + img_labels_row00[c + 1] = i_label; + } else { + img_labels_row00[c + 1] = 0; + } + if img_row01[c] > 0 { + img_labels_row01[c] = i_label; + } else { + img_labels_row01[c] = 0; + } + if img_row01[c + 1] > 0 { + img_labels_row01[c + 1] = i_label; + } else { + img_labels_row01[c + 1] = 0; + } + } else { + img_labels_row00[c] = 0; + img_labels_row00[c + 1] = 0; + img_labels_row01[c] = 0; + img_labels_row01[c + 1] = 0; + } + } + if o_cols { + let c = e_cols as usize; + let mut i_label = img_labels_row00[c]; + if i_label > 0 { + i_label = solver.get_label(i_label); + if img_row00[c] > 0 { + img_labels_row00[c] = i_label; + } else { + img_labels_row00[c] = 0; + } + if img_row01[c] > 0 { + img_labels_row01[c] = i_label; + } else { + img_labels_row01[c] = 0; + } + } else { + img_labels_row00[c] = 0; + img_labels_row01[c] = 0; + } + } + } + + if o_rows { + let r = e_rows as usize; + + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + + // Row pointers for the output image + let mut img_labels_row00 = img_labels.slice_mut(s![r, ..]); + + for c in (0..e_cols as usize).step_by(2) { + let mut i_label = img_labels_row00[c]; + if i_label > 0 { + i_label = solver.get_label(i_label); + if img_row00[c] > 0 { + img_labels_row00[c] = i_label; + } else { + img_labels_row00[c] = 0; + } + if img_row00[c + 1] > 0 { + img_labels_row00[c + 1] = i_label; + } else { + img_labels_row00[c + 1] = 0; + } + } else { + img_labels_row00[c] = 0; + img_labels_row00[c + 1] = 0; + } + } + if o_cols { + let c = e_cols as usize; + let mut i_label = img_labels_row00[c]; + if i_label > 0 { + i_label = solver.get_label(i_label); + if img_row00[c] > 0 { + img_labels_row00[c] = i_label; + } else { + img_labels_row00[c] = 0; + } + } else { + img_labels_row00[c] = 0; + } + } + } + + img_labels +} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_center_line_forest_code.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_center_line_forest_code.rs new file mode 100644 index 0000000000..1c1b932334 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_center_line_forest_code.rs @@ -0,0 +1,42 @@ +no_analyze!{{ +use centerLabels::*;let mut label = entry; +while let Some(next) = (|label| -> Option { match label { +cl_tree_0 => { +if ({c+=1; c} >= w) { return None; } + if img_row00[(c) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row11[(c) as usize]; + return Some(cl_tree_1); + } + else { + img_labels_row00[c as usize] = solver.new_label(); + return Some(cl_tree_1); + } + } + else { + img_labels_row00[c as usize] = 0; + return Some(cl_tree_0); + } +} +cl_tree_1 => { +if ({c+=1; c} >= w) { return None; } + if img_row00[(c) as usize] > 0 { + if img_row11[(c) as usize] > 0 { + img_labels_row00[c as usize] = LabelsSolver::merge(img_labels_row00[(c - 1) as usize], img_labels_row11[(c) as usize], solver); + return Some(cl_tree_1); + } + else { + img_labels_row00[c as usize] = img_labels_row00[(c - 1) as usize]; + return Some(cl_tree_1); + } + } + else { + img_labels_row00[c as usize] = 0; + return Some(cl_tree_0); + } +} + }; None})(label) +{ +label = next; +} +}} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_first_line_forest_code.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_first_line_forest_code.rs new file mode 100644 index 0000000000..5deff2941b --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_first_line_forest_code.rs @@ -0,0 +1,31 @@ +no_analyze!{{ +use firstLabels::*;let mut label = entry; +while let Some(next) = (|label| -> Option { match label { +fl_tree_0 => { +if ({c+=1; c} >= w) { return None; } + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = solver.new_label(); + return Some(fl_tree_1); + } + else { + img_labels_row00[c as usize] = 0; + return Some(fl_tree_0); + } +} +fl_tree_1 => { +if ({c+=1; c} >= w) { return None; } + if img_row00[(c) as usize] > 0 { + img_labels_row00[c as usize] = img_labels_row00[(c - 1) as usize]; + return Some(fl_tree_1); + } + else { + img_labels_row00[c as usize] = 0; + return Some(fl_tree_0); + } +} +fl_ => {}, + }; None})(label) +{ +label = next; +} +}} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_forest_labels.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_forest_labels.rs new file mode 100644 index 0000000000..70e89e8ab1 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_forest_labels.rs @@ -0,0 +1,21 @@ +/// Workaround for rust-analyzer bug that causes invalid errors on the `include!`. +macro_rules! no_analyze { + ($tokens:tt) => { + $tokens + }; +} + +pub(crate) use no_analyze; + +#[allow(non_snake_case, non_camel_case_types, unused)] +pub enum centerLabels { + cl_tree_0, + cl_tree_1, +} + +#[allow(non_snake_case, non_camel_case_types, unused)] +pub enum firstLabels { + fl_tree_0, + fl_tree_1, + fl_, +} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs new file mode 100644 index 0000000000..9085293086 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs @@ -0,0 +1,81 @@ +//! Spaghetti algorithm for connected component labeling, modified for 4-connectivity using the +//! 4-connected Rosenfeld mask. +//! F. Bolelli, S. Allegretti, L. Baraldi, and C. Grana, +//! "Spaghetti Labeling: Directed Acyclic Graphs for Block-Based Bonnected Components Labeling," +//! IEEE Transactions on Image Processing, vol. 29, no. 1, pp. 1999-2012, 2019. +//! +//! Decision forests are generated using a modified [GRAPHGEN](https://github.com/wingertge/GRAPHGEN) +//! as described in +//! +//! F. Bolelli, S. Allegretti, C. Grana. +//! "One DAG to Rule Them All." +//! IEEE Transactions on Pattern Analisys and Machine Intelligence, 2021 + +#![allow(unreachable_code)] + +use ndarray::{s, Array2, ArrayView2, Axis}; + +use super::Solver; + +#[allow(non_snake_case)] +mod Spaghetti4C_forest_labels; +pub(crate) use Spaghetti4C_forest_labels::*; + +pub fn process(img: ArrayView2) -> Array2 { + let (h, w) = img.dim(); + + let mut img_labels = Array2::default(img.raw_dim()); + + // A quick and dirty upper bound for the maximum number of labels. + // Following formula comes from the fact that a 2x2 block in 4-connectivity case + // can never have more than 2 new labels and 1 label for background. + // Worst case image example pattern: + // 1 0 1 0 1... + // 0 1 0 1 0... + // 1 0 1 0 1... + // ............ + let max_labels = ((h * w + 1) / 2) + 1; + + let mut solver = LabelsSolver::init(max_labels); + let solver = &mut solver; + + let w = w as i32; + + // First row + { + let r = 0; + //Pointers: + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + + // Row pointers for the output image + let mut img_labels_row00 = img_labels.slice_mut(s![r, ..]); + let mut c = -1i32; + + let entry = firstLabels::fl_tree_0; + + include!("Spaghetti4C_first_line_forest_code.rs"); + } + + for r in 1..h { + //Pointers: + // Row pointers for the input image + let img_row00 = img.index_axis(Axis(0), r); + let img_row11 = img.index_axis(Axis(0), r - 1); + + // Row pointers for the output image + let (mut img_labels_row00, img_labels_row11) = + img_labels.multi_slice_mut((s![r, ..], s![r - 1, ..])); + let mut c = -1i32; + + let entry = centerLabels::cl_tree_0; + + include!("Spaghetti4C_center_line_forest_code.rs"); + } + + solver.flatten(); + + img_labels.map_inplace(|label| *label = solver.get_label(*label)); + + img_labels +} diff --git a/crates/burn-vision/src/backends/cpu/mod.rs b/crates/burn-vision/src/backends/cpu/mod.rs index 6f51d94902..e64f7a8d75 100644 --- a/crates/burn-vision/src/backends/cpu/mod.rs +++ b/crates/burn-vision/src/backends/cpu/mod.rs @@ -1,3 +1,4 @@ mod connected_components; +mod ops; pub use connected_components::*; diff --git a/crates/burn-vision/src/backends/cpu/ops.rs b/crates/burn-vision/src/backends/cpu/ops.rs new file mode 100644 index 0000000000..157d56aed7 --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/ops.rs @@ -0,0 +1,24 @@ +use crate::{ + backends::cpu, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, VisionOps, +}; +use burn_ndarray::{FloatNdArrayElement, IntNdArrayElement, NdArray, QuantElement}; +use burn_tensor::ops::{BoolTensor, IntTensor}; + +impl VisionOps for NdArray +where + E: FloatNdArrayElement, + I: IntNdArrayElement, + Q: QuantElement, +{ + fn connected_components(img: BoolTensor, connectivity: Connectivity) -> IntTensor { + cpu::connected_components::(img, connectivity) + } + + fn connected_components_with_stats( + img: BoolTensor, + connectivity: Connectivity, + opts: ConnectedStatsOptions, + ) -> (IntTensor, ConnectedStatsPrimitive) { + cpu::connected_components_with_stats::(img, connectivity, opts) + } +} diff --git a/crates/burn-vision/src/lib.rs b/crates/burn-vision/src/lib.rs index 03643b83fb..be3e9b2684 100644 --- a/crates/burn-vision/src/lib.rs +++ b/crates/burn-vision/src/lib.rs @@ -1,3 +1,5 @@ +extern crate alloc; + pub mod backends; mod ops; mod tensor; diff --git a/crates/burn-vision/tests/main.rs b/crates/burn-vision/tests/main.rs index 2819632cc1..f31361f3cf 100644 --- a/crates/burn-vision/tests/main.rs +++ b/crates/burn-vision/tests/main.rs @@ -1,6 +1,6 @@ #[cfg(all(test, feature = "cpu"))] mod tests_cpu { - pub type TestBackend = burn_ndarray::NdArray; + pub type TestBackend = burn_ndarray::NdArray; burn_vision::testgen_all!(); } From 77089939f6e57e94c635c363c5f12bd32c045dc9 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Fri, 31 Jan 2025 14:59:17 +0100 Subject: [PATCH 10/24] Implement stats --- .../src/backends/cpu/connected_components.rs | 140 ++++++++++++++++-- .../cpu/connected_components/spaghetti/mod.rs | 35 ++++- .../connected_components/spaghetti_4c/mod.rs | 13 +- 3 files changed, 167 insertions(+), 21 deletions(-) diff --git a/crates/burn-vision/src/backends/cpu/connected_components.rs b/crates/burn-vision/src/backends/cpu/connected_components.rs index 638cc0cb9e..30a2b0bfaf 100644 --- a/crates/burn-vision/src/backends/cpu/connected_components.rs +++ b/crates/burn-vision/src/backends/cpu/connected_components.rs @@ -15,34 +15,51 @@ pub fn connected_components( img: BoolTensor, connectivity: Connectivity, ) -> IntTensor { + run::(img, connectivity, || NoOp).0 +} + +pub fn connected_components_with_stats( + img: BoolTensor, + connectivity: Connectivity, + _options: ConnectedStatsOptions, +) -> (IntTensor, ConnectedStatsPrimitive) { + let device = B::bool_device(&img); + let (labels, stats) = run::(img, connectivity, ConnectedStatsOp::default); + println!("{stats:?}"); + let stats = finalize_stats(&device, stats); + (labels, stats) +} + +fn run( + img: BoolTensor, + connectivity: Connectivity, + stats: impl Fn() -> Stats, +) -> (IntTensor, Vec) { let device = B::bool_device(&img); let img = Tensor::::from_primitive(img); let [batches, _, height, width] = img.shape().dims(); let img = img.into_data().convert::().to_vec::().unwrap(); let img = Array3::from_shape_vec((batches, height, width), img).unwrap(); + let mut stats_res = Vec::with_capacity(batches); let process = match connectivity { Connectivity::Four => spaghetti_4c::process::, Connectivity::Eight => spaghetti::process::, }; - let mut out = process(img.index_axis(Axis(0), 0)); + let mut stats_0 = stats(); + let mut out = process(img.index_axis(Axis(0), 0), &mut stats_0); + stats_res.push(stats_0); for i in 1..batches { - let batch = process(img.index_axis(Axis(0), i)); + let mut stats_i = stats(); + let batch = process(img.index_axis(Axis(0), i), &mut stats_i); out.append(Axis(0), batch.view()).unwrap(); + stats_res.push(stats_i); } - println!("{out:?}"); let (data, _) = out.into_raw_vec_and_offset(); let data = TensorData::new(data, Shape::new([batches, height, width])); - Tensor::::from_data(data, &device).into_primitive() -} - -pub fn connected_components_with_stats( - _img: BoolTensor, - _connectivity: Connectivity, - _options: ConnectedStatsOptions, -) -> (IntTensor, ConnectedStatsPrimitive) { - todo!() + let labels = Tensor::::from_data(data, &device).into_primitive(); + (labels, stats_res) } pub trait Solver { @@ -50,7 +67,7 @@ pub trait Solver { /// Hack to get around mutable borrow limitations on methods fn merge(label_1: u32, label_2: u32, solver: &mut Self) -> u32; fn new_label(&mut self) -> u32; - fn flatten(&mut self); + fn flatten(&mut self) -> u32; fn get_label(&mut self, i_label: u32) -> u32; } @@ -89,7 +106,7 @@ impl Solver for UnionFind { len } - fn flatten(&mut self) { + fn flatten(&mut self) -> u32 { let mut k = 1; for i in 1..self.labels.len() { if self.labels[i] < i as u32 { @@ -99,9 +116,104 @@ impl Solver for UnionFind { k += 1; } } + k } fn get_label(&mut self, i_label: u32) -> u32 { self.labels[i_label as usize] } } + +pub trait StatsOp { + fn init(&mut self, num_labels: u32); + fn update(&mut self, row: usize, column: usize, label: u32); + fn finish(&mut self); +} + +struct NoOp; + +impl StatsOp for NoOp { + fn init(&mut self, _num_labels: u32) {} + + fn update(&mut self, _row: usize, _column: usize, _label: u32) {} + + fn finish(&mut self) {} +} + +#[derive(Default, Debug)] +struct ConnectedStatsOp { + pub area: Vec, + pub left: Vec, + pub top: Vec, + pub right: Vec, + pub bottom: Vec, +} + +impl StatsOp for ConnectedStatsOp { + fn init(&mut self, num_labels: u32) { + let num_labels = num_labels as usize; + self.area = vec![0; num_labels]; + self.left = vec![u32::MAX; num_labels]; + self.top = vec![u32::MAX; num_labels]; + self.right = vec![0; num_labels]; + self.bottom = vec![0; num_labels]; + } + + fn update(&mut self, row: usize, column: usize, label: u32) { + let l = label as usize; + self.area[l] += 1; + self.left[l] = self.left[l].min(column as u32); + self.top[l] = self.top[l].min(row as u32); + self.right[l] = self.right[l].max(column as u32); + self.bottom[l] = self.bottom[l].max(row as u32); + } + + fn finish(&mut self) { + // Background shouldn't have stats + self.area[0] = 0; + self.left[0] = 0; + self.right[0] = 0; + self.top[0] = 0; + self.bottom[0] = 0; + } +} + +fn finalize_stats( + device: &B::Device, + stats: Vec, +) -> ConnectedStatsPrimitive { + let batches = stats.len(); + let max_len = stats.iter().map(|it| it.area.len()).max().unwrap_or(1); + let mut area = Vec::with_capacity(batches * max_len); + let mut left = Vec::with_capacity(batches * max_len); + let mut top = Vec::with_capacity(batches * max_len); + let mut right = Vec::with_capacity(batches * max_len); + let mut bottom = Vec::with_capacity(batches * max_len); + + for mut stats in stats { + stats.area.resize(max_len, 0); + stats.left.resize(max_len, 0); + stats.top.resize(max_len, 0); + stats.right.resize(max_len, 0); + stats.bottom.resize(max_len, 0); + + area.extend(stats.area); + left.extend(stats.left); + top.extend(stats.top); + right.extend(stats.right); + bottom.extend(stats.bottom); + } + + let into_prim = |data: Vec| { + let data = TensorData::new(data, Shape::new([batches, max_len])); + Tensor::::from_data(data, device).into_primitive() + }; + + ConnectedStatsPrimitive { + area: into_prim(area), + left: into_prim(left), + top: into_prim(top), + right: into_prim(right), + bottom: into_prim(bottom), + } +} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs index b39c922072..7b7c7c8c9b 100644 --- a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs @@ -22,9 +22,9 @@ use ndarray::{s, Array2, ArrayView2, Axis}; mod Spaghetti_forest_labels; pub(crate) use Spaghetti_forest_labels::*; -use super::Solver; +use super::{Solver, StatsOp}; -pub fn process(img: ArrayView2) -> Array2 { +pub fn process(img: ArrayView2, stats: &mut impl StatsOp) -> Array2 { let (h, w) = img.dim(); let e_rows = h as u32 & 0xfffffffe; @@ -101,7 +101,8 @@ pub fn process(img: ArrayView2) -> Array2 { } } - solver.flatten(); + let n_labels = solver.flatten(); + stats.init(n_labels); for r in (0..e_rows as usize).step_by(2) { //Pointers: @@ -119,29 +120,41 @@ pub fn process(img: ArrayView2) -> Array2 { i_label = solver.get_label(i_label); if img_row00[c] > 0 { img_labels_row00[c] = i_label; + stats.update(r, c, i_label); } else { img_labels_row00[c] = 0; + stats.update(r, c, 0); } if img_row00[c + 1] > 0 { img_labels_row00[c + 1] = i_label; + stats.update(r, c + 1, i_label); } else { img_labels_row00[c + 1] = 0; + stats.update(r, c + 1, 0); } if img_row01[c] > 0 { img_labels_row01[c] = i_label; + stats.update(r + 1, c, i_label); } else { img_labels_row01[c] = 0; + stats.update(r + 1, c, 0); } if img_row01[c + 1] > 0 { img_labels_row01[c + 1] = i_label; + stats.update(r + 1, c + 1, i_label); } else { img_labels_row01[c + 1] = 0; + stats.update(r + 1, c + 1, 0); } } else { img_labels_row00[c] = 0; + stats.update(r, c, 0); img_labels_row00[c + 1] = 0; + stats.update(r, c + 1, 0); img_labels_row01[c] = 0; + stats.update(r + 1, c, 0); img_labels_row01[c + 1] = 0; + stats.update(r + 1, c + 1, 0); } } if o_cols { @@ -151,17 +164,23 @@ pub fn process(img: ArrayView2) -> Array2 { i_label = solver.get_label(i_label); if img_row00[c] > 0 { img_labels_row00[c] = i_label; + stats.update(r, c, i_label); } else { img_labels_row00[c] = 0; + stats.update(r, c, 0); } if img_row01[c] > 0 { img_labels_row01[c] = i_label; + stats.update(r + 1, c, i_label); } else { img_labels_row01[c] = 0; + stats.update(r + 1, c, 0); } } else { img_labels_row00[c] = 0; + stats.update(r, c, 0); img_labels_row01[c] = 0; + stats.update(r + 1, c, 0); } } } @@ -181,17 +200,23 @@ pub fn process(img: ArrayView2) -> Array2 { i_label = solver.get_label(i_label); if img_row00[c] > 0 { img_labels_row00[c] = i_label; + stats.update(r, c, i_label); } else { img_labels_row00[c] = 0; + stats.update(r, c, 0); } if img_row00[c + 1] > 0 { img_labels_row00[c + 1] = i_label; + stats.update(r, c + 1, i_label); } else { img_labels_row00[c + 1] = 0; + stats.update(r, c + 1, 0); } } else { img_labels_row00[c] = 0; + stats.update(r, c, 0); img_labels_row00[c + 1] = 0; + stats.update(r, c + 1, 0); } } if o_cols { @@ -201,14 +226,18 @@ pub fn process(img: ArrayView2) -> Array2 { i_label = solver.get_label(i_label); if img_row00[c] > 0 { img_labels_row00[c] = i_label; + stats.update(r, c, i_label); } else { img_labels_row00[c] = 0; + stats.update(r, c, 0); } } else { img_labels_row00[c] = 0; + stats.update(r, c, i_label); } } } + stats.finish(); img_labels } diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs index 9085293086..dbcd6d757a 100644 --- a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs @@ -15,13 +15,13 @@ use ndarray::{s, Array2, ArrayView2, Axis}; -use super::Solver; +use super::{Solver, StatsOp}; #[allow(non_snake_case)] mod Spaghetti4C_forest_labels; pub(crate) use Spaghetti4C_forest_labels::*; -pub fn process(img: ArrayView2) -> Array2 { +pub fn process(img: ArrayView2, stats: &mut impl StatsOp) -> Array2 { let (h, w) = img.dim(); let mut img_labels = Array2::default(img.raw_dim()); @@ -73,9 +73,14 @@ pub fn process(img: ArrayView2) -> Array2 { include!("Spaghetti4C_center_line_forest_code.rs"); } - solver.flatten(); + let n_labels = solver.flatten(); + stats.init(n_labels); - img_labels.map_inplace(|label| *label = solver.get_label(*label)); + img_labels.indexed_iter_mut().for_each(|((r, c), label)| { + *label = solver.get_label(*label); + stats.update(r, c, *label); + }); + stats.finish(); img_labels } From aeea3a8225cd435ea936bc812374776c03f2f515 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Fri, 31 Jan 2025 15:17:13 +0100 Subject: [PATCH 11/24] Implement all backends except fusion --- Cargo.lock | 3 ++ crates/burn-candle/src/element.rs | 3 ++ crates/burn-candle/src/lib.rs | 1 + crates/burn-vision/Cargo.toml | 11 ++++-- crates/burn-vision/src/backends/cpu/ops.rs | 39 +++++++++++----------- 5 files changed, 35 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a3ab30f4ec..e8b5340f20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -968,9 +968,12 @@ dependencies = [ name = "burn-vision" version = "0.17.0" dependencies = [ + "burn-autodiff", + "burn-candle", "burn-cuda", "burn-jit", "burn-ndarray", + "burn-tch", "burn-tensor", "burn-tensor-testgen", "burn-wgpu", diff --git a/crates/burn-candle/src/element.rs b/crates/burn-candle/src/element.rs index df5c2dc756..ebe4a056c2 100644 --- a/crates/burn-candle/src/element.rs +++ b/crates/burn-candle/src/element.rs @@ -4,8 +4,11 @@ use burn_tensor::Element; use candle_core::{FloatDType, Tensor, WithDType}; use half::{bf16, f16}; +/// Candle element pub trait CandleElement: Element + WithDType {} +/// Candle float element pub trait FloatCandleElement: CandleElement + FloatDType {} +/// Candle int element pub trait IntCandleElement: CandleElement {} impl CandleElement for f64 {} diff --git a/crates/burn-candle/src/lib.rs b/crates/burn-candle/src/lib.rs index 78923fad03..64a6d05330 100644 --- a/crates/burn-candle/src/lib.rs +++ b/crates/burn-candle/src/lib.rs @@ -13,6 +13,7 @@ mod ops; mod tensor; pub use backend::*; +pub use element::*; pub use tensor::*; #[cfg(test)] diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index f92e2d6be4..a044b2c81e 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -13,9 +13,13 @@ version.workspace = true [features] -default = ["jit-backend"] +autodiff = ["burn-autodiff"] +candle = ["burn-candle"] +default = ["ndarray"] export-tests = ["burn-tensor-testgen"] jit-backend = ["cubecl", "burn-jit"] +ndarray = ["burn-ndarray"] +tch = ["burn-tch"] # Test features cpu = ["export-tests"] @@ -24,8 +28,11 @@ vulkan = ["burn-wgpu/vulkan", "wgpu"] wgpu = ["jit-backend", "export-tests"] [dependencies] +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true } +burn-candle = { path = "../burn-candle", version = "0.17.0", optional = true } burn-jit = { path = "../burn-jit", version = "0.17.0", optional = true } -burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", optional = true } +burn-tch = { path = "../burn-tch", version = "0.17.0", optional = true } burn-tensor = { path = "../burn-tensor", version = "0.17.0" } burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } cubecl = { workspace = true, optional = true } diff --git a/crates/burn-vision/src/backends/cpu/ops.rs b/crates/burn-vision/src/backends/cpu/ops.rs index 157d56aed7..3d75e83157 100644 --- a/crates/burn-vision/src/backends/cpu/ops.rs +++ b/crates/burn-vision/src/backends/cpu/ops.rs @@ -1,24 +1,23 @@ -use crate::{ - backends::cpu, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, VisionOps, -}; +use crate::VisionOps; + +#[cfg(feature = "autodiff")] +use burn_autodiff::{checkpoint::strategy::CheckpointStrategy, Autodiff}; +#[cfg(feature = "candle")] +use burn_candle::{Candle, FloatCandleElement, IntCandleElement}; +#[cfg(feature = "ndarray")] use burn_ndarray::{FloatNdArrayElement, IntNdArrayElement, NdArray, QuantElement}; -use burn_tensor::ops::{BoolTensor, IntTensor}; +#[cfg(feature = "tch")] +use burn_tch::{LibTorch, TchElement}; -impl VisionOps for NdArray -where - E: FloatNdArrayElement, - I: IntNdArrayElement, - Q: QuantElement, +#[cfg(feature = "ndarray")] +impl VisionOps + for NdArray { - fn connected_components(img: BoolTensor, connectivity: Connectivity) -> IntTensor { - cpu::connected_components::(img, connectivity) - } - - fn connected_components_with_stats( - img: BoolTensor, - connectivity: Connectivity, - opts: ConnectedStatsOptions, - ) -> (IntTensor, ConnectedStatsPrimitive) { - cpu::connected_components_with_stats::(img, connectivity, opts) - } } + +#[cfg(feature = "candle")] +impl VisionOps for Candle {} +#[cfg(feature = "tch")] +impl VisionOps for LibTorch {} +#[cfg(feature = "autodiff")] +impl VisionOps for Autodiff {} From a994ca742950fab140793dcec5b8f6ba797cf765 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Fri, 31 Jan 2025 15:25:51 +0100 Subject: [PATCH 12/24] Fix autodiff to use GPU when available --- crates/burn-vision/Cargo.toml | 3 ++- crates/burn-vision/src/backends/cpu/ops.rs | 30 +++++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index a044b2c81e..745fcb1e36 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -15,7 +15,7 @@ version.workspace = true [features] autodiff = ["burn-autodiff"] candle = ["burn-candle"] -default = ["ndarray"] +default = ["ndarray", "autodiff"] export-tests = ["burn-tensor-testgen"] jit-backend = ["cubecl", "burn-jit"] ndarray = ["burn-ndarray"] @@ -44,3 +44,4 @@ serde = { workspace = true } burn-cuda = { path = "../burn-cuda", version = "0.17.0", default-features = false } burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", default-features = false } +cubecl = { workspace = true } diff --git a/crates/burn-vision/src/backends/cpu/ops.rs b/crates/burn-vision/src/backends/cpu/ops.rs index 3d75e83157..c682703106 100644 --- a/crates/burn-vision/src/backends/cpu/ops.rs +++ b/crates/burn-vision/src/backends/cpu/ops.rs @@ -20,4 +20,32 @@ impl VisionOps for Candle VisionOps for LibTorch {} #[cfg(feature = "autodiff")] -impl VisionOps for Autodiff {} +impl, C: CheckpointStrategy> VisionOps + for Autodiff +{ + fn connected_components( + img: burn_tensor::ops::BoolTensor, + connectivity: crate::Connectivity, + ) -> burn_tensor::ops::IntTensor { + B::connected_components(img, connectivity) + } + + fn connected_components_with_stats( + img: burn_tensor::ops::BoolTensor, + connectivity: crate::Connectivity, + opts: crate::ConnectedStatsOptions, + ) -> ( + burn_tensor::ops::IntTensor, + crate::ConnectedStatsPrimitive, + ) { + let (labels, stats) = B::connected_components_with_stats(img, connectivity, opts); + let stats = crate::ConnectedStatsPrimitive:: { + area: stats.area, + left: stats.left, + top: stats.top, + right: stats.right, + bottom: stats.bottom, + }; + (labels, stats) + } +} From 866307b1669518044cc7b7cb83befad777ba6b94 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Fri, 31 Jan 2025 17:10:04 +0100 Subject: [PATCH 13/24] Fixes and cleanup --- Cargo.lock | 1 + crates/burn-vision/Cargo.toml | 2 + .../src/backends/cpu/connected_components.rs | 7 + .../cpu/connected_components/spaghetti/mod.rs | 6 +- .../connected_components/spaghetti_4c/mod.rs | 6 +- .../backends/jit/connected_components/bke.rs | 388 ------------------ .../backends/jit/connected_components/mod.rs | 1 - crates/burn-vision/src/backends/jit/ops.rs | 127 ++++++ 8 files changed, 145 insertions(+), 393 deletions(-) delete mode 100644 crates/burn-vision/src/backends/jit/connected_components/bke.rs diff --git a/Cargo.lock b/Cargo.lock index e8b5340f20..ac051eef94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -971,6 +971,7 @@ dependencies = [ "burn-autodiff", "burn-candle", "burn-cuda", + "burn-fusion", "burn-jit", "burn-ndarray", "burn-tch", diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index 745fcb1e36..acebb4317a 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -17,6 +17,7 @@ autodiff = ["burn-autodiff"] candle = ["burn-candle"] default = ["ndarray", "autodiff"] export-tests = ["burn-tensor-testgen"] +fusion = ["burn-fusion", "burn-cuda/fusion", "burn-wgpu/fusion"] jit-backend = ["cubecl", "burn-jit"] ndarray = ["burn-ndarray"] tch = ["burn-tch"] @@ -30,6 +31,7 @@ wgpu = ["jit-backend", "export-tests"] [dependencies] burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true } burn-candle = { path = "../burn-candle", version = "0.17.0", optional = true } +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } burn-jit = { path = "../burn-jit", version = "0.17.0", optional = true } burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", optional = true } burn-tch = { path = "../burn-tch", version = "0.17.0", optional = true } diff --git a/crates/burn-vision/src/backends/cpu/connected_components.rs b/crates/burn-vision/src/backends/cpu/connected_components.rs index 30a2b0bfaf..9415050354 100644 --- a/crates/burn-vision/src/backends/cpu/connected_components.rs +++ b/crates/burn-vision/src/backends/cpu/connected_components.rs @@ -217,3 +217,10 @@ fn finalize_stats( bottom: into_prim(bottom), } } + +pub fn max_labels(h: usize, w: usize, conn: Connectivity) -> usize { + match conn { + Connectivity::Four => ((h * w + 1) / 2) + 1, + Connectivity::Eight => ((h + 1) / 2) * ((w + 1) / 2) + 1, + } +} diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs index 7b7c7c8c9b..66c79ae99c 100644 --- a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs @@ -22,7 +22,9 @@ use ndarray::{s, Array2, ArrayView2, Axis}; mod Spaghetti_forest_labels; pub(crate) use Spaghetti_forest_labels::*; -use super::{Solver, StatsOp}; +use crate::Connectivity; + +use super::{max_labels, Solver, StatsOp}; pub fn process(img: ArrayView2, stats: &mut impl StatsOp) -> Array2 { let (h, w) = img.dim(); @@ -34,7 +36,7 @@ pub fn process(img: ArrayView2, stats: &mut impl Stats let mut img_labels = Array2::default(img.raw_dim()); - let mut solver = LabelsSolver::init(((h + 1) / 2) * ((w + 1) / 2) + 1); + let mut solver = LabelsSolver::init(max_labels(h, w, Connectivity::Eight)); let solver = &mut solver; diff --git a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs index dbcd6d757a..d1a9ab4304 100644 --- a/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs +++ b/crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs @@ -15,7 +15,9 @@ use ndarray::{s, Array2, ArrayView2, Axis}; -use super::{Solver, StatsOp}; +use crate::Connectivity; + +use super::{max_labels, Solver, StatsOp}; #[allow(non_snake_case)] mod Spaghetti4C_forest_labels; @@ -34,7 +36,7 @@ pub fn process(img: ArrayView2, stats: &mut impl Stats // 0 1 0 1 0... // 1 0 1 0 1... // ............ - let max_labels = ((h * w + 1) / 2) + 1; + let max_labels = max_labels(h, w, Connectivity::Four); let mut solver = LabelsSolver::init(max_labels); let solver = &mut solver; diff --git a/crates/burn-vision/src/backends/jit/connected_components/bke.rs b/crates/burn-vision/src/backends/jit/connected_components/bke.rs deleted file mode 100644 index ae8633cadb..0000000000 --- a/crates/burn-vision/src/backends/jit/connected_components/bke.rs +++ /dev/null @@ -1,388 +0,0 @@ -//! Block-based komura equivalence, adapted from -//! S. Allegretti, F. Bolelli, C. Grana, -//! "Optimized Block-Based Algorithms to Label Connected Components on GPUs," -//! in IEEE Transactions on Parallel and Distributed Systems, 2019. - -use burn_jit::{ - kernel, - ops::numeric::{empty_device, zeros_device}, - tensor::JitTensor, - JitElement, JitRuntime, -}; -use burn_tensor::{DType, Shape}; -use cubecl::cube; -use cubecl::prelude::*; - -mod info { - pub const A: u8 = 0; - pub const B: u8 = 1; - pub const C: u8 = 2; - pub const D: u8 = 3; - pub const Q: u8 = 5; - pub const R: u8 = 6; - pub const S: u8 = 7; -} - -#[cube] -fn has_bit(bitmap: I, pos: u8) -> bool { - bool::cast_from((bitmap >> I::cast_from(pos)) & I::new(1)) -} - -#[cube] -fn set_bit(bitmap: I, pos: u8) -> I { - bitmap | (I::new(1) << I::cast_from(pos)) -} - -#[cube] -fn find_root(s_buf: &Tensor>, n: u32) -> u32 { - let mut n = n; - while Atomic::load(&s_buf[n]) != n { - n = Atomic::load(&s_buf[n]); - } - n -} - -#[cube] -fn find_root_and_compress(s_buf: &mut Tensor, id: u32) -> u32 { - let mut n = id; - while s_buf[n] != n { - n = s_buf[n]; - s_buf[id] = n; - } - n -} - -#[cube] -fn tree_union(s_buf: &Tensor>, a: u32, b: u32) { - let mut a = a; - let mut b = b; - #[allow(unused_assignments)] - let mut done = false; - - loop { - a = find_root(s_buf, a); - b = find_root(s_buf, b); - - #[allow(clippy::comparison_chain, reason = "not supported in cubecl")] - if a < b { - let old = Atomic::min(&s_buf[b], a); - done = old == b; - b = old; - } else if b < a { - let old = Atomic::min(&s_buf[a], b); - done = old == a; - a = old; - } else { - done = true; - } - - if done { - break; - } - } -} - -#[cube(launch)] -fn init_labeling(img: &Tensor, labels: &mut Tensor, last_pixel: &mut Array) { - let batch = ABSOLUTE_POS_Z; - let row = ABSOLUTE_POS_Y * 2; - let col = ABSOLUTE_POS_X * 2; - - if row >= labels.shape(1) || col >= labels.shape(2) { - terminate!(); - } - - let img_rows = img.shape(2); - let img_cols = img.shape(3); - let img_stride = img.stride(2); - let labels_stride = labels.stride(1); - - let img_index = batch * img.stride(0) + row * img_stride + col * img.stride(3); - let labels_index = batch * labels.stride(0) + row * labels_stride + col * labels.stride(2); - - let mut p = 0u16; - - // Bitmask representing two kinds of information - // Bits 0, 1, 2, 3 are set if pixel a, b, c, d are foreground, respectively - // Bits 4, 5, 6, 7 are set if block P, Q, R, S need to be merged to X in Merge phase - let mut info = 0u8; - - let mut buffer = Array::::new(4); - #[unroll] - for i in 0..4 { - buffer[i] = 0; - } - - if col + 1 < img_cols { - buffer[0] = img[img_index]; - buffer[1] = img[img_index + 1]; - - if row + 1 < img_rows { - buffer[2] = img[img_index + img_stride]; - buffer[3] = img[img_index + img_stride + 1]; - } - } else { - buffer[0] = img[img_index]; - - if row + 1 < img_rows { - buffer[2] = img[img_index + img_stride]; - } - } - - if buffer[0] != 0 { - p |= 0x777; - info = set_bit::(info, info::A); - } - if buffer[1] != 0 { - p |= 0x777 << 1; - info = set_bit::(info, info::B); - } - if buffer[2] != 0 { - p |= 0x777 << 4; - info = set_bit::(info, info::C); - } - if buffer[3] != 0 { - info = set_bit::(info, info::D); - } - - if col == 0 { - p &= 0xeeee; - } - if col + 1 >= img_cols { - p &= 0x3333; - } else if col + 2 >= img_cols { - p &= 0x7777; - } - - if row == 0 { - p &= 0xfff0; - } - if row + 1 >= img_rows { - p &= 0x00ff; - } else if row + 2 >= img_rows { - p &= 0x0fff; - } - - // P is now ready to be used to find neighbor blocks - // P value avoids range errors - - let mut father_offset = 0i32; - - // P square - if has_bit::(p, 0) && img[img_index - img_stride - 1] != 0 { - father_offset = -(2 * labels_stride as i32 + 2); - } - - // Q square - if (has_bit::(p, 1) && img[img_index - img_stride] != 0) - || (has_bit::(p, 2) && img[img_index + 1 - img_stride] != 0) - { - if father_offset == 0 { - father_offset = -(2 * labels_stride as i32); - } else { - info = set_bit::(info, info::Q); - } - } - - // R square - if has_bit::(p, 3) && img[img_index + 2 - img_stride] != 0 { - if father_offset == 0 { - father_offset = -(2 * labels_stride as i32 - 2); - } else { - info = set_bit::(info, info::R); - } - } - - // S square - if (has_bit::(p, 4) && img[img_index - 1] != 0) - || (has_bit::(p, 8) && img[img_index + img_stride - 1] != 0) - { - if father_offset == 0 { - father_offset = -2i32; - } else { - info = set_bit::(info, info::S); - } - } - - labels[labels_index] = labels_index as i32 + father_offset; - if col + 1 < labels.shape(2) { - labels[labels_index + 1] = info as i32; - } else if row + 1 < labels.shape(1) { - labels[labels_index + labels_stride] = info as i32; - } else { - last_pixel[0] = info; - } -} - -#[cube(launch)] -fn merge(labels: &mut Tensor>, last_pixel: &mut Array) { - let batch = ABSOLUTE_POS_Z; - let row = ABSOLUTE_POS_Y * 2; - let col = ABSOLUTE_POS_X * 2; - let rows = labels.shape(1); - let cols = labels.shape(2); - let labels_stride = labels.stride(1); - let labels_index = batch * labels.stride(0) + row * labels_stride + col; - - if row >= labels.shape(1) || col >= labels.shape(2) { - terminate!(); - } - - let info = if col + 1 < cols { - Atomic::load(&labels[labels_index + 1]) as u8 - } else if row + 1 < rows { - Atomic::load(&labels[labels_index + labels_stride]) as u8 - } else { - last_pixel[0] - }; - - if has_bit::(info, info::Q) { - tree_union(labels, labels_index, labels_index - 2 * labels_stride); - } - if has_bit::(info, info::R) { - tree_union(labels, labels_index, labels_index - 2 * labels_stride + 2); - } - if has_bit::(info, info::S) { - tree_union(labels, labels_index, labels_index - 1); - } -} - -#[cube(launch)] -fn compression(labels: &mut Tensor) { - let batch = ABSOLUTE_POS_Z; - let row = ABSOLUTE_POS_Y * 2; - let col = ABSOLUTE_POS_X * 2; - let labels_index = batch * labels.stride(0) + row * labels.stride(1) + col; - - if row < labels.shape(1) && col < labels.shape(2) { - find_root_and_compress(labels, labels_index); - } -} - -#[cube(launch)] -fn final_labeling(img: &Tensor, labels: &mut Tensor) { - let batch = ABSOLUTE_POS_Z; - let row = ABSOLUTE_POS_Y * 2; - let col = ABSOLUTE_POS_X * 2; - let rows = labels.shape(1); - let cols = labels.shape(2); - let label_stride = labels.stride(1); - let img_stride = img.stride(2); - let labels_index = batch * labels.stride(0) + row * label_stride + col; - - if row >= labels.shape(1) || col >= labels.shape(2) { - terminate!(); - } - - let mut label = 0; - #[allow(unused_assignments)] - let mut info = 0u8; - let mut buffer = Array::::new(2); - - if col + 1 < cols { - buffer[0] = label[labels_index]; - buffer[1] = label[labels_index + 1]; - label = buffer[0] + 1; - info = buffer[1] as u8; - } else { - label = labels[labels_index] + 1; - if row + 1 < rows { - info = labels[labels_index + label_stride] as u8; - } else { - // Read from the input image - // "a" is already in position 0 - info = img[batch * img.stride(0) + row * img_stride + col]; - } - } - - if col + 1 < cols { - labels[labels_index] = has_bit::(info, info::B) as u32 * label; - labels[labels_index + 1] = has_bit::(info, info::A) as u32 * label; - - if row + 1 < rows { - labels[labels_index + label_stride] = has_bit::(info, info::D) as u32 * label; - labels[labels_index + label_stride + 1] = has_bit::(info, info::C) as u32 * label; - } - } else { - labels[labels_index] = has_bit::(info, info::A) as u32 * label; - - if row + 1 < rows { - labels[labels_index + label_stride] = has_bit::(info, info::C) as u32 * label; - } - } -} - -#[expect( - unused, - reason = "currently broken because kernel reassigns pointers and I need to figure out how to port that" -)] -pub fn block_based_komura_equivalence( - img: JitTensor, -) -> JitTensor { - let img = kernel::cast::(img); - - let [batches, channels, rows, columns] = img.shape.dims(); - assert_eq!(channels, 1, "Channels must be 1 for connected components"); - - let shape = Shape::new([batches, rows, columns]); - let labels = zeros_device::(img.client.clone(), img.device.clone(), shape); - - let last_pixel = if (rows == 1 || columns == 1) && (rows + columns) % 2 == 0 { - empty_device::(img.client.clone(), img.device.clone(), Shape::new([1])) - } else { - let offset = (((rows - 2) * labels.strides[2]) + (columns - 2)) * size_of::(); - JitTensor::new_contiguous( - labels.client.clone(), - labels.device.clone(), - Shape::new([1]), - labels.handle.clone().offset_start(offset as u64), - DType::U8, - ) - }; - - let cube_dim = CubeDim::default(); - let cube_count_x = (columns as u32).div_ceil(2).div_ceil(cube_dim.x); - let cube_count_y = (rows as u32).div_ceil(2).div_ceil(cube_dim.y); - let cube_count = CubeCount::Static(cube_count_x, cube_count_y, batches as u32); - - init_labeling::launch( - &img.client, - cube_count.clone(), - cube_dim, - img.as_tensor_arg::(1), - labels.as_tensor_arg::(1), - last_pixel.as_array_arg::(1), - ); - - compression::launch( - &img.client, - cube_count.clone(), - cube_dim, - labels.as_tensor_arg::(1), - ); - - merge::launch( - &img.client, - cube_count.clone(), - cube_dim, - labels.as_tensor_arg::(1), - last_pixel.as_array_arg::(1), - ); - - compression::launch( - &img.client, - cube_count.clone(), - cube_dim, - labels.as_tensor_arg::(1), - ); - - final_labeling::launch( - &img.client, - cube_count.clone(), - cube_dim, - img.as_tensor_arg::(1), - labels.as_tensor_arg::(1), - ); - - labels -} diff --git a/crates/burn-vision/src/backends/jit/connected_components/mod.rs b/crates/burn-vision/src/backends/jit/connected_components/mod.rs index 0fa8bcc07e..e8c740695b 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/mod.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/mod.rs @@ -1,4 +1,3 @@ -mod bke; mod hardware_accelerated; use burn_jit::{ diff --git a/crates/burn-vision/src/backends/jit/ops.rs b/crates/burn-vision/src/backends/jit/ops.rs index 4935de561a..0ac628f77f 100644 --- a/crates/burn-vision/src/backends/jit/ops.rs +++ b/crates/burn-vision/src/backends/jit/ops.rs @@ -1,8 +1,15 @@ use crate::{ backends::cpu, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, VisionOps, }; +#[cfg(feature = "fusion")] +use burn_fusion::{client::FusionClient, stream::Operation, Fusion, FusionBackend, FusionRuntime}; use burn_jit::{BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_tensor::ops::{BoolTensor, IntTensor}; +#[cfg(feature = "fusion")] +use burn_tensor::{ + repr::{CustomOpDescription, HandleContainer, OperationDescription}, + Element, +}; use super::connected_components::hardware_accelerated; @@ -33,3 +40,123 @@ where }) } } + +#[cfg(feature = "fusion")] +impl> VisionOps for Fusion { + fn connected_components(img: BoolTensor, conn: Connectivity) -> IntTensor { + let batches = img.shape[0]; + let height = img.shape[2]; + let width = img.shape[3]; + let client = img.client.clone(); + + #[derive(derive_new::new)] + struct ConnComp { + desc: CustomOpDescription, + conn: Connectivity, + _b: core::marker::PhantomData, + } + + impl> Operation for ConnComp { + fn execute( + self: Box, + handles: &mut HandleContainer<::FusionHandle>, + ) { + let ([img], [labels]) = self.desc.consume(); + let input = handles.get_bool_tensor::(&img); + let output = B1::connected_components(input, self.conn); + + handles.register_int_tensor::(&labels.id, output); + } + } + + let stream = img.stream; + let out = client.tensor_uninitialized(vec![batches, height, width], B::IntElem::dtype()); + + let desc = CustomOpDescription::new( + "connected_components", + &[img.into_description()], + &[out.to_description_out()], + ); + client.register( + vec![stream], + OperationDescription::Custom(desc.clone()), + ConnComp::::new(desc, conn), + ); + + out + } + + fn connected_components_with_stats( + img: BoolTensor, + conn: Connectivity, + opts: ConnectedStatsOptions, + ) -> (IntTensor, ConnectedStatsPrimitive) { + let batches = img.shape[0]; + let height = img.shape[2]; + let width = img.shape[3]; + let client = img.client.clone(); + + #[derive(derive_new::new)] + struct ConnCompStats { + desc: CustomOpDescription, + conn: Connectivity, + opts: ConnectedStatsOptions, + _b: core::marker::PhantomData, + } + + impl> Operation for ConnCompStats { + fn execute( + self: Box, + handles: &mut HandleContainer<::FusionHandle>, + ) { + let ([img], [labels, area, left, top, right, bottom]) = self.desc.consume(); + let input = handles.get_bool_tensor::(&img); + let (output, stats) = + B1::connected_components_with_stats(input, self.conn, self.opts); + + handles.register_int_tensor::(&labels.id, output); + handles.register_int_tensor::(&area.id, stats.area); + handles.register_int_tensor::(&left.id, stats.left); + handles.register_int_tensor::(&top.id, stats.top); + handles.register_int_tensor::(&right.id, stats.right); + handles.register_int_tensor::(&bottom.id, stats.bottom); + } + } + + let stream = img.stream; + let out = client.tensor_uninitialized(vec![batches, height, width], B::IntElem::dtype()); + let area = client.tensor_uninitialized(vec![batches, height * width], B::IntElem::dtype()); + let left = client.tensor_uninitialized(vec![batches, height * width], B::IntElem::dtype()); + let top = client.tensor_uninitialized(vec![batches, height * width], B::IntElem::dtype()); + let right = client.tensor_uninitialized(vec![batches, height * width], B::IntElem::dtype()); + let bottom = + client.tensor_uninitialized(vec![batches, height * width], B::IntElem::dtype()); + + let desc = CustomOpDescription::new( + "connected_components", + &[img.into_description()], + &[ + out.to_description_out(), + area.to_description_out(), + left.to_description_out(), + top.to_description_out(), + right.to_description_out(), + bottom.to_description_out(), + ], + ); + client.register( + vec![stream], + OperationDescription::Custom(desc.clone()), + ConnCompStats::::new(desc, conn, opts), + ); + + let stats = ConnectedStatsPrimitive { + area, + left, + top, + right, + bottom, + }; + (out, stats) + } +} From a8e3994eaeeca5ac660e67982d28f83b54bed263 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Fri, 31 Jan 2025 18:30:43 +0100 Subject: [PATCH 14/24] Add docs --- Cargo.lock | 5 + crates/burn-vision/Cargo.toml | 2 +- .../src/backends/cpu/connected_components.rs | 4 +- .../hardware_accelerated.rs | 91 +++++++++---------- .../backends/jit/connected_components/mod.rs | 3 +- crates/burn-vision/src/lib.rs | 14 +++ crates/burn-vision/src/ops/base.rs | 38 ++++++++ crates/burn-vision/src/tensor.rs | 13 ++- .../src/tests/connected_components.rs | 8 +- 9 files changed, 122 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ac051eef94..719014a5aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1709,7 +1709,9 @@ dependencies = [ "float-ord", "fnv", "half", + "hashbrown 0.14.5", "num-traits", + "portable-atomic", "serde", "variadics_please", ] @@ -5429,6 +5431,9 @@ name = "portable-atomic" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +dependencies = [ + "serde", +] [[package]] name = "portable-atomic-util" diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index acebb4317a..03f9ad0eb9 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -15,7 +15,7 @@ version.workspace = true [features] autodiff = ["burn-autodiff"] candle = ["burn-candle"] -default = ["ndarray", "autodiff"] +default = ["ndarray", "autodiff", "jit-backend"] export-tests = ["burn-tensor-testgen"] fusion = ["burn-fusion", "burn-cuda/fusion", "burn-wgpu/fusion"] jit-backend = ["cubecl", "burn-jit"] diff --git a/crates/burn-vision/src/backends/cpu/connected_components.rs b/crates/burn-vision/src/backends/cpu/connected_components.rs index 9415050354..fff61aeeaa 100644 --- a/crates/burn-vision/src/backends/cpu/connected_components.rs +++ b/crates/burn-vision/src/backends/cpu/connected_components.rs @@ -36,8 +36,8 @@ fn run( stats: impl Fn() -> Stats, ) -> (IntTensor, Vec) { let device = B::bool_device(&img); - let img = Tensor::::from_primitive(img); - let [batches, _, height, width] = img.shape().dims(); + let img = Tensor::::from_primitive(img); + let [batches, height, width] = img.shape().dims(); let img = img.into_data().convert::().to_vec::().unwrap(); let img = Array3::from_shape_vec((batches, height, width), img).unwrap(); let mut stats_res = Vec::with_capacity(batches); diff --git a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs index 3ffb883b11..c3dc243686 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs @@ -11,21 +11,20 @@ use burn_jit::{ ops::numeric::zeros_device, tensor::JitTensor, BoolElement, FloatElement, IntElement, JitBackend, JitRuntime, }; -use burn_tensor::Shape; use cubecl::{prelude::*, Feature}; const BLOCK_H: u32 = 4; #[cube] -fn merge(labels: &Tensor>, label_1: u32, label_2: u32) { +fn merge(labels: &Tensor>, label_1: u32, label_2: u32) { let mut label_1 = label_1; let mut label_2 = label_2; - while label_1 != label_2 && (label_1 != Atomic::load(&labels[label_1]) - 1) { - label_1 = Atomic::load(&labels[label_1]) - 1; + while label_1 != label_2 && (label_1 != u32::cast_from(Atomic::load(&labels[label_1])) - 1) { + label_1 = u32::cast_from(Atomic::load(&labels[label_1])) - 1; } - while label_1 != label_2 && (label_2 != Atomic::load(&labels[label_2]) - 1) { - label_2 = Atomic::load(&labels[label_2]) - 1; + while label_1 != label_2 && (label_2 != u32::cast_from(Atomic::load(&labels[label_2])) - 1) { + label_2 = u32::cast_from(Atomic::load(&labels[label_2])) - 1; } while label_1 != label_2 { #[allow(clippy::manual_swap)] @@ -34,7 +33,7 @@ fn merge(labels: &Tensor>, label_1: u32, label_2: u32) { label_1 = label_2; label_2 = tmp; } - let label_3 = Atomic::min(&labels[label_1], label_2 + 1) - 1; + let label_3 = u32::cast_from(Atomic::min(&labels[label_1], I::cast_from(label_2 + 1))) - 1; if label_1 == label_3 { label_1 = label_2; } else { @@ -61,9 +60,9 @@ fn ballot_dyn(y: u32, pred: bool) -> u32 { } #[cube(launch)] -fn strip_labeling( +fn strip_labeling( img: &Tensor, - labels: &Tensor>, + labels: &Tensor>, #[comptime] connectivity: Connectivity, ) { let mut shared_pixels = SharedMemory::::new(BLOCK_H); @@ -77,7 +76,7 @@ fn strip_labeling( terminate!(); } - let img_stride = img.stride(2); + let img_stride = img.stride(1); let labels_stride = labels.stride(1); let img_line_base = batch * img.stride(0) + y * img_stride + UNIT_POS_X; @@ -86,7 +85,7 @@ fn strip_labeling( let mut distance_y = 0; let mut distance_y_1 = 0; - for i in range_stepped(0, img.shape(3), PLANE_DIM) { + for i in range_stepped(0, img.shape(2), PLANE_DIM) { let x = UNIT_POS_X + i; if x < cols { @@ -107,7 +106,7 @@ fn strip_labeling( if p_y && s_dist_y == 0 { Atomic::store( &labels[labels_index], - labels_index - select(UNIT_POS_X == 0, distance_y, 0) + 1, + I::cast_from(labels_index - select(UNIT_POS_X == 0, distance_y, 0) + 1), ); } @@ -190,9 +189,9 @@ fn strip_labeling( } #[cube(launch)] -fn strip_merge( +fn strip_merge( img: &Tensor, - labels: &Tensor>, + labels: &Tensor>, #[comptime] connectivity: Connectivity, ) { let batch = CUBE_POS_Z; @@ -200,9 +199,9 @@ fn strip_merge( let y = (CUBE_POS_Y + 1) * BLOCK_H; let x = plane_start_x + UNIT_POS_X; - let img_step = img.stride(2); + let img_step = img.stride(1); let labels_step = labels.stride(1); - let cols = img.shape(3); + let cols = img.shape(2); if y < labels.shape(1) && x < labels.shape(2) { let mut mask = 0xffffffffu32; @@ -294,7 +293,7 @@ fn strip_merge( } #[cube(launch)] -fn relabeling(img: &Tensor, labels: &mut Tensor) { +fn relabeling(img: &Tensor, labels: &mut Tensor) { let batch = ABSOLUTE_POS_Z; let plane_start_x = CUBE_POS_X * CUBE_DIM_X; let y = ABSOLUTE_POS_Y; @@ -302,7 +301,7 @@ fn relabeling(img: &Tensor, labels: &mut Tensor) { let cols = labels.shape(2); let rows = labels.shape(1); - let img_step = img.stride(2); + let img_step = img.stride(1); let labels_step = labels.stride(1); if x < cols && y < rows { @@ -320,29 +319,29 @@ fn relabeling(img: &Tensor, labels: &mut Tensor) { let mut label = 0u32; if p && s_dist == 0 { - label = labels[labels_index] - 1; - while label != labels[label] - 1 { - label = labels[label] - 1; + label = u32::cast_from(labels[labels_index]) - 1; + while label != u32::cast_from(labels[label]) - 1 { + label = u32::cast_from(labels[label]) - 1; } } label = plane_broadcast(label, UNIT_POS_X - s_dist); if p { - labels[labels_index] = label + 1; + labels[labels_index] = I::cast_from(label + 1); } } } #[cube(launch)] -fn analysis( +fn analysis( img: &Tensor, - labels: &mut Tensor, - area: &mut Tensor>, - top: &mut Tensor>, - left: &mut Tensor>, - right: &mut Tensor>, - bottom: &mut Tensor>, + labels: &mut Tensor, + area: &mut Tensor>, + top: &mut Tensor>, + left: &mut Tensor>, + right: &mut Tensor>, + bottom: &mut Tensor>, #[comptime] opts: ConnectedStatsOptions, ) { let batch = ABSOLUTE_POS_Z; @@ -351,7 +350,7 @@ fn analysis( let cols = labels.shape(2); let rows = labels.shape(1); - let img_step = img.stride(2); + let img_step = img.stride(1); let labels_step = labels.stride(1); if x < cols && y < rows { @@ -372,32 +371,32 @@ fn analysis( let mut label = 0u32; if p && s_dist == 0 { - label = labels[labels_index] - 1; - while label != labels[label] - 1 { - label = labels[label] - 1; + label = u32::cast_from(labels[labels_index]) - 1; + while label != u32::cast_from(labels[label]) - 1 { + label = u32::cast_from(labels[label]) - 1; } if opts.area_enabled { - Atomic::add(&area[label], count); + Atomic::add(&area[label], I::cast_from(count)); } if opts.left_enabled { - Atomic::min(&left[label], x); + Atomic::min(&left[label], I::cast_from(x)); } if opts.top_enabled { - Atomic::min(&top[label], y); + Atomic::min(&top[label], I::cast_from(y)); } if opts.right_enabled { - Atomic::max(&right[label], max_x); + Atomic::max(&right[label], I::cast_from(max_x)); } if opts.bottom_enabled { - Atomic::max(&bottom[label], y); + Atomic::max(&bottom[label], I::cast_from(y)); } } label = plane_broadcast(label, UNIT_POS_X - s_dist); if p { - labels[labels_index] = label + 1; + labels[labels_index] = I::cast_from(label + 1); } } } @@ -427,11 +426,9 @@ pub fn hardware_accelerated(client.clone(), device.clone(), shape); + let labels = zeros_device::(client.clone(), device.clone(), img.shape.clone()); // Assume 32 wide warp. Currently, larger warps are handled by just exiting everything past 32. // This isn't ideal but we require CUBE_DIM_X == warp_size, and we can't query the actual warp @@ -441,7 +438,7 @@ pub fn hardware_accelerated( + strip_labeling::launch::( &client, cube_count, cube_dim, @@ -458,7 +455,7 @@ pub fn hardware_accelerated( + strip_merge::launch::( &client, cube_count, cube_dim_merge, @@ -476,7 +473,7 @@ pub fn hardware_accelerated( + relabeling::launch::( &client, cube_count, cube_dim, @@ -484,7 +481,7 @@ pub fn hardware_accelerated(1), ); } else { - analysis::launch::( + analysis::launch::( &client, cube_count, cube_dim, diff --git a/crates/burn-vision/src/backends/jit/connected_components/mod.rs b/crates/burn-vision/src/backends/jit/connected_components/mod.rs index e8c740695b..4ca53c936c 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/mod.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/mod.rs @@ -23,7 +23,8 @@ where let [batches, height, width] = l.shape.dims(); let shape = Shape::new([batches, height * width]); let zeros = || zeros_device::(l.client.clone(), l.device.clone(), shape.clone()); - let max = || full_device::(l.client.clone(), shape.clone(), l.device.clone(), u32::MAX); + let max = I::max_value(); + let max = || full_device::(l.client.clone(), shape.clone(), l.device.clone(), max); let dummy = || { JitTensor::new_contiguous( l.client.clone(), diff --git a/crates/burn-vision/src/lib.rs b/crates/burn-vision/src/lib.rs index be3e9b2684..a384cde51f 100644 --- a/crates/burn-vision/src/lib.rs +++ b/crates/burn-vision/src/lib.rs @@ -1,10 +1,24 @@ +//! Vision ops for burn, with GPU acceleration where possible. +//! +//! # Operations +//! Operation names are based on `opencv` wherever applicable. +//! +//! Currently implemented are: +//! - `connected_components` +//! - `connected_components_with_stats` +//! + +#![warn(missing_docs)] + extern crate alloc; +/// Backend implementations for JIT and CPU pub mod backends; mod ops; mod tensor; #[cfg(feature = "export-tests")] +#[allow(missing_docs)] mod tests; pub use ops::*; diff --git a/crates/burn-vision/src/ops/base.rs b/crates/burn-vision/src/ops/base.rs index ddbd479507..7d20e88ec9 100644 --- a/crates/burn-vision/src/ops/base.rs +++ b/crates/burn-vision/src/ops/base.rs @@ -5,35 +5,61 @@ use burn_tensor::{ Int, Tensor, }; +/// Connected components connectivity #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum Connectivity { + /// Four-connected (only connected in cardinal directions) Four, + /// Eight-connected (connected if any of the surrounding 8 pixels are in the foreground) Eight, } +/// Which stats should be enabled for `connected_components_with_stats`. +/// Currently only used by the GPU implementation to save on atomic operations for unneeded stats. +/// +/// Disabled stats are aliased to the labels tensor #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct ConnectedStatsOptions { + /// Whether to collect the area of each component pub area_enabled: bool, + /// Whether to find the top (minimum y) of each component pub top_enabled: bool, + /// Whether to find the left (minimum x) of each component pub left_enabled: bool, + /// Whether to find the right (max x) of each component pub right_enabled: bool, + /// Whether to find the bottom (max y) of each component pub bottom_enabled: bool, } +/// Stats collected by the connected components analysis +/// +/// Disabled analyses may be aliased to labels #[derive(Clone, Debug)] pub struct ConnectedStats { + /// Total area of each component pub area: Tensor, + /// Topmost y coordinate in the component pub top: Tensor, + /// Leftmost x coordinate in the component pub left: Tensor, + /// Rightmost x coordinate in the component pub right: Tensor, + /// Bottommost y coordinate in the component pub bottom: Tensor, } +/// Primitive version of [`ConnectedStats`], to be returned by the backend pub struct ConnectedStatsPrimitive { + /// Total area of each component pub area: IntTensor, + /// Leftmost x coordinate in the component pub left: IntTensor, + /// Topmost y coordinate in the component pub top: IntTensor, + /// Rightmost x coordinate in the component pub right: IntTensor, + /// Bottommost y coordinate in the component pub bottom: IntTensor, } @@ -56,6 +82,7 @@ impl Default for ConnectedStatsOptions { } impl ConnectedStatsOptions { + /// Don't collect any stats pub fn none() -> Self { Self { area_enabled: false, @@ -66,6 +93,7 @@ impl ConnectedStatsOptions { } } + /// Collect all stats pub fn all() -> Self { Self { area_enabled: true, @@ -77,11 +105,21 @@ impl ConnectedStatsOptions { } } +/// Vision operations, implemented by each backend pub trait VisionOps { + /// Computes the connected components labeled image of boolean image with 4 or 8 way + /// connectivity - returns a tensor of the component label of each pixel. + /// + /// `img`- The boolean image tensor in the format [batches, height, width] fn connected_components(img: BoolTensor, connectivity: Connectivity) -> IntTensor { cpu::connected_components::(img, connectivity) } + /// Computes the connected components labeled image of boolean image with 4 or 8 way + /// connectivity and collects statistics on each component - returns a tensor of the component + /// label of each pixel, along with stats collected for each component. + /// + /// `img`- The boolean image tensor in the format [batches, height, width] fn connected_components_with_stats( img: BoolTensor, connectivity: Connectivity, diff --git a/crates/burn-vision/src/tensor.rs b/crates/burn-vision/src/tensor.rs index c7432b7d5f..5b381170a9 100644 --- a/crates/burn-vision/src/tensor.rs +++ b/crates/burn-vision/src/tensor.rs @@ -2,8 +2,19 @@ use burn_tensor::{backend::Backend, Bool, Int, Tensor}; use crate::{ConnectedStats, ConnectedStatsOptions, Connectivity, VisionOps}; +/// Connected components tensor extensions pub trait ConnectedComponents { + /// Computes the connected components labeled image of boolean image with 4 or 8 way + /// connectivity - returns a tensor of the component label of each pixel. + /// + /// `img`- The boolean image tensor in the format [batches, height, width] fn connected_components(self, connectivity: Connectivity) -> Tensor; + + /// Computes the connected components labeled image of boolean image with 4 or 8 way + /// connectivity and collects statistics on each component - returns a tensor of the component + /// label of each pixel, along with stats collected for each component. + /// + /// `img`- The boolean image tensor in the format [batches, height, width] fn connected_components_with_stats( self, connectivity: Connectivity, @@ -11,7 +22,7 @@ pub trait ConnectedComponents { ) -> (Tensor, ConnectedStats); } -impl> ConnectedComponents for Tensor { +impl> ConnectedComponents for Tensor { fn connected_components(self, connectivity: Connectivity) -> Tensor { Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity)) } diff --git a/crates/burn-vision/src/tests/connected_components.rs b/crates/burn-vision/src/tests/connected_components.rs index c299769774..8ebc32dfba 100644 --- a/crates/burn-vision/src/tests/connected_components.rs +++ b/crates/burn-vision/src/tests/connected_components.rs @@ -22,7 +22,7 @@ mod tests { #[test] fn should_support_8_connectivity() { - let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<4>(); + let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<3>(); let output = tensor.connected_components(Connectivity::Eight); let expected = space_invader(); // All pixels are in the same group for 8-connected @@ -33,7 +33,7 @@ mod tests { #[test] fn should_support_8_connectivity_with_stats() { - let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<4>(); + let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<3>(); let (output, stats) = tensor .connected_components_with_stats(Connectivity::Eight, ConnectedStatsOptions::all()); @@ -59,7 +59,7 @@ mod tests { #[test] fn should_support_4_connectivity() { - let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<4>(); + let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<3>(); let output = tensor.connected_components(Connectivity::Four); let expected = as_type!(IntType: [ @@ -80,7 +80,7 @@ mod tests { #[test] fn should_support_4_connectivity_with_stats() { - let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<4>(); + let tensor = TestTensorBool::<2>::from(space_invader()).unsqueeze::<3>(); let (output, stats) = tensor .connected_components_with_stats(Connectivity::Four, ConnectedStatsOptions::all()); From 021360bb8b41002b1bdf266ba8e24cf4157d7e7a Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Fri, 31 Jan 2025 18:31:17 +0100 Subject: [PATCH 15/24] Update cubecl --- Cargo.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 969ad44a77..3ecd6c856b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } -# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e0734dadca994b02b7dce3b77a575edb1fb2232e" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e0734dadca994b02b7dce3b77a575edb1fb2232e" } ### For local development. ### -cubecl = { path = "../cubecl/crates/cubecl", default-features = false } -cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } +# cubecl = { path = "../cubecl/crates/cubecl", default-features = false } +# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } ### For the release. ### # cubecl = { version = "0.4.0", default-features = false } # cubecl-common = { version = "0.4.0", default-features = false } From 01ff01b378737e6107c4265b7b3ec3e677d74cb7 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Sat, 1 Feb 2025 17:21:32 +0100 Subject: [PATCH 16/24] Compact labels for JIT --- crates/burn-jit/src/kernel/index/mod.rs | 2 +- crates/burn-jit/src/kernel/index/slice.rs | 3 +- crates/burn-jit/src/kernel/mod.rs | 2 +- crates/burn-vision/Cargo.toml | 2 +- .../src/backends/cpu/connected_components.rs | 3 + crates/burn-vision/src/backends/cpu/ops.rs | 1 + .../hardware_accelerated.rs | 165 ++++++++++++++---- .../backends/jit/connected_components/mod.rs | 13 +- crates/burn-vision/src/backends/jit/ops.rs | 7 +- crates/burn-vision/src/ops/base.rs | 35 ++-- 10 files changed, 172 insertions(+), 61 deletions(-) diff --git a/crates/burn-jit/src/kernel/index/mod.rs b/crates/burn-jit/src/kernel/index/mod.rs index 828c39c50c..83ce64aff8 100644 --- a/crates/burn-jit/src/kernel/index/mod.rs +++ b/crates/burn-jit/src/kernel/index/mod.rs @@ -11,7 +11,7 @@ pub(crate) use flip::*; pub(crate) use repeat_dim::*; pub(crate) use select::*; pub(crate) use select_assign::*; -pub(crate) use slice::*; +pub use slice::*; pub(crate) use slice_assign::*; pub(crate) use gather::*; diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index b6daba8da5..bca8e00dd9 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -3,7 +3,8 @@ use burn_tensor::Shape; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use std::ops::Range; -pub(crate) fn slice( +/// Slice a jit tensor with a set of ranges +pub fn slice( tensor: JitTensor, indices: &[Range], ) -> JitTensor { diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index 93d2833976..e1a0a3158e 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -39,4 +39,4 @@ pub mod reduce; pub(crate) use clamp::*; pub(crate) use comparison::*; -pub(crate) use index::*; +pub use index::*; diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index 03f9ad0eb9..a6c0d573a0 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -15,7 +15,7 @@ version.workspace = true [features] autodiff = ["burn-autodiff"] candle = ["burn-candle"] -default = ["ndarray", "autodiff", "jit-backend"] +default = ["ndarray", "autodiff", "jit-backend", "fusion"] export-tests = ["burn-tensor-testgen"] fusion = ["burn-fusion", "burn-cuda/fusion", "burn-wgpu/fusion"] jit-backend = ["cubecl", "burn-jit"] diff --git a/crates/burn-vision/src/backends/cpu/connected_components.rs b/crates/burn-vision/src/backends/cpu/connected_components.rs index fff61aeeaa..71771b7eff 100644 --- a/crates/burn-vision/src/backends/cpu/connected_components.rs +++ b/crates/burn-vision/src/backends/cpu/connected_components.rs @@ -189,8 +189,10 @@ fn finalize_stats( let mut top = Vec::with_capacity(batches * max_len); let mut right = Vec::with_capacity(batches * max_len); let mut bottom = Vec::with_capacity(batches * max_len); + let mut max_label = Vec::with_capacity(batches); for mut stats in stats { + max_label.push(stats.area.len() as u32 - 1); stats.area.resize(max_len, 0); stats.left.resize(max_len, 0); stats.top.resize(max_len, 0); @@ -215,6 +217,7 @@ fn finalize_stats( top: into_prim(top), right: into_prim(right), bottom: into_prim(bottom), + max_label: into_prim(max_label), } } diff --git a/crates/burn-vision/src/backends/cpu/ops.rs b/crates/burn-vision/src/backends/cpu/ops.rs index c682703106..31973a3667 100644 --- a/crates/burn-vision/src/backends/cpu/ops.rs +++ b/crates/burn-vision/src/backends/cpu/ops.rs @@ -45,6 +45,7 @@ impl, C: CheckpointStrategy> Vis top: stats.top, right: stats.right, bottom: stats.bottom, + max_label: stats.max_label, }; (labels, stats) } diff --git a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs index c3dc243686..931320071c 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs @@ -8,10 +8,13 @@ use crate::{ ConnectedStatsPrimitive, Connectivity, }; use burn_jit::{ - ops::numeric::zeros_device, tensor::JitTensor, BoolElement, FloatElement, IntElement, - JitBackend, JitRuntime, + kernel, + ops::{into_data_sync, numeric::zeros_device}, + tensor::JitTensor, + BoolElement, FloatElement, IntElement, JitBackend, JitRuntime, }; -use cubecl::{prelude::*, Feature}; +use burn_tensor::ops::IntTensorOps; +use cubecl::{calculate_cube_count_elemwise, prelude::*, Feature}; const BLOCK_H: u32 = 4; @@ -342,6 +345,7 @@ fn analysis( left: &mut Tensor>, right: &mut Tensor>, bottom: &mut Tensor>, + max_label: &mut Tensor>, #[comptime] opts: ConnectedStatsOptions, ) { let batch = ABSOLUTE_POS_Z; @@ -352,6 +356,7 @@ fn analysis( let rows = labels.shape(1); let img_step = img.stride(1); let labels_step = labels.stride(1); + let b_offs = batch * labels.stride(0); if x < cols && y < rows { let mut mask = 0xffffffffu32; @@ -359,8 +364,8 @@ fn analysis( mask >>= 32 - (cols - CUBE_POS_X * CUBE_DIM_X); } - let img_index = batch * img.stride(0) + y * img_step + x; - let labels_index = batch * labels.stride(0) + y * labels_step + x; + let img_index = b_offs + y * img_step + x; + let labels_index = b_offs + y * labels_step + x; let p = bool::cast_from(img[img_index]); let pixels = ballot_dyn(UNIT_POS_Y, p) & mask; @@ -372,24 +377,20 @@ fn analysis( if p && s_dist == 0 { label = u32::cast_from(labels[labels_index]) - 1; - while label != u32::cast_from(labels[label]) - 1 { - label = u32::cast_from(labels[label]) - 1; + while label != u32::cast_from(labels[b_offs + label]) - 1 { + label = u32::cast_from(labels[b_offs + label]) - 1; } - if opts.area_enabled { - Atomic::add(&area[label], I::cast_from(count)); - } - if opts.left_enabled { - Atomic::min(&left[label], I::cast_from(x)); - } - if opts.top_enabled { - Atomic::min(&top[label], I::cast_from(y)); - } - if opts.right_enabled { - Atomic::max(&right[label], I::cast_from(max_x)); + Atomic::add(&area[b_offs + label], I::cast_from(count)); + + if opts.bounds_enabled { + Atomic::min(&left[b_offs + label], I::cast_from(x)); + Atomic::min(&top[b_offs + label], I::cast_from(y)); + Atomic::max(&right[b_offs + label], I::cast_from(max_x)); + Atomic::max(&bottom[b_offs + label], I::cast_from(y)); } - if opts.bottom_enabled { - Atomic::max(&bottom[label], I::cast_from(y)); + if comptime!(opts.max_label_enabled || opts.compact_labels) { + Atomic::max(&max_label[batch], I::cast_from(label)); } } @@ -401,6 +402,60 @@ fn analysis( } } +#[cube(launch)] +fn compact_labels(labels: &mut Tensor, remap: &Tensor) { + let batch = ABSOLUTE_POS_Z; + let x = ABSOLUTE_POS_X; + let y = ABSOLUTE_POS_Y; + + let labels_pos = batch * labels.stride(0) + y * labels.stride(1) + x * labels.stride(2); + + if labels_pos >= labels.len() { + terminate!(); + } + + let label = u32::cast_from(labels[labels_pos]); + if label != 0 { + labels[labels_pos] = remap[label]; + } +} + +#[cube(launch)] +fn compact_stats( + area: &Tensor, + area_new: &mut Tensor, + top: &Tensor, + top_new: &mut Tensor, + left: &Tensor, + left_new: &mut Tensor, + right: &Tensor, + right_new: &mut Tensor, + bottom: &Tensor, + bottom_new: &mut Tensor, + remap: &Tensor, + max_label: u32, + #[comptime] opts: ConnectedStatsOptions, +) { + let label = ABSOLUTE_POS_X; + if label > max_label { + terminate!(); + } + + let area = area[label]; + if area == I::new(0) { + terminate!(); + } + let new_label = u32::cast_from(remap[label]); + + area_new[new_label] = area; + if opts.bounds_enabled { + top_new[new_label] = top[label]; + left_new[new_label] = left[label]; + right_new[new_label] = right[label]; + bottom_new[new_label] = bottom[label]; + } +} + #[allow(clippy::type_complexity)] pub fn hardware_accelerated( img: JitTensor, @@ -442,8 +497,8 @@ pub fn hardware_accelerated(1), - labels.as_tensor_arg::(1), + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), connectivity, ); @@ -459,8 +514,8 @@ pub fn hardware_accelerated(1), - labels.as_tensor_arg::(1), + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), connectivity, ); @@ -477,23 +532,67 @@ pub fn hardware_accelerated(1), - labels.as_tensor_arg::(1), + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), ); } else { analysis::launch::( &client, cube_count, cube_dim, - img.as_tensor_arg::(1), - labels.as_tensor_arg::(1), - stats.area.as_tensor_arg::(1), - stats.top.as_tensor_arg::(1), - stats.left.as_tensor_arg::(1), - stats.right.as_tensor_arg::(1), - stats.bottom.as_tensor_arg::(1), + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + stats.area.as_tensor_arg::(1), + stats.top.as_tensor_arg::(1), + stats.left.as_tensor_arg::(1), + stats.right.as_tensor_arg::(1), + stats.bottom.as_tensor_arg::(1), + stats.max_label.as_tensor_arg::(1), stats_opt, ); + if stats_opt.compact_labels { + let max_labels = into_data_sync::(stats.max_label.clone()).convert::(); + let max_label = *max_labels.as_slice::().unwrap().iter().max().unwrap() as usize; + let sliced = kernel::slice::(stats.area.clone(), &[0..batches, 0..max_label + 1]); + let present = JitBackend::::int_not_equal_elem(sliced, I::new(0)); + let relabel = JitBackend::::int_prefix_sum(present); + + let cube_dim = CubeDim::default(); + let cube_count = CubeCount::new_3d( + (cols as u32).div_ceil(cube_dim.x), + (rows as u32).div_ceil(cube_dim.y), + batches as u32, + ); + compact_labels::launch( + &client, + cube_count, + cube_dim, + labels.as_tensor_arg::(1), + relabel.as_tensor_arg::(1), + ); + + let cube_dim = CubeDim::new_1d(256); + let cube_count = + CubeCount::new_3d((rows * cols).div_ceil(256) as u32, 1, batches as u32); + compact_stats::launch( + &client, + cube_count, + cube_dim, + stats.area.copy().as_tensor_arg::(1), + stats.area.as_tensor_arg::(1), + stats.top.copy().as_tensor_arg::(1), + stats.top.as_tensor_arg::(1), + stats.left.copy().as_tensor_arg::(1), + stats.left.as_tensor_arg::(1), + stats.right.copy().as_tensor_arg::(1), + stats.right.as_tensor_arg::(1), + stats.bottom.copy().as_tensor_arg::(1), + stats.bottom.as_tensor_arg::(1), + relabel.as_tensor_arg::(1), + ScalarArg::new(max_label as u32), + stats_opt, + ); + } } Ok((labels, stats)) diff --git a/crates/burn-vision/src/backends/jit/connected_components/mod.rs b/crates/burn-vision/src/backends/jit/connected_components/mod.rs index 4ca53c936c..af0ed730e2 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/mod.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/mod.rs @@ -35,10 +35,13 @@ where ) }; ConnectedStatsPrimitive { - area: opts.area_enabled.then(zeros).unwrap_or_else(dummy), - left: opts.left_enabled.then(max).unwrap_or_else(dummy), - top: opts.top_enabled.then(max).unwrap_or_else(dummy), - right: opts.right_enabled.then(zeros).unwrap_or_else(dummy), - bottom: opts.bottom_enabled.then(zeros).unwrap_or_else(dummy), + area: (opts != ConnectedStatsOptions::none()) + .then(zeros) + .unwrap_or_else(dummy), + left: opts.bounds_enabled.then(max).unwrap_or_else(dummy), + top: opts.bounds_enabled.then(max).unwrap_or_else(dummy), + right: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy), + bottom: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy), + max_label: zeros_device::(l.client.clone(), l.device.clone(), Shape::new([1])), } } diff --git a/crates/burn-vision/src/backends/jit/ops.rs b/crates/burn-vision/src/backends/jit/ops.rs index 0ac628f77f..ed0c7afc5f 100644 --- a/crates/burn-vision/src/backends/jit/ops.rs +++ b/crates/burn-vision/src/backends/jit/ops.rs @@ -109,7 +109,8 @@ impl> VisionOps for Fusion { self: Box, handles: &mut HandleContainer<::FusionHandle>, ) { - let ([img], [labels, area, left, top, right, bottom]) = self.desc.consume(); + let ([img], [labels, area, left, top, right, bottom, max_label]) = + self.desc.consume(); let input = handles.get_bool_tensor::(&img); let (output, stats) = B1::connected_components_with_stats(input, self.conn, self.opts); @@ -120,6 +121,7 @@ impl> VisionOps for Fusion { handles.register_int_tensor::(&top.id, stats.top); handles.register_int_tensor::(&right.id, stats.right); handles.register_int_tensor::(&bottom.id, stats.bottom); + handles.register_int_tensor::(&max_label.id, stats.max_label); } } @@ -131,6 +133,7 @@ impl> VisionOps for Fusion { let right = client.tensor_uninitialized(vec![batches, height * width], B::IntElem::dtype()); let bottom = client.tensor_uninitialized(vec![batches, height * width], B::IntElem::dtype()); + let max_label = client.tensor_uninitialized(vec![batches], B::IntElem::dtype()); let desc = CustomOpDescription::new( "connected_components", @@ -142,6 +145,7 @@ impl> VisionOps for Fusion { top.to_description_out(), right.to_description_out(), bottom.to_description_out(), + max_label.to_description_out(), ], ); client.register( @@ -156,6 +160,7 @@ impl> VisionOps for Fusion { top, right, bottom, + max_label, }; (out, stats) } diff --git a/crates/burn-vision/src/ops/base.rs b/crates/burn-vision/src/ops/base.rs index 7d20e88ec9..885354f633 100644 --- a/crates/burn-vision/src/ops/base.rs +++ b/crates/burn-vision/src/ops/base.rs @@ -20,16 +20,12 @@ pub enum Connectivity { /// Disabled stats are aliased to the labels tensor #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct ConnectedStatsOptions { - /// Whether to collect the area of each component - pub area_enabled: bool, - /// Whether to find the top (minimum y) of each component - pub top_enabled: bool, - /// Whether to find the left (minimum x) of each component - pub left_enabled: bool, - /// Whether to find the right (max x) of each component - pub right_enabled: bool, - /// Whether to find the bottom (max y) of each component - pub bottom_enabled: bool, + /// Whether to enable bounding boxes + pub bounds_enabled: bool, + /// Whether to enable the max label + pub max_label_enabled: bool, + /// Whether labels must be contiguous starting at 1 + pub compact_labels: bool, } /// Stats collected by the connected components analysis @@ -47,6 +43,8 @@ pub struct ConnectedStats { pub right: Tensor, /// Bottommost y coordinate in the component pub bottom: Tensor, + /// Scalar tensor of the max label + pub max_label: Tensor, } /// Primitive version of [`ConnectedStats`], to be returned by the backend @@ -61,6 +59,8 @@ pub struct ConnectedStatsPrimitive { pub right: IntTensor, /// Bottommost y coordinate in the component pub bottom: IntTensor, + /// Scalar tensor of the max label + pub max_label: IntTensor, } impl From> for ConnectedStats { @@ -71,6 +71,7 @@ impl From> for ConnectedStats { left: Tensor::from_primitive(value.left), right: Tensor::from_primitive(value.right), bottom: Tensor::from_primitive(value.bottom), + max_label: Tensor::from_primitive(value.max_label), } } } @@ -86,10 +87,9 @@ impl ConnectedStatsOptions { pub fn none() -> Self { Self { area_enabled: false, - top_enabled: false, - left_enabled: false, - right_enabled: false, - bottom_enabled: false, + bounds_enabled: false, + max_label_enabled: false, + compact_labels: false, } } @@ -97,10 +97,9 @@ impl ConnectedStatsOptions { pub fn all() -> Self { Self { area_enabled: true, - top_enabled: true, - left_enabled: true, - right_enabled: true, - bottom_enabled: true, + bounds_enabled: true, + max_label_enabled: true, + compact_labels: true, } } } From d790113b3d50a0d46b8d899ef8709001066522cc Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Sun, 2 Feb 2025 16:41:17 +0100 Subject: [PATCH 17/24] Improve JIT backend implementation by adding label compaction --- Cargo.lock | 15 -- Cargo.toml | 8 +- crates/burn-jit/src/ops/base.rs | 3 +- crates/burn-vision/Cargo.toml | 2 +- .../hardware_accelerated.rs | 56 ++-- crates/burn-vision/src/backends/jit/mod.rs | 4 + crates/burn-vision/src/backends/jit/ops.rs | 8 +- .../src/backends/jit/prefix_sum.rs | 254 ++++++++++++++++++ crates/burn-vision/src/ops/base.rs | 2 - .../src/tests/connected_components.rs | 37 ++- crates/burn-vision/tests/main.rs | 7 + 11 files changed, 332 insertions(+), 64 deletions(-) create mode 100644 crates/burn-vision/src/backends/jit/prefix_sum.rs diff --git a/Cargo.lock b/Cargo.lock index 5a53f78f0e..9a34ce5455 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1495,7 +1495,6 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1510,7 +1509,6 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1531,7 +1529,6 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1552,7 +1549,6 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "bytemuck", "cubecl-common", @@ -1566,7 +1562,6 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "bytemuck", "cubecl-common", @@ -1582,7 +1577,6 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "bytemuck", "cubecl-common", @@ -1608,7 +1602,6 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -1626,7 +1619,6 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "bytemuck", "cubecl-core", @@ -1638,7 +1630,6 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "cubecl-common", "darling", @@ -1653,7 +1644,6 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "darling", "proc-macro2", @@ -1664,7 +1654,6 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1680,7 +1669,6 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1690,7 +1678,6 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "async-channel", "async-lock", @@ -1712,7 +1699,6 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1727,7 +1713,6 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=e0734dadca994b02b7dce3b77a575edb1fb2232e#e0734dadca994b02b7dce3b77a575edb1fb2232e" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 164fd5ab29..1d713e5540 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e0734dadca994b02b7dce3b77a575edb1fb2232e" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e0734dadca994b02b7dce3b77a575edb1fb2232e" } +# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e0734dadca994b02b7dce3b77a575edb1fb2232e" } +# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e0734dadca994b02b7dce3b77a575edb1fb2232e" } ### For local development. ### -# cubecl = { path = "../cubecl/crates/cubecl", default-features = false } -# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } +cubecl = { path = "../cubecl/crates/cubecl", default-features = false } +cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } ### For the release. ### # cubecl = { version = "0.4.0", default-features = false } # cubecl-common = { version = "0.4.0", default-features = false } diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index 9327e1fc92..112a11de33 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -136,7 +136,8 @@ pub(crate) fn expand(tensor: JitTensor, target_shape: Shape) - } } -pub(crate) fn reshape(tensor: JitTensor, shape: Shape) -> JitTensor { +/// Reshape a jit tensor to a new shape +pub fn reshape(tensor: JitTensor, shape: Shape) -> JitTensor { // TODO: Not force standard layout all the time (improve performance). let tensor = kernel::into_contiguous(tensor); diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index a6c0d573a0..2ee8935af8 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -25,7 +25,7 @@ tch = ["burn-tch"] # Test features cpu = ["export-tests"] cuda = ["jit-backend", "export-tests"] -vulkan = ["burn-wgpu/vulkan", "wgpu"] +vulkan = ["burn-wgpu/vulkan", "jit-backend", "export-tests"] wgpu = ["jit-backend", "export-tests"] [dependencies] diff --git a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs index 931320071c..6ed446a1f5 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs @@ -4,8 +4,8 @@ //! DASIP, 2018 use crate::{ - backends::jit::connected_components::stats_from_opts, ConnectedStatsOptions, - ConnectedStatsPrimitive, Connectivity, + backends::jit::{connected_components::stats_from_opts, prefix_sum::prefix_sum}, + ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, }; use burn_jit::{ kernel, @@ -13,8 +13,8 @@ use burn_jit::{ tensor::JitTensor, BoolElement, FloatElement, IntElement, JitBackend, JitRuntime, }; -use burn_tensor::ops::IntTensorOps; -use cubecl::{calculate_cube_count_elemwise, prelude::*, Feature}; +use burn_tensor::{ops::IntTensorOps, Shape}; +use cubecl::{prelude::*, Feature}; const BLOCK_H: u32 = 4; @@ -380,6 +380,7 @@ fn analysis( while label != u32::cast_from(labels[b_offs + label]) - 1 { label = u32::cast_from(labels[b_offs + label]) - 1; } + label += 1; Atomic::add(&area[b_offs + label], I::cast_from(count)); @@ -397,13 +398,17 @@ fn analysis( label = plane_broadcast(label, UNIT_POS_X - s_dist); if p { - labels[labels_index] = I::cast_from(label + 1); + labels[labels_index] = I::cast_from(label); } } } #[cube(launch)] -fn compact_labels(labels: &mut Tensor, remap: &Tensor) { +fn compact_labels( + labels: &mut Tensor, + remap: &Tensor, + max_label: &Tensor>, +) { let batch = ABSOLUTE_POS_Z; let x = ABSOLUTE_POS_X; let y = ABSOLUTE_POS_Y; @@ -416,7 +421,9 @@ fn compact_labels(labels: &mut Tensor, remap: &Tensor) { let label = u32::cast_from(labels[labels_pos]); if label != 0 { - labels[labels_pos] = remap[label]; + let new_label = remap[label]; + labels[labels_pos] = new_label; + Atomic::max(&max_label[batch], new_label); } } @@ -433,11 +440,9 @@ fn compact_stats( bottom: &Tensor, bottom_new: &mut Tensor, remap: &Tensor, - max_label: u32, - #[comptime] opts: ConnectedStatsOptions, ) { let label = ABSOLUTE_POS_X; - if label > max_label { + if label >= remap.len() { terminate!(); } @@ -448,12 +453,12 @@ fn compact_stats( let new_label = u32::cast_from(remap[label]); area_new[new_label] = area; - if opts.bounds_enabled { - top_new[new_label] = top[label]; - left_new[new_label] = left[label]; - right_new[new_label] = right[label]; - bottom_new[new_label] = bottom[label]; - } + // This should be gated but there's a problem with the Eq bound only being implemented for tuples + // up to 12 elems, so I can't pass the opts. It's not unsafe, but potentially unnecessary work. + top_new[new_label] = top[label]; + left_new[new_label] = left[label]; + right_new[new_label] = right[label]; + bottom_new[new_label] = bottom[label]; } #[allow(clippy::type_complexity)] @@ -525,7 +530,7 @@ pub fn hardware_accelerated( @@ -553,9 +558,13 @@ pub fn hardware_accelerated(stats.max_label.clone()).convert::(); let max_label = *max_labels.as_slice::().unwrap().iter().max().unwrap() as usize; - let sliced = kernel::slice::(stats.area.clone(), &[0..batches, 0..max_label + 1]); + let sliced = kernel::slice::( + stats.area.clone(), + &[0..batches, 0..(max_label + 1).next_multiple_of(4)], + ); let present = JitBackend::::int_not_equal_elem(sliced, I::new(0)); - let relabel = JitBackend::::int_prefix_sum(present); + let present = kernel::cast::(present); + let relabel = prefix_sum::(present); let cube_dim = CubeDim::default(); let cube_count = CubeCount::new_3d( @@ -563,18 +572,21 @@ pub fn hardware_accelerated(client.clone(), device.clone(), Shape::new([batches])); + compact_labels::launch::( &client, cube_count, cube_dim, labels.as_tensor_arg::(1), relabel.as_tensor_arg::(1), + stats.max_label.as_tensor_arg::(1), ); let cube_dim = CubeDim::new_1d(256); let cube_count = CubeCount::new_3d((rows * cols).div_ceil(256) as u32, 1, batches as u32); - compact_stats::launch( + compact_stats::launch::( &client, cube_count, cube_dim, @@ -589,8 +601,6 @@ pub fn hardware_accelerated(1), stats.bottom.as_tensor_arg::(1), relabel.as_tensor_arg::(1), - ScalarArg::new(max_label as u32), - stats_opt, ); } } diff --git a/crates/burn-vision/src/backends/jit/mod.rs b/crates/burn-vision/src/backends/jit/mod.rs index 9d610df49a..8666645c48 100644 --- a/crates/burn-vision/src/backends/jit/mod.rs +++ b/crates/burn-vision/src/backends/jit/mod.rs @@ -1,2 +1,6 @@ mod connected_components; mod ops; + +/// Should eventually make this a full op, but the kernel is too specialized on ints and plane ops +/// to really use it in a general case. Needs more work to use as a normal tensor method. +mod prefix_sum; diff --git a/crates/burn-vision/src/backends/jit/ops.rs b/crates/burn-vision/src/backends/jit/ops.rs index ed0c7afc5f..8d75b37f8d 100644 --- a/crates/burn-vision/src/backends/jit/ops.rs +++ b/crates/burn-vision/src/backends/jit/ops.rs @@ -45,8 +45,8 @@ where impl> VisionOps for Fusion { fn connected_components(img: BoolTensor, conn: Connectivity) -> IntTensor { let batches = img.shape[0]; - let height = img.shape[2]; - let width = img.shape[3]; + let height = img.shape[1]; + let width = img.shape[2]; let client = img.client.clone(); #[derive(derive_new::new)] @@ -92,8 +92,8 @@ impl> VisionOps for Fusion { opts: ConnectedStatsOptions, ) -> (IntTensor, ConnectedStatsPrimitive) { let batches = img.shape[0]; - let height = img.shape[2]; - let width = img.shape[3]; + let height = img.shape[1]; + let width = img.shape[2]; let client = img.client.clone(); #[derive(derive_new::new)] diff --git a/crates/burn-vision/src/backends/jit/prefix_sum.rs b/crates/burn-vision/src/backends/jit/prefix_sum.rs new file mode 100644 index 0000000000..38bd4f49d7 --- /dev/null +++ b/crates/burn-vision/src/backends/jit/prefix_sum.rs @@ -0,0 +1,254 @@ +use burn_tensor::Shape; +use cubecl::prelude::*; + +use burn_jit::{ + ops::{ + numeric::{empty_device, zeros_device}, + reshape, + }, + tensor::JitTensor, + IntElement, JitRuntime, +}; + +const CUBE_SIZE: u32 = 256; +const MIN_SUBGROUP_SIZE: u32 = 4; +const MAX_REDUCE_SIZE: u32 = CUBE_SIZE / MIN_SUBGROUP_SIZE; + +const PART_SIZE: u32 = 4096; + +#[cube(launch_unchecked)] +fn prefix_sum_kernel( + scan_in: &Tensor>, + scan_out: &mut Tensor>, + scan_bump: &Tensor>, + reduction: &Tensor>, + cube_count_x: u32, +) { + let mut broadcast = SharedMemory::::new(1); + let mut reduce = SharedMemory::::new(MAX_REDUCE_SIZE); + let batch = CUBE_POS_Z; + let vec4_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.line_size()); + let nums_per_cube = CUBE_SIZE * vec4_spt; + + //acquire partition index + if UNIT_POS_X == 0 { + broadcast[0] = Atomic::add(&scan_bump[batch], I::new(1)); + } + sync_units(); + let part_id = u32::cast_from(broadcast[0]); + + let plane_id = UNIT_POS_X / PLANE_DIM; + let dev_offs = part_id * nums_per_cube; + let plane_offs = plane_id * PLANE_DIM * vec4_spt; + + // Exit if full plane is out of bounds + if dev_offs + plane_offs >= scan_in.shape(1) { + terminate!(); + } + + let zero = I::new(0); + + let flag_reduction = I::new(1); + let flag_inclusive = I::new(2); + let flag_mask = I::new(3); + + let red_offs = batch * reduction.stride(0); + let scan_offs = batch * scan_in.stride(0); + + let mut t_scan = Array::>::vectorized(vec4_spt, scan_in.line_size()); + { + let mut i = dev_offs + plane_offs + UNIT_POS_PLANE; + + if part_id < cube_count_x - 1 { + for k in 0..vec4_spt { + let mut scan = scan_in[i + scan_offs]; + let x = scan[0]; + scan[1] += x; + let y = scan[1]; + scan[2] += y; + let z = scan[2]; + scan[3] += z; + t_scan[k] = scan; + i += PLANE_DIM; + } + } + + if part_id == cube_count_x - 1 { + for k in 0..vec4_spt { + if i < scan_in.shape(1) { + let mut scan = scan_in[i + scan_offs]; + let x = scan[0]; + scan[1] += x; + let y = scan[1]; + scan[2] += y; + let z = scan[2]; + scan[3] += z; + t_scan[k] = scan; + } + i += PLANE_DIM; + } + } + + let mut prev = zero; + let plane_mask = PLANE_DIM - 1; + let circular_shift = (UNIT_POS_PLANE + plane_mask) & plane_mask; + for k in 0..vec4_spt { + let t = plane_broadcast(plane_inclusive_sum(t_scan[k][3]), circular_shift); + t_scan[k] += Line::cast_from(select(UNIT_POS_PLANE != 0, t, zero) + prev); + prev += plane_broadcast(t, 0); + } + + if UNIT_POS_PLANE == 0 { + reduce[plane_id] = prev; + } + } + sync_units(); + + //Non-divergent subgroup agnostic inclusive scan across subgroup reductions + let lane_log = count_trailing_zeros(PLANE_DIM); + let spine_size = CUBE_DIM >> lane_log; + { + let mut offset_0 = 0; + let mut offset_1 = 0; + let aligned_size = + 1 << ((count_trailing_zeros(spine_size) + lane_log + 1) / lane_log * lane_log); + let mut j = PLANE_DIM; + while j <= aligned_size { + let i_0 = ((UNIT_POS_X + offset_0) << offset_1) - offset_0; + let pred_0 = i_0 < spine_size; + let t_0 = plane_inclusive_sum(select(pred_0, reduce[i_0], zero)); + if pred_0 { + reduce[i_0] = t_0; + } + sync_units(); + + if j != PLANE_DIM { + let rshift = j >> lane_log; + let i_1 = UNIT_POS_X + rshift; + if (i_1 & (j - 1)) >= rshift { + let pred_1 = i_1 < spine_size; + let t_1 = select(pred_1, reduce[((i_1 >> offset_1) << offset_1) - 1], zero); + if pred_1 && ((i_1 + 1) & (rshift - 1)) != 0 { + reduce[i_1] += t_1; + } + } + } else { + offset_0 += 1; + } + offset_1 += lane_log; + + j <<= lane_log; + } + } + sync_units(); + + //Device broadcast + if UNIT_POS_X == 0 { + Atomic::store( + &reduction[part_id + red_offs], + (reduce[spine_size - 1] << I::new(2)) + | select(part_id != 0, flag_reduction, flag_inclusive), + ) + } + + //Lookback, single thread + if part_id != 0 { + if UNIT_POS_X == 0 { + let mut lookback_id = part_id - 1; + let mut prev_reduction = zero; + loop { + let flag_payload = Atomic::load(&reduction[lookback_id + red_offs]); + if (flag_payload & flag_mask) == flag_inclusive { + prev_reduction += flag_payload >> I::new(2); + Atomic::store( + &reduction[part_id + red_offs], + ((prev_reduction + reduce[spine_size - 1]) << I::new(2)) | flag_inclusive, + ); + broadcast[0] = prev_reduction; + break; + } + + if (flag_payload & flag_mask) == flag_reduction { + prev_reduction += flag_payload >> I::new(2); + lookback_id -= 1; + } + } + } + sync_units(); + } + + { + let prev = if plane_id != 0 { + reduce[plane_id - 1] + } else { + zero + }; + let prev = Line::cast_from(broadcast[0] + prev); + let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * vec4_spt; + let dev_offset = part_id * nums_per_cube; + let mut i = s_offset + dev_offset; + + if part_id < cube_count_x - 1 { + for k in 0..vec4_spt { + scan_out[i + scan_offs] = t_scan[k] + prev; + i += PLANE_DIM; + } + } + + if part_id == cube_count_x - 1 { + for k in 0..vec4_spt { + if i < scan_out.shape(1) { + scan_out[i + scan_offs] = t_scan[k] + prev; + } + i += PLANE_DIM; + } + } + } +} + +#[cube] +fn count_trailing_zeros(num: u32) -> u32 { + u32::find_first_set(num) - 1 +} + +/// Compute the prefix sum of a tensor +pub fn prefix_sum(input: JitTensor) -> JitTensor { + let client = input.client.clone(); + let device = input.device.clone(); + let num_elems = input.shape.num_elements() as u32; + let numbers = *input.shape.dims.last().unwrap() as u32; + let batches = num_elems / numbers; + + let input = reshape(input, Shape::new([batches as usize, numbers as usize])); + let out = empty_device::(client.clone(), device.clone(), input.shape.clone()); + + let cubes = numbers.div_ceil(PART_SIZE); + let cube_dim = CubeDim::new_1d(CUBE_SIZE); + let cube_count = CubeCount::new_3d(cubes, 1, batches); + + let bump = zeros_device::( + client.clone(), + device.clone(), + Shape::new([batches as usize]), + ); + let reduction = zeros_device::( + client.clone(), + device.clone(), + Shape::new([batches as usize, cubes as usize]), + ); + + unsafe { + prefix_sum_kernel::launch_unchecked::( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg::(4), + out.as_tensor_arg::(4), + bump.as_tensor_arg::(1), + reduction.as_tensor_arg::(1), + ScalarArg::new(cubes), + ) + }; + + out +} diff --git a/crates/burn-vision/src/ops/base.rs b/crates/burn-vision/src/ops/base.rs index 885354f633..f41f777405 100644 --- a/crates/burn-vision/src/ops/base.rs +++ b/crates/burn-vision/src/ops/base.rs @@ -86,7 +86,6 @@ impl ConnectedStatsOptions { /// Don't collect any stats pub fn none() -> Self { Self { - area_enabled: false, bounds_enabled: false, max_label_enabled: false, compact_labels: false, @@ -96,7 +95,6 @@ impl ConnectedStatsOptions { /// Collect all stats pub fn all() -> Self { Self { - area_enabled: true, bounds_enabled: true, max_label_enabled: true, compact_labels: true, diff --git a/crates/burn-vision/src/tests/connected_components.rs b/crates/burn-vision/src/tests/connected_components.rs index 8ebc32dfba..368054149c 100644 --- a/crates/burn-vision/src/tests/connected_components.rs +++ b/crates/burn-vision/src/tests/connected_components.rs @@ -40,21 +40,25 @@ mod tests { let expected = space_invader(); // All pixels are in the same group for 8-connected let expected = TestTensorInt::<2>::from(expected).unsqueeze::<3>(); - let (area, left, top, right, bottom) = normalize_stats( - stats.area.into_data(), - stats.left.into_data(), - stats.top.into_data(), - stats.right.into_data(), - stats.bottom.into_data(), + let (area, left, top, right, bottom) = ( + stats.area.slice([0..1, 1..2]).into_data(), + stats.left.slice([0..1, 1..2]).into_data(), + stats.top.slice([0..1, 1..2]).into_data(), + stats.right.slice([0..1, 1..2]).into_data(), + stats.bottom.slice([0..1, 1..2]).into_data(), ); - normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false); + output.into_data().assert_eq(&expected.into_data(), false); area.assert_eq(&TensorData::from([[58]]), false); left.assert_eq(&TensorData::from([[0]]), false); top.assert_eq(&TensorData::from([[0]]), false); right.assert_eq(&TensorData::from([[13]]), false); bottom.assert_eq(&TensorData::from([[8]]), false); + stats + .max_label + .into_data() + .assert_eq(&TensorData::from([1]), false); } #[test] @@ -97,21 +101,26 @@ mod tests { ]); let expected = TestTensorInt::<2>::from(expected).unsqueeze::<3>(); - let (area, left, top, right, bottom) = normalize_stats( - stats.area.into_data(), - stats.left.into_data(), - stats.top.into_data(), - stats.right.into_data(), - stats.bottom.into_data(), + // Slice off background and limit to compacted labels + let (area, left, top, right, bottom) = ( + stats.area.slice([0..1, 1..6]).into_data(), + stats.left.slice([0..1, 1..6]).into_data(), + stats.top.slice([0..1, 1..6]).into_data(), + stats.right.slice([0..1, 1..6]).into_data(), + stats.bottom.slice([0..1, 1..6]).into_data(), ); - normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false); + output.into_data().assert_eq(&expected.into_data(), false); area.assert_eq(&TensorData::from([[1, 1, 46, 5, 5]]), false); left.assert_eq(&TensorData::from([[3, 10, 1, 0, 12]]), false); top.assert_eq(&TensorData::from([[0, 0, 1, 5, 5]]), false); right.assert_eq(&TensorData::from([[3, 10, 12, 1, 13]]), false); bottom.assert_eq(&TensorData::from([[0, 0, 8, 7, 7]]), false); + stats + .max_label + .into_data() + .assert_eq(&TensorData::from([5]), false); } /// Normalize labels to sequential since actual labels aren't required to be contiguous and diff --git a/crates/burn-vision/tests/main.rs b/crates/burn-vision/tests/main.rs index f31361f3cf..a9e025afdd 100644 --- a/crates/burn-vision/tests/main.rs +++ b/crates/burn-vision/tests/main.rs @@ -12,6 +12,13 @@ mod tests_wgpu { burn_vision::testgen_all!(); } +#[cfg(all(test, feature = "vulkan"))] +mod tests_wgpu { + pub type TestBackend = burn_wgpu::Vulkan; + + burn_vision::testgen_all!(); +} + #[cfg(all(test, feature = "cuda"))] mod tests_cuda { pub type TestBackend = burn_cuda::Cuda; From 15c431c09f4062bef475bb5fd310e21c0bee6308 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Sun, 2 Feb 2025 19:20:19 +0100 Subject: [PATCH 18/24] Use GPU reduction for max label --- .../jit/connected_components/hardware_accelerated.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs index 6ed446a1f5..90583f6345 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs @@ -556,8 +556,9 @@ pub fn hardware_accelerated(stats.max_label.clone()).convert::(); - let max_label = *max_labels.as_slice::().unwrap().iter().max().unwrap() as usize; + let max_label = JitBackend::::int_max(stats.max_label); + let max_label = into_data_sync::(max_label).convert::(); + let max_label = max_label.as_slice::().unwrap()[0] as usize; let sliced = kernel::slice::( stats.area.clone(), &[0..batches, 0..(max_label + 1).next_multiple_of(4)], From e3ec0856812fd59099199904e51b4dde926eef27 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Sun, 2 Feb 2025 19:36:23 +0100 Subject: [PATCH 19/24] Manually fuse presence and prefix sum --- .../jit/connected_components/hardware_accelerated.rs | 10 +++++----- .../src/backends/jit/connected_components/mod.rs | 4 ++++ .../jit/{ => connected_components}/prefix_sum.rs | 7 +++++-- crates/burn-vision/src/backends/jit/mod.rs | 4 ---- 4 files changed, 14 insertions(+), 11 deletions(-) rename crates/burn-vision/src/backends/jit/{ => connected_components}/prefix_sum.rs (95%) diff --git a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs index 90583f6345..af111aa964 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs @@ -4,8 +4,8 @@ //! DASIP, 2018 use crate::{ - backends::jit::{connected_components::stats_from_opts, prefix_sum::prefix_sum}, - ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, + backends::jit::connected_components::stats_from_opts, ConnectedStatsOptions, + ConnectedStatsPrimitive, Connectivity, }; use burn_jit::{ kernel, @@ -16,6 +16,8 @@ use burn_jit::{ use burn_tensor::{ops::IntTensorOps, Shape}; use cubecl::{prelude::*, Feature}; +use super::prefix_sum::prefix_sum; + const BLOCK_H: u32 = 4; #[cube] @@ -563,9 +565,7 @@ pub fn hardware_accelerated::int_not_equal_elem(sliced, I::new(0)); - let present = kernel::cast::(present); - let relabel = prefix_sum::(present); + let relabel = prefix_sum::(sliced); let cube_dim = CubeDim::default(); let cube_count = CubeCount::new_3d( diff --git a/crates/burn-vision/src/backends/jit/connected_components/mod.rs b/crates/burn-vision/src/backends/jit/connected_components/mod.rs index af0ed730e2..53627a077e 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/mod.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/mod.rs @@ -1,5 +1,9 @@ mod hardware_accelerated; +/// Should eventually make this a full op, but the kernel is too specialized on ints and plane ops +/// to really use it in a general case. Needs more work to use as a normal tensor method. +mod prefix_sum; + use burn_jit::{ ops::numeric::{full_device, zeros_device}, tensor::JitTensor, diff --git a/crates/burn-vision/src/backends/jit/prefix_sum.rs b/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs similarity index 95% rename from crates/burn-vision/src/backends/jit/prefix_sum.rs rename to crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs index 38bd4f49d7..2c4cf663b6 100644 --- a/crates/burn-vision/src/backends/jit/prefix_sum.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs @@ -61,7 +61,8 @@ fn prefix_sum_kernel( if part_id < cube_count_x - 1 { for k in 0..vec4_spt { - let mut scan = scan_in[i + scan_offs]; + // Manually fuse not_equal and cast + let mut scan = Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero))); let x = scan[0]; scan[1] += x; let y = scan[1]; @@ -76,7 +77,9 @@ fn prefix_sum_kernel( if part_id == cube_count_x - 1 { for k in 0..vec4_spt { if i < scan_in.shape(1) { - let mut scan = scan_in[i + scan_offs]; + // Manually fuse not_equal and cast + let mut scan = + Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero))); let x = scan[0]; scan[1] += x; let y = scan[1]; diff --git a/crates/burn-vision/src/backends/jit/mod.rs b/crates/burn-vision/src/backends/jit/mod.rs index 8666645c48..9d610df49a 100644 --- a/crates/burn-vision/src/backends/jit/mod.rs +++ b/crates/burn-vision/src/backends/jit/mod.rs @@ -1,6 +1,2 @@ mod connected_components; mod ops; - -/// Should eventually make this a full op, but the kernel is too specialized on ints and plane ops -/// to really use it in a general case. Needs more work to use as a normal tensor method. -mod prefix_sum; From 11c8f1f5015f346ad050a9ab1b988a4edf8ab159 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Sun, 2 Feb 2025 19:42:40 +0100 Subject: [PATCH 20/24] Make prefix sum more generic over line size --- .../jit/connected_components/prefix_sum.rs | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs b/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs index 2c4cf663b6..f22910f442 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/prefix_sum.rs @@ -27,8 +27,9 @@ fn prefix_sum_kernel( let mut broadcast = SharedMemory::::new(1); let mut reduce = SharedMemory::::new(MAX_REDUCE_SIZE); let batch = CUBE_POS_Z; - let vec4_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.line_size()); - let nums_per_cube = CUBE_SIZE * vec4_spt; + let line_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.line_size()); + let nums_per_cube = CUBE_SIZE * line_spt; + let v_last = comptime!(scan_in.line_size() - 1); //acquire partition index if UNIT_POS_X == 0 { @@ -39,7 +40,7 @@ fn prefix_sum_kernel( let plane_id = UNIT_POS_X / PLANE_DIM; let dev_offs = part_id * nums_per_cube; - let plane_offs = plane_id * PLANE_DIM * vec4_spt; + let plane_offs = plane_id * PLANE_DIM * line_spt; // Exit if full plane is out of bounds if dev_offs + plane_offs >= scan_in.shape(1) { @@ -55,37 +56,35 @@ fn prefix_sum_kernel( let red_offs = batch * reduction.stride(0); let scan_offs = batch * scan_in.stride(0); - let mut t_scan = Array::>::vectorized(vec4_spt, scan_in.line_size()); + let mut t_scan = Array::>::vectorized(line_spt, scan_in.line_size()); { let mut i = dev_offs + plane_offs + UNIT_POS_PLANE; if part_id < cube_count_x - 1 { - for k in 0..vec4_spt { + for k in 0..line_spt { // Manually fuse not_equal and cast let mut scan = Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero))); - let x = scan[0]; - scan[1] += x; - let y = scan[1]; - scan[2] += y; - let z = scan[2]; - scan[3] += z; + #[unroll] + for v in 1..scan_in.line_size() { + let prev = scan[v - 1]; + scan[v] += prev; + } t_scan[k] = scan; i += PLANE_DIM; } } if part_id == cube_count_x - 1 { - for k in 0..vec4_spt { + for k in 0..line_spt { if i < scan_in.shape(1) { // Manually fuse not_equal and cast let mut scan = Line::cast_from(scan_in[i + scan_offs].not_equal(Line::new(zero))); - let x = scan[0]; - scan[1] += x; - let y = scan[1]; - scan[2] += y; - let z = scan[2]; - scan[3] += z; + #[unroll] + for v in 1..scan_in.line_size() { + let prev = scan[v - 1]; + scan[v] += prev; + } t_scan[k] = scan; } i += PLANE_DIM; @@ -95,8 +94,8 @@ fn prefix_sum_kernel( let mut prev = zero; let plane_mask = PLANE_DIM - 1; let circular_shift = (UNIT_POS_PLANE + plane_mask) & plane_mask; - for k in 0..vec4_spt { - let t = plane_broadcast(plane_inclusive_sum(t_scan[k][3]), circular_shift); + for k in 0..line_spt { + let t = plane_broadcast(plane_inclusive_sum(t_scan[k][v_last]), circular_shift); t_scan[k] += Line::cast_from(select(UNIT_POS_PLANE != 0, t, zero) + prev); prev += plane_broadcast(t, 0); } @@ -187,19 +186,19 @@ fn prefix_sum_kernel( zero }; let prev = Line::cast_from(broadcast[0] + prev); - let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * vec4_spt; + let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * line_spt; let dev_offset = part_id * nums_per_cube; let mut i = s_offset + dev_offset; if part_id < cube_count_x - 1 { - for k in 0..vec4_spt { + for k in 0..line_spt { scan_out[i + scan_offs] = t_scan[k] + prev; i += PLANE_DIM; } } if part_id == cube_count_x - 1 { - for k in 0..vec4_spt { + for k in 0..line_spt { if i < scan_out.shape(1) { scan_out[i + scan_offs] = t_scan[k] + prev; } From e6126c8fc57f66d5f26ba6dea695519b4c8b5e8d Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Mon, 3 Feb 2025 17:35:26 +0100 Subject: [PATCH 21/24] Add vision tests to xtask --- xtask/src/commands/test.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/xtask/src/commands/test.rs b/xtask/src/commands/test.rs index 5b94b2909e..ff0581089c 100644 --- a/xtask/src/commands/test.rs +++ b/xtask/src/commands/test.rs @@ -67,6 +67,15 @@ pub(crate) fn handle_command( "std with features: test-tch,record-item-custom-serde", )?; + // burn-vision + helpers::custom_crates_tests( + vec!["burn-vision"], + vec!["--features", "cpu"], + None, + None, + "std cpu", + )?; + if std::env::var("DISABLE_WGPU").is_err() { helpers::custom_crates_tests( vec!["burn-core"], @@ -75,6 +84,13 @@ pub(crate) fn handle_command( None, "std wgpu", )?; + helpers::custom_crates_tests( + vec!["burn-vision"], + vec!["--features", "wgpu"], + None, + None, + "std wgpu", + )?; // Vulkan isn't available on MacOS #[cfg(not(target_os = "macos"))] if std::env::var("DISABLE_WGPU_SPIRV").is_err() { @@ -85,6 +101,13 @@ pub(crate) fn handle_command( None, "std vulkan", )?; + helpers::custom_crates_tests( + vec!["burn-vision"], + vec!["--features", "vulkan"], + None, + None, + "std vulkan", + )?; } } From 1bbf50a4279c8724ffaca82edf6f17e6e093ea98 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Mon, 3 Feb 2025 17:49:12 +0100 Subject: [PATCH 22/24] Fix CPU and other review stuff --- crates/burn-vision/Cargo.toml | 5 ++++- .../burn-vision/src/backends/cpu/connected_components.rs | 8 ++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index 2ee8935af8..b4d8dbdce9 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -1,5 +1,8 @@ [package] -authors = ["nathanielsimard "] +authors = [ + "nathanielsimard ", + "wingertge ", +] categories = ["science"] description = "Vision processing operations for burn tensors" documentation = "https://docs.rs/burn-vision" diff --git a/crates/burn-vision/src/backends/cpu/connected_components.rs b/crates/burn-vision/src/backends/cpu/connected_components.rs index 71771b7eff..3d78c08fbb 100644 --- a/crates/burn-vision/src/backends/cpu/connected_components.rs +++ b/crates/burn-vision/src/backends/cpu/connected_components.rs @@ -25,7 +25,6 @@ pub fn connected_components_with_stats( ) -> (IntTensor, ConnectedStatsPrimitive) { let device = B::bool_device(&img); let (labels, stats) = run::(img, connectivity, ConnectedStatsOp::default); - println!("{stats:?}"); let stats = finalize_stats(&device, stats); (labels, stats) } @@ -211,13 +210,18 @@ fn finalize_stats( Tensor::::from_data(data, device).into_primitive() }; + let max_label = { + let data = TensorData::new(max_label, Shape::new([batches])); + Tensor::::from_data(data, device).into_primitive() + }; + ConnectedStatsPrimitive { area: into_prim(area), left: into_prim(left), top: into_prim(top), right: into_prim(right), bottom: into_prim(bottom), - max_label: into_prim(max_label), + max_label, } } From 4f174a8875279f393ffda7e6cebbeb2875c85cc5 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 7 Feb 2025 15:22:53 -0500 Subject: [PATCH 23/24] Add publish job --- .github/workflows/publish.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index de956c243e..d9bad9839c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,6 +6,25 @@ on: - "v*" jobs: + publish-burn-vision: + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 + with: + crate: burn-vision + needs: + - publish-burn-autodiff + - publish-burn-candle + - publish-burn-fusion + - publish-burn-jit + - publish-burn-ndarray + - publish-burn-tch + - publish-burn-tensor + - publish-burn-tensor-testgen + # dev dependencies + - publish-burn-wgpu + - publish-burn-cuda + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + publish-burn-router: uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 with: From 5a1ada3f1d929cb119bf65d85426a5b3bd0aa61b Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Sat, 8 Feb 2025 17:11:05 +0100 Subject: [PATCH 24/24] Review fixes --- crates/burn-vision/Cargo.toml | 8 +- .../hardware_accelerated.rs | 146 ++++++++++-------- crates/burn-vision/tests/main.rs | 8 +- xtask/src/commands/test.rs | 4 +- 4 files changed, 89 insertions(+), 77 deletions(-) diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index b4d8dbdce9..18f6600345 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -26,10 +26,10 @@ ndarray = ["burn-ndarray"] tch = ["burn-tch"] # Test features -cpu = ["export-tests"] -cuda = ["jit-backend", "export-tests"] -vulkan = ["burn-wgpu/vulkan", "jit-backend", "export-tests"] -wgpu = ["jit-backend", "export-tests"] +test-cpu = ["export-tests"] +test-cuda = ["jit-backend", "export-tests"] +test-vulkan = ["burn-wgpu/vulkan", "jit-backend", "export-tests"] +test-wgpu = ["jit-backend", "export-tests"] [dependencies] burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true } diff --git a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs index af111aa964..e4f89d25cd 100644 --- a/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs +++ b/crates/burn-vision/src/backends/jit/connected_components/hardware_accelerated.rs @@ -64,7 +64,7 @@ fn ballot_dyn(y: u32, pred: bool) -> u32 { plane_ballot(pred)[index] } -#[cube(launch)] +#[cube(launch_unchecked)] fn strip_labeling( img: &Tensor, labels: &Tensor>, @@ -193,7 +193,7 @@ fn strip_labeling( } } -#[cube(launch)] +#[cube(launch_unchecked)] fn strip_merge( img: &Tensor, labels: &Tensor>, @@ -297,7 +297,7 @@ fn strip_merge( } } -#[cube(launch)] +#[cube(launch_unchecked)] fn relabeling(img: &Tensor, labels: &mut Tensor) { let batch = ABSOLUTE_POS_Z; let plane_start_x = CUBE_POS_X * CUBE_DIM_X; @@ -338,7 +338,7 @@ fn relabeling(img: &Tensor, labels: &mut Tensor( img: &Tensor, labels: &mut Tensor, @@ -405,7 +405,7 @@ fn analysis( } } -#[cube(launch)] +#[cube(launch_unchecked)] fn compact_labels( labels: &mut Tensor, remap: &Tensor, @@ -429,7 +429,7 @@ fn compact_labels( } } -#[cube(launch)] +#[cube(launch_unchecked)] fn compact_stats( area: &Tensor, area_new: &mut Tensor, @@ -500,14 +500,16 @@ pub fn hardware_accelerated( - &client, - cube_count, - cube_dim, - img.as_tensor_arg::(1), - labels.as_tensor_arg::(1), - connectivity, - ); + unsafe { + strip_labeling::launch_unchecked::( + &client, + cube_count, + cube_dim, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + connectivity, + ) + }; let horizontal_warps = Ord::min((cols as u32).div_ceil(warp_size), 32); let cube_dim_merge = CubeDim::new_3d(warp_size, 1, horizontal_warps); @@ -517,14 +519,16 @@ pub fn hardware_accelerated( - &client, - cube_count, - cube_dim_merge, - img.as_tensor_arg::(1), - labels.as_tensor_arg::(1), - connectivity, - ); + unsafe { + strip_merge::launch_unchecked::( + &client, + cube_count, + cube_dim_merge, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + connectivity, + ) + }; let cube_count = CubeCount::Static( (cols as u32).div_ceil(cube_dim.x), @@ -535,28 +539,32 @@ pub fn hardware_accelerated( - &client, - cube_count, - cube_dim, - img.as_tensor_arg::(1), - labels.as_tensor_arg::(1), - ); + unsafe { + relabeling::launch_unchecked::( + &client, + cube_count, + cube_dim, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + ) + }; } else { - analysis::launch::( - &client, - cube_count, - cube_dim, - img.as_tensor_arg::(1), - labels.as_tensor_arg::(1), - stats.area.as_tensor_arg::(1), - stats.top.as_tensor_arg::(1), - stats.left.as_tensor_arg::(1), - stats.right.as_tensor_arg::(1), - stats.bottom.as_tensor_arg::(1), - stats.max_label.as_tensor_arg::(1), - stats_opt, - ); + unsafe { + analysis::launch_unchecked::( + &client, + cube_count, + cube_dim, + img.as_tensor_arg::(1), + labels.as_tensor_arg::(1), + stats.area.as_tensor_arg::(1), + stats.top.as_tensor_arg::(1), + stats.left.as_tensor_arg::(1), + stats.right.as_tensor_arg::(1), + stats.bottom.as_tensor_arg::(1), + stats.max_label.as_tensor_arg::(1), + stats_opt, + ) + }; if stats_opt.compact_labels { let max_label = JitBackend::::int_max(stats.max_label); let max_label = into_data_sync::(max_label).convert::(); @@ -575,34 +583,38 @@ pub fn hardware_accelerated(client.clone(), device.clone(), Shape::new([batches])); - compact_labels::launch::( - &client, - cube_count, - cube_dim, - labels.as_tensor_arg::(1), - relabel.as_tensor_arg::(1), - stats.max_label.as_tensor_arg::(1), - ); + unsafe { + compact_labels::launch_unchecked::( + &client, + cube_count, + cube_dim, + labels.as_tensor_arg::(1), + relabel.as_tensor_arg::(1), + stats.max_label.as_tensor_arg::(1), + ) + }; let cube_dim = CubeDim::new_1d(256); let cube_count = CubeCount::new_3d((rows * cols).div_ceil(256) as u32, 1, batches as u32); - compact_stats::launch::( - &client, - cube_count, - cube_dim, - stats.area.copy().as_tensor_arg::(1), - stats.area.as_tensor_arg::(1), - stats.top.copy().as_tensor_arg::(1), - stats.top.as_tensor_arg::(1), - stats.left.copy().as_tensor_arg::(1), - stats.left.as_tensor_arg::(1), - stats.right.copy().as_tensor_arg::(1), - stats.right.as_tensor_arg::(1), - stats.bottom.copy().as_tensor_arg::(1), - stats.bottom.as_tensor_arg::(1), - relabel.as_tensor_arg::(1), - ); + unsafe { + compact_stats::launch_unchecked::( + &client, + cube_count, + cube_dim, + stats.area.copy().as_tensor_arg::(1), + stats.area.as_tensor_arg::(1), + stats.top.copy().as_tensor_arg::(1), + stats.top.as_tensor_arg::(1), + stats.left.copy().as_tensor_arg::(1), + stats.left.as_tensor_arg::(1), + stats.right.copy().as_tensor_arg::(1), + stats.right.as_tensor_arg::(1), + stats.bottom.copy().as_tensor_arg::(1), + stats.bottom.as_tensor_arg::(1), + relabel.as_tensor_arg::(1), + ) + }; } } diff --git a/crates/burn-vision/tests/main.rs b/crates/burn-vision/tests/main.rs index a9e025afdd..6bd8dbfb96 100644 --- a/crates/burn-vision/tests/main.rs +++ b/crates/burn-vision/tests/main.rs @@ -1,25 +1,25 @@ -#[cfg(all(test, feature = "cpu"))] +#[cfg(all(test, feature = "test-cpu"))] mod tests_cpu { pub type TestBackend = burn_ndarray::NdArray; burn_vision::testgen_all!(); } -#[cfg(all(test, feature = "wgpu"))] +#[cfg(all(test, feature = "test-wgpu"))] mod tests_wgpu { pub type TestBackend = burn_wgpu::Wgpu; burn_vision::testgen_all!(); } -#[cfg(all(test, feature = "vulkan"))] +#[cfg(all(test, feature = "test-vulkan"))] mod tests_wgpu { pub type TestBackend = burn_wgpu::Vulkan; burn_vision::testgen_all!(); } -#[cfg(all(test, feature = "cuda"))] +#[cfg(all(test, feature = "test-cuda"))] mod tests_cuda { pub type TestBackend = burn_cuda::Cuda; diff --git a/xtask/src/commands/test.rs b/xtask/src/commands/test.rs index ff0581089c..dbd7eab842 100644 --- a/xtask/src/commands/test.rs +++ b/xtask/src/commands/test.rs @@ -70,7 +70,7 @@ pub(crate) fn handle_command( // burn-vision helpers::custom_crates_tests( vec!["burn-vision"], - vec!["--features", "cpu"], + vec!["--features", "test-cpu"], None, None, "std cpu", @@ -86,7 +86,7 @@ pub(crate) fn handle_command( )?; helpers::custom_crates_tests( vec!["burn-vision"], - vec!["--features", "wgpu"], + vec!["--features", "test-wgpu"], None, None, "std wgpu",