From c8f385cf8f4c965bf3dbe3afac3f7f14c9b2000c Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Mon, 3 Feb 2025 14:32:08 +0100 Subject: [PATCH 1/4] Update cubecl (#2764) * Update cubecl * Update to scope merge * Fix bitwise shift * Update * Update lock for OpenSSL fix --- Cargo.lock | 396 +++++++++--------- Cargo.toml | 4 +- crates/burn-jit/src/fusion/matmul/args.rs | 2 +- crates/burn-jit/src/fusion/on_write/ir.rs | 4 +- .../src/kernel/conv/conv2d/gemm/launch.rs | 12 +- .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 38 +- crates/burn-jit/src/kernel/conv/error.rs | 11 +- crates/burn-jit/src/ops/int_ops.rs | 16 +- .../src/tensor/quantization/scheme.rs | 2 +- crates/burn-tensor/src/tests/ops/bitwise.rs | 4 + 10 files changed, 253 insertions(+), 236 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cb75c0ac6e..4151733570 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,7 +41,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", "zerocopy", @@ -174,7 +174,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -249,18 +249,18 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "async-trait" -version = "0.1.85" +version = "0.1.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" +checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -547,9 +547,9 @@ checksum = "c360505aed52b7ec96a3636c3f039d99103c37d1d9b4f7a8c743d3ea9ffcd03b" [[package]] name = "bumpalo" -version = "3.16.0" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "burn" @@ -589,7 +589,7 @@ version = "0.17.0" dependencies = [ "cubecl-common", "dashmap", - "getrandom", + "getrandom 0.2.15", "indicatif", "rayon", "reqwest", @@ -690,7 +690,7 @@ dependencies = [ "derive-new 0.7.0", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -741,7 +741,7 @@ dependencies = [ "rust-format", "serde", "serde_json", - "syn 2.0.96", + "syn 2.0.98", "thiserror 2.0.11", "tracing-core", "tracing-subscriber", @@ -929,7 +929,7 @@ checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -946,9 +946,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" dependencies = [ "serde", ] @@ -1057,9 +1057,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.9" +version = "1.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" +checksum = "e4730490333d58093109dc02c23174c3f4d490998c3fed3cc8e82d57afedb9cf" dependencies = [ "jobserver", "libc", @@ -1164,7 +1164,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1184,9 +1184,9 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.52" +version = "0.1.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c682c223677e0e5b6b7f63a64b9351844c3f1b1678a68b7ee617e30fb082620e" +checksum = "e24a03c8b52922d68a1589ad61032f2c1aa5a8158d2aa0d93c6e9534944bbad6" dependencies = [ "cc", ] @@ -1336,9 +1336,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -1437,9 +1437,9 @@ dependencies = [ [[package]] name = "crunchy" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] name = "crypto-common" @@ -1475,7 +1475,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1490,14 +1490,14 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "derive-new 0.6.0", "derive_more 1.0.0", "embassy-futures", "futures-lite", - "getrandom", + "getrandom 0.2.15", "half", "log", "num-traits", @@ -1511,7 +1511,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1532,7 +1532,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-common", @@ -1546,7 +1546,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-common", @@ -1562,7 +1562,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-common", @@ -1578,9 +1578,9 @@ dependencies = [ [[package]] name = "cubecl-hip-sys" -version = "6.3.1000" +version = "6.3.1001" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4d987c1720eab39c72c515377a8001f683a4c4d99232a29fc0de389d9a8ce4f" +checksum = "c7e92df7f9feff6a469932fc4d4b349d28000af9e6f34e583eb4f8df70038d48" dependencies = [ "libc", ] @@ -1588,22 +1588,25 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-common", "cubecl-macros-internal", "derive_more 1.0.0", "float-ord", + "fnv", "half", + "hashbrown 0.14.5", "num-traits", + "portable-atomic", "serde", - "type_hash", + "variadics_please", ] [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-core", @@ -1615,7 +1618,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-common", "darling", @@ -1624,24 +1627,24 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "cubecl-macros-internal" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1657,7 +1660,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1667,7 +1670,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "async-channel", "async-lock", @@ -1689,7 +1692,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1704,7 +1707,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "ash", "async-channel", @@ -1724,9 +1727,9 @@ dependencies = [ [[package]] name = "cudarc" -version = "0.12.2" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd76de2aa3a7bdb9a65941ea5a3c688d941688f736a81b2fc5beb88747a7f25" +checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384" dependencies = [ "half", "libloading", @@ -1821,7 +1824,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1832,7 +1835,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1878,7 +1881,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1889,7 +1892,7 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1900,7 +1903,7 @@ checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1921,7 +1924,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1931,7 +1934,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1942,7 +1945,7 @@ checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1962,7 +1965,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "unicode-xid", ] @@ -2018,7 +2021,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2042,9 +2045,9 @@ dependencies = [ [[package]] name = "dyn-clone" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" [[package]] name = "dyn-stack" @@ -2092,7 +2095,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2104,7 +2107,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2334,7 +2337,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2427,7 +2430,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2613,10 +2616,22 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets 0.52.6", +] + [[package]] name = "gif" version = "0.13.1" @@ -2684,15 +2699,15 @@ dependencies = [ [[package]] name = "gix-trace" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04bdde120c29f1fc23a24d3e115aeeea3d60d8e65bab92cc5f9d90d9302eb952" +checksum = "7c396a2036920c69695f760a65e7f2677267ccf483f25046977d87e4cb2665f7" [[package]] name = "gix-utils" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba427e3e9599508ed98a6ddf8ed05493db114564e338e41f6a996d2e4790335f" +checksum = "ff08f24e03ac8916c478c8419d7d3c33393da9bb41fa4c24455d5406aeefd35f" dependencies = [ "fastrand", "unicode-normalization", @@ -2998,9 +3013,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.9.5" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" +checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" [[package]] name = "httpdate" @@ -3016,9 +3031,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "1.5.2" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" dependencies = [ "bytes", "futures-channel", @@ -3225,7 +3240,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -3331,9 +3346,9 @@ checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -3342,9 +3357,9 @@ dependencies = [ [[package]] name = "indicatif" -version = "0.17.9" +version = "0.17.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" dependencies = [ "console", "number_prefix", @@ -3378,7 +3393,7 @@ dependencies = [ "indoc", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -3398,14 +3413,14 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "ipnet" -version = "2.10.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is_terminal_polyfill" @@ -3514,9 +3529,9 @@ checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libfuzzer-sys" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b9569d2f74e257076d8c6bfa73fb505b46b851e51ddaecc825944aa3bed17fa" +checksum = "cf78f52d400cf2d84a3a973a78a592b4adc535739e0a5597a0da6f0c357adc75" dependencies = [ "arbitrary", "cc", @@ -3791,9 +3806,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" +checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924" dependencies = [ "adler2", "simd-adler32", @@ -3807,7 +3822,7 @@ checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", "log", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -3858,7 +3873,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -3893,9 +3908,9 @@ dependencies = [ [[package]] name = "native-tls" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +checksum = "0dab59f8e050d5df8e4dd87d9206fb6f65a483e20ac9fda365ade4fab353196c" dependencies = [ "libc", "log", @@ -4061,7 +4076,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -4133,7 +4148,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -4242,9 +4257,9 @@ dependencies = [ [[package]] name = "objc2-encode" -version = "4.0.3" +version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7891e71393cd1f227313c9379a26a584ff3d7e6e7159e988851f0934c993f0f8" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" [[package]] name = "objc2-foundation" @@ -4395,9 +4410,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.68" +version = "0.10.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" +checksum = "61cfb4e166a8bb8c9b55c500bc2308550148ece889be90f609377e58140f42c6" dependencies = [ "bitflags 2.8.0", "cfg-if", @@ -4416,20 +4431,20 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "openssl-probe" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openssl-sys" -version = "0.9.104" +version = "0.9.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" +checksum = "8b22d5b84be05a8d6947c7cb71f7c849aa0f112acd4bf51c2a7c1c988ac0a9dc" dependencies = [ "cc", "libc", @@ -4671,7 +4686,7 @@ version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72571dde488ecccbe799798bf99ab7308ebdb7cf5d95bcc498dbd5a132f0da4d" dependencies = [ - "getrandom", + "getrandom 0.2.15", "polars-arrow", "polars-core", "polars-error", @@ -4699,7 +4714,7 @@ dependencies = [ "dyn-clone", "either", "ethnum", - "getrandom", + "getrandom 0.2.15", "hashbrown 0.15.2", "itoa", "lz4", @@ -5144,7 +5159,7 @@ dependencies = [ "once_cell", "polars-error", "rand", - "raw-cpuid 11.2.0", + "raw-cpuid 11.3.0", "rayon", "stacker", "sysinfo 0.33.1", @@ -5156,6 +5171,9 @@ name = "portable-atomic" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +dependencies = [ + "serde", +] [[package]] name = "portable-atomic-util" @@ -5204,7 +5222,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" dependencies = [ "proc-macro2", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -5241,7 +5259,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30" dependencies = [ "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -5411,7 +5429,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", ] [[package]] @@ -5426,9 +5444,9 @@ dependencies = [ [[package]] name = "range-alloc" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8a99fddc9f0ba0a85884b8d14e3592853e787d581ca1816c91349b10e4eeab" +checksum = "c3d6831663a5098ea164f89cff59c6284e95f4e3c76ce9848d4529f5ccca9bde" [[package]] name = "ratatui" @@ -5513,9 +5531,9 @@ dependencies = [ [[package]] name = "raw-cpuid" -version = "11.2.0" +version = "11.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" +checksum = "c6928fa44c097620b706542d428957635951bade7143269085389d42c8a4927e" dependencies = [ "bitflags 2.8.0", ] @@ -5586,7 +5604,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -5613,7 +5631,7 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom", + "getrandom 0.2.15", "libredox", "thiserror 1.0.69", ] @@ -5736,7 +5754,7 @@ checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.15", "libc", "spin", "untrusted", @@ -5800,7 +5818,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.96", + "syn 2.0.98", "unicode-ident", ] @@ -5851,9 +5869,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.43" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags 2.8.0", "errno", @@ -5864,9 +5882,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.21" +version = "0.23.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" +checksum = "9fb9263ab4eb695e42321db096e3b8fbd715a59b154d5c88d82db2175b681ba7" dependencies = [ "log", "once_cell", @@ -5901,9 +5919,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" [[package]] name = "rustls-webpki" @@ -5924,9 +5942,9 @@ checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "safetensors" @@ -6040,9 +6058,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "seq-macro" @@ -6087,14 +6105,14 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "serde_json" -version = "1.0.137" +version = "1.0.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "930cfb6e6abf99298aaad7d29abbef7a9999a9a8806a40088f55f0dcec03146b" +checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" dependencies = [ "itoa", "memchr", @@ -6165,7 +6183,7 @@ checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6443,7 +6461,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6459,15 +6477,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", - "quote", "unicode-ident", ] [[package]] name = "syn" -version = "2.0.96" +version = "2.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" +checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" dependencies = [ "proc-macro2", "quote", @@ -6491,7 +6508,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6620,13 +6637,13 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.15.0" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" +checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" dependencies = [ "cfg-if", "fastrand", - "getrandom", + "getrandom 0.3.1", "once_cell", "rustix", "windows-sys 0.59.0", @@ -6709,7 +6726,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6720,7 +6737,7 @@ checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6820,7 +6837,7 @@ dependencies = [ "aho-corasick", "derive_builder", "esaxx-rs", - "getrandom", + "getrandom 0.2.15", "hf-hub", "itertools 0.12.1", "lazy_static", @@ -6867,7 +6884,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6938,9 +6955,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.22" +version = "0.22.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +checksum = "02a8b472d1a3d7c18e2d61a489aee3453fd9031c33e4f55bd533f4a7adca1bee" dependencies = [ "indexmap", "serde", @@ -7019,7 +7036,7 @@ checksum = "5a3a646485f7cd8f580749ab94718ad3d344bcc0cc5b0fefe43c15fdd898bb96" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7054,7 +7071,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7129,37 +7146,6 @@ dependencies = [ "rustc-hash", ] -[[package]] -name = "type_hash" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c86f48f11992d3e379358c63cb25736c0b23944ff000d1583bbccad2b0b7c6" -dependencies = [ - "type_hash_core", - "type_hash_macros", -] - -[[package]] -name = "type_hash_core" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87b1e93e2cd97790892dbe2d2813fbaa6eebaeb960265f59e363e79e51e4997a" -dependencies = [ - "fnv", -] - -[[package]] -name = "type_hash_macros" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "746fc164e076483ef087b3989f7aa80ffd9320fa558f3cb72cecfb9bb1dbc41e" -dependencies = [ - "either", - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "typenum" version = "1.17.0" @@ -7210,9 +7196,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" [[package]] name = "unicode-normalization" @@ -7349,7 +7335,7 @@ version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" dependencies = [ - "getrandom", + "getrandom 0.2.15", "rand", ] @@ -7366,9 +7352,9 @@ dependencies = [ [[package]] name = "valuable" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" [[package]] name = "variadics_please" @@ -7378,7 +7364,7 @@ checksum = "41b6d82be61465f97d42bd1d15bf20f3b0a3a0905018f38f9d6f6962055b0b5c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7424,6 +7410,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -7446,7 +7441,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "wasm-bindgen-shared", ] @@ -7481,7 +7476,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -7543,9 +7538,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.7" +version = "0.26.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" +checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9" dependencies = [ "rustls-pki-types", ] @@ -7778,7 +7773,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7789,7 +7784,7 @@ checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7800,7 +7795,7 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7811,7 +7806,7 @@ checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8003,13 +7998,22 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.24" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" +checksum = "7e49d2d35d3fad69b39b94139037ecfb4f359f08958b9c11e7315ce770462419" dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags 2.8.0", +] + [[package]] name = "wrapcenum-derive" version = "0.4.1" @@ -8019,7 +8023,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8116,7 +8120,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "synstructure", ] @@ -8138,7 +8142,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8158,7 +8162,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "synstructure", ] @@ -8179,7 +8183,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8201,7 +8205,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 09b944468c..7263dec57a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/fusion/matmul/args.rs b/crates/burn-jit/src/fusion/matmul/args.rs index 1dbbf3baea..bba18e88f9 100644 --- a/crates/burn-jit/src/fusion/matmul/args.rs +++ b/crates/burn-jit/src/fusion/matmul/args.rs @@ -247,7 +247,7 @@ impl CubeType for FusedMatmulState { } impl Init for FusedMatmulStateExpand { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _context: &mut Scope) -> Self { self } } diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index 0cec2d29c7..36c8e402a0 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -45,13 +45,13 @@ impl CubeType for Arg { } impl Init for Arg { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _context: &mut Scope) -> Self { self } } impl IntoRuntime for Arg { - fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType { + fn __expand_runtime_method(self, _context: &mut Scope) -> Self::ExpandType { self } } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs index f36f89bdf5..ad70a9b825 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs @@ -7,7 +7,7 @@ use burn_tensor::{ use cubecl::{ flex32, ir::{Elem, FloatKind}, - linalg::matmul::{self}, + linalg::matmul::{self, kernels::MatmulLaunchError}, tensor_line_size, tf32, Feature, }; use half::{bf16, f16}; @@ -195,18 +195,14 @@ where let cube_count = Alg::cube_count(&selection, &problem); let advanced_config = Default::default(); - let config = match Alg::make_config( + let config = Alg::make_config( config_input, &problem, &cube_dim, &cube_count, &advanced_config, - ) { - Ok(val) => val, - Err(err) => { - panic!("Can't launch conv kernel because of an invalid config: {err}") - } - }; + ) + .map_err(MatmulLaunchError::InvalidConfig)?; let bias = bias.unwrap_or_else(|| { empty_device::(input.client.clone(), input.device.clone(), Shape::new([1])) diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index f74cdaf8bc..09ce56898b 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -98,25 +98,38 @@ fn im2col_kernel( } #[cfg(not(test))] -pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option { - let cube_count_per_batch = (out_h * out_w).div_ceil(burn_common::PLANE_DIM_APPROX); +pub(crate) fn batches_per_run( + batch_size: usize, + out_h: usize, + out_w: usize, +) -> Result { + use cubecl::linalg::matmul::kernels::MatmulAvailabilityError; + + let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::PLANE_DIM_APPROX); let max_cube_count = u16::MAX as usize; let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size); if max_simultaneous == 0 { - return None; + return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static( + cube_count_per_batch as u32, + 1, + 1, + )) + .into()); } - Some( - (0..=max_simultaneous) - .rev() - .find(|per_run| batch_size % per_run == 0) - .expect("Logically not possible"), - ) + Ok((0..=max_simultaneous) + .rev() + .find(|per_run| batch_size % per_run == 0) + .expect("Logically not possible")) } #[cfg(test)] #[allow(unused)] -pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option { - Some(1) +pub(crate) fn batches_per_run( + batch_size: usize, + out_h: usize, + out_w: usize, +) -> Result { + Ok(1) } fn im2col( @@ -214,8 +227,7 @@ pub fn conv2d_im2col( return execute_1x1_kernel::(input, weight, bias, options); } - let batches_per_run = batches_per_run(batch_size, out_h, out_w) - .expect("Image too large to run even one batch at once"); + let batches_per_run = batches_per_run(batch_size, out_h, out_w)?; let matmul_shape = Shape::new([groups, out_c_per_group, batches_per_run * out_h * out_w]); let mut out = if batches_per_run != batch_size { diff --git a/crates/burn-jit/src/kernel/conv/error.rs b/crates/burn-jit/src/kernel/conv/error.rs index 99c91fc751..2654a20e24 100644 --- a/crates/burn-jit/src/kernel/conv/error.rs +++ b/crates/burn-jit/src/kernel/conv/error.rs @@ -1,5 +1,8 @@ use core::fmt::Debug; -use cubecl::{linalg::matmul::kernels::MatmulLaunchError, tune::AutotuneError}; +use cubecl::{ + linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}, + tune::AutotuneError, +}; pub enum ConvLaunchError { Matmul(MatmulLaunchError), @@ -30,6 +33,12 @@ impl From for ConvLaunchError { } } +impl From for ConvLaunchError { + fn from(value: MatmulAvailabilityError) -> Self { + Self::Matmul(MatmulLaunchError::Unavailable(value)) + } +} + #[allow(clippy::from_over_into)] impl Into for ConvLaunchError { fn into(self) -> AutotuneError { diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 068c1269d9..8da778d1e8 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -328,26 +328,18 @@ where } fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - let lhs_cast = kernel::cast::(lhs); - let rhs_cast = kernel::cast::(rhs); - launch_binop_int::(lhs_cast, rhs_cast) + launch_binop_int::(lhs, rhs) } fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { - let lhs_cast = kernel::cast::(lhs); - let rhs_cast = rhs.elem::(); - launch_scalar_binop_int::(lhs_cast, rhs_cast) + launch_scalar_binop_int::(lhs, rhs) } fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - let lhs_cast = kernel::cast::(lhs); - let rhs_cast = kernel::cast::(rhs); - launch_binop_int::(lhs_cast, rhs_cast) + launch_binop_int::(lhs, rhs) } fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { - let lhs_cast = kernel::cast::(lhs); - let rhs_cast = rhs.elem::(); - launch_scalar_binop_int::(lhs_cast, rhs_cast) + launch_scalar_binop_int::(lhs, rhs) } } diff --git a/crates/burn-tensor/src/tensor/quantization/scheme.rs b/crates/burn-tensor/src/tensor/quantization/scheme.rs index fb141ee16d..27fa996ad6 100644 --- a/crates/burn-tensor/src/tensor/quantization/scheme.rs +++ b/crates/burn-tensor/src/tensor/quantization/scheme.rs @@ -37,7 +37,7 @@ impl CubeType for QuantizationScheme { } #[cfg(feature = "cubecl")] impl cubecl::frontend::Init for QuantizationScheme { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _scope: &mut cubecl::ir::Scope) -> Self { self } } diff --git a/crates/burn-tensor/src/tests/ops/bitwise.rs b/crates/burn-tensor/src/tests/ops/bitwise.rs index 73702a716e..c85f5edcc5 100644 --- a/crates/burn-tensor/src/tests/ops/bitwise.rs +++ b/crates/burn-tensor/src/tests/ops/bitwise.rs @@ -124,6 +124,10 @@ mod tests { #[test] fn should_apply_bitwise_left_shift_2d() { + if (IntType::MAX as u32) < 512 { + return; + } + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); From e0c641934fb44b67e6d6d143e5b72d06cd68ba41 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 08:32:58 -0500 Subject: [PATCH 2/4] Bump indicatif from 0.17.9 to 0.17.11 (#2769) Bumps [indicatif](https://github.com/console-rs/indicatif) from 0.17.9 to 0.17.11. - [Release notes](https://github.com/console-rs/indicatif/releases) - [Commits](https://github.com/console-rs/indicatif/compare/0.17.9...0.17.11) --- updated-dependencies: - dependency-name: indicatif dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 7263dec57a..7287eae729 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ globwalk = "0.9.1" hashbrown = "0.15.2" hound = "3.5.1" image = "0.25.5" -indicatif = "0.17.9" +indicatif = "0.17.11" js-sys = "0.3.72" libm = "0.2.11" log = { default-features = false, version = "0.4.25" } From 6b2e66bd36bd4ddc0e2cd1e94690ece6212562fa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 09:19:27 -0500 Subject: [PATCH 3/4] Bump sysinfo from 0.32.1 to 0.33.1 (#2771) * Bump sysinfo from 0.32.1 to 0.33.1 Bumps [sysinfo](https://github.com/GuillaumeGomez/sysinfo) from 0.32.1 to 0.33.1. - [Changelog](https://github.com/GuillaumeGomez/sysinfo/blob/master/CHANGELOG.md) - [Commits](https://github.com/GuillaumeGomez/sysinfo/compare/v0.32.1...v0.33.1) --- updated-dependencies: - dependency-name: sysinfo dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Fix Hip backend name * Fix refresh kind methods --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Guillaume Lagrange --- Cargo.lock | 25 +++++-------------- Cargo.toml | 2 +- .../src/persistence/system_info.rs | 2 +- crates/burn-train/src/metric/cpu_use.rs | 4 ++- .../examples/ag-news-train.rs | 4 +-- 5 files changed, 13 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4151733570..ab9eddb0bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -393,7 +393,7 @@ dependencies = [ "serial_test", "strum", "strum_macros", - "sysinfo 0.32.1", + "sysinfo", "tracing-subscriber", "wgpu", "wsl", @@ -893,7 +893,7 @@ dependencies = [ "ratatui", "rstest", "serde", - "sysinfo 0.32.1", + "sysinfo", "systemstat", "tracing-appender", "tracing-core", @@ -3544,7 +3544,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -5162,7 +5162,7 @@ dependencies = [ "raw-cpuid 11.3.0", "rayon", "stacker", - "sysinfo 0.33.1", + "sysinfo", "version_check", ] @@ -6525,21 +6525,6 @@ dependencies = [ "walkdir", ] -[[package]] -name = "sysinfo" -version = "0.32.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c33cd241af0f2e9e3b5c32163b873b29956890b5342e6745b917ce9d490f4af" -dependencies = [ - "core-foundation-sys", - "libc", - "memchr", - "ntapi", - "rayon", - "serde", - "windows 0.57.0", -] - [[package]] name = "sysinfo" version = "0.33.1" @@ -6550,6 +6535,8 @@ dependencies = [ "libc", "memchr", "ntapi", + "rayon", + "serde", "windows 0.57.0", ] diff --git a/Cargo.toml b/Cargo.toml index 7287eae729..169d668aa8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -145,7 +145,7 @@ uuid = { version = "1.12.1", default-features = false } libc = "0.2.169" nvml-wrapper = "0.10.0" -sysinfo = "0.32.1" +sysinfo = "0.33.1" systemstat = "0.2.3" tch = "0.15.0" diff --git a/backend-comparison/src/persistence/system_info.rs b/backend-comparison/src/persistence/system_info.rs index 287b629c21..3fe24bc955 100644 --- a/backend-comparison/src/persistence/system_info.rs +++ b/backend-comparison/src/persistence/system_info.rs @@ -38,7 +38,7 @@ impl BenchmarkSystemInfo { fn enumerate_cpus() -> Vec { let system = sysinfo::System::new_with_specifics( - sysinfo::RefreshKind::new().with_cpu(sysinfo::CpuRefreshKind::everything()), + sysinfo::RefreshKind::nothing().with_cpu(sysinfo::CpuRefreshKind::everything()), ); let cpu_names: HashSet = system .cpus() diff --git a/crates/burn-train/src/metric/cpu_use.rs b/crates/burn-train/src/metric/cpu_use.rs index 2769793088..d06d8429db 100644 --- a/crates/burn-train/src/metric/cpu_use.rs +++ b/crates/burn-train/src/metric/cpu_use.rs @@ -26,7 +26,9 @@ impl CpuUse { } fn refresh(sys: &mut System) -> f64 { - sys.refresh_specifics(RefreshKind::new().with_cpu(CpuRefreshKind::new().with_cpu_usage())); + sys.refresh_specifics( + RefreshKind::nothing().with_cpu(CpuRefreshKind::nothing().with_cpu_usage()), + ); let cpus = sys.cpus(); let num_cpus = cpus.len(); diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 9a9cab44bd..927c190b2c 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -116,10 +116,10 @@ mod cuda { #[cfg(feature = "hip")] mod hip { use crate::{launch, ElemType}; - use burn::backend::{Autodiff, HipJit}; + use burn::backend::{Autodiff, Hip}; pub fn run() { - launch::>>(vec![Default::default()]); + launch::>>(vec![Default::default()]); } } From 9f003203d05a0b260c3cd5ad44a7460dfcffc67d Mon Sep 17 00:00:00 2001 From: SalvoMcL <64030770+salvomcl@users.noreply.github.com> Date: Mon, 3 Feb 2025 16:05:14 +0100 Subject: [PATCH 4/4] Feat: Add PoissonNLL loss (#2765) * added PoissonNLLLossConfig * added PoissonNLLLoss * added tests * update docs * added requested changes --- burn-book/src/building-blocks/module.md | 1 + crates/burn-core/src/nn/loss/mod.rs | 2 + crates/burn-core/src/nn/loss/poisson.rs | 390 ++++++++++++++++++++++++ 3 files changed, 393 insertions(+) create mode 100644 crates/burn-core/src/nn/loss/poisson.rs diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index 0f5aca7f24..9598d6e39e 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -294,3 +294,4 @@ Burn comes with built-in modules that you can use to build your own modules. | `CrossEntropyLoss` | `nn.CrossEntropyLoss` | | `MseLoss` | `nn.MSELoss` | | `HuberLoss` | `nn.HuberLoss` | +| `PoissonNllLoss` | `nn.PoissonNLLLoss` | diff --git a/crates/burn-core/src/nn/loss/mod.rs b/crates/burn-core/src/nn/loss/mod.rs index cca7b4541b..475364e63b 100644 --- a/crates/burn-core/src/nn/loss/mod.rs +++ b/crates/burn-core/src/nn/loss/mod.rs @@ -2,10 +2,12 @@ mod binary_cross_entropy; mod cross_entropy; mod huber; mod mse; +mod poisson; mod reduction; pub use binary_cross_entropy::*; pub use cross_entropy::*; pub use huber::*; pub use mse::*; +pub use poisson::*; pub use reduction::*; diff --git a/crates/burn-core/src/nn/loss/poisson.rs b/crates/burn-core/src/nn/loss/poisson.rs new file mode 100644 index 0000000000..3cc989ad8e --- /dev/null +++ b/crates/burn-core/src/nn/loss/poisson.rs @@ -0,0 +1,390 @@ +use core::f32::consts::PI; + +use crate as burn; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; +use crate::tensor::backend::Backend; +use crate::tensor::Tensor; +use crate::{config::Config, module::Module}; + +use super::Reduction; + +/// Configuration for creating a [PoissonNllLoss](PoissonNllLoss) instance. +/// +/// This configuration allows customization of the Poisson Negative Log Likelihood (NLL) loss +/// behavior, such as whether the input is in log-space, whether to include the Stirling +/// approximation term, and a small epsilon value to avoid numerical instability. +#[derive(Config, Debug)] +pub struct PoissonNllLossConfig { + /// If `true`, the predictions are expected to be in log-space. + /// + /// When `log_input` is `true`, the loss is computed as: + /// ```text + /// L(predictions, target) = exp(predictions) - target * predictions + /// ``` + /// When `log_input` is `false`, the loss is computed as: + /// ```text + /// L(predictions, target) = predictions - target * log(predictions + eps) + /// ``` + #[config(default = true)] + pub log_input: bool, + /// Whether to compute the full loss, including the Stirling approximation term. + /// + /// When `full` is `true`, the Stirling approximation term is added to the loss: + /// ```text + /// target * log(target) - target + 0.5 * log(2 * PI * target) + /// ``` + #[config(default = false)] + pub full: bool, + /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`. + /// + /// This epsilon value is added to the predictions to ensure numerical stability + /// when computing the logarithm. + #[config(default = 1e-8)] + pub eps: f64, +} + +impl PoissonNllLossConfig { + /// Initializes a [PoissonNllLoss](PoissonNllLoss) instance with the current configuration. + /// + /// # Panics + /// - Panics if `eps` is not a positive number. + pub fn init(&self) -> PoissonNllLoss { + self.assertions(); + PoissonNllLoss { + log_input: self.log_input, + full: self.full, + eps: self.eps, + } + } + + /// Validates the configuration parameters. + /// + /// # Panics + /// - Panics if `eps` is not a positive number. + fn assertions(&self) { + assert!( + self.eps > 0., + "eps for PoissonNllLoss must be a positive number." + ); + } +} + +/// Negative Log Likelihood (NLL) loss with a Poisson distribution assumption for the target. +/// +/// This loss function is used when the target values are assumed to follow a Poisson distribution. +/// The loss is defined as: +/// ```text +/// target ~ Poisson(input) +/// L(predictions, target) = predictions - target * log(predictions) + log(target!) +/// ``` +/// The last term (`log(target!)`) can be omitted or approximated using Stirling's formula. +/// The approximation is applied for `target > 1`, while for `target <= 1`, zeros are added to the loss. +/// +/// For more details, see: +/// +#[derive(Module, Debug, Clone)] +#[module(custom_display)] +pub struct PoissonNllLoss { + /// If `true`, the predictions are expected to be in log-space. + pub log_input: bool, + /// Whether to compute the full loss, including the Stirling approximation term. + pub full: bool, + /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`. + pub eps: f64, +} + +impl ModuleDisplay for PoissonNllLoss { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("log_input", &self.log_input) + .add("full", &self.full) + .add("eps", &self.eps) + .optional() + } +} + +impl PoissonNllLoss { + /// Computes the loss element-wise for the given predictions and targets, then reduces + /// the result to a single loss value. + /// + /// # Arguments + /// - `predictions`: The predicted values. + /// - `targets`: The target values. + /// - `reduction`: The reduction method to apply. `Reduction::Auto` behaves as `Reduction::Mean`. + /// + /// # Shapes + /// - `predictions`: `[...dims]` + /// - `targets`: `[...dims]` + /// - `output`: `[1]` + /// + /// # Panics + /// - Panics if the shapes of `predictions` and `targets` do not match. + /// - Panics if any target value is negative. + /// - Panics if `log_input` is `false` and any prediction value is negative. + pub fn forward( + &self, + predictions: Tensor, + targets: Tensor, + reduction: Reduction, + ) -> Tensor { + let loss = self.forward_no_reduction(predictions, targets); + match reduction { + Reduction::Mean | Reduction::Auto => loss.mean(), + Reduction::Sum => loss.sum(), + } + } + + /// Computes the loss element-wise for the given predictions and targets without reduction. + /// + /// # Arguments + /// - `predictions`: The predicted values. + /// - `targets`: The target values. + /// + /// # Shapes + /// - `predictions`: `[...dims]` + /// - `targets`: `[...dims]` + /// - `output`: `[...dims]` + /// + /// # Panics + /// - Panics if the shapes of `predictions` and `targets` do not match. + /// - Panics if any target value is negative. + /// - Panics if `log_input` is `false` and any prediction value is negative. + pub fn forward_no_reduction( + &self, + predictions: Tensor, + targets: Tensor, + ) -> Tensor { + self.assertions(&predictions, &targets); + let mut loss; + if self.log_input { + loss = predictions.clone().exp() - targets.clone() * predictions; + } else { + loss = predictions.clone() - targets.clone() * (predictions + self.eps).log(); + } + if self.full { + let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone() + + (targets.clone() * 2. * PI).log() * 0.5; + loss = loss + + log_stirling_term + .mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like()); + } + loss + } + + /// Validates the input tensors for the loss computation. + /// + /// # Panics + /// - Panics if the shapes of `predictions` and `targets` do not match. + /// - Panics if any target value is negative. + /// - Panics if `log_input` is `false` and any prediction value is negative. + fn assertions( + &self, + predictions: &Tensor, + targets: &Tensor, + ) { + let predictions_dims = predictions.dims(); + let targets_dims = targets.dims(); + assert!( + predictions_dims == targets_dims, + "Shape of targets ({:?}) should correspond to outer shape of predictions ({:?}).", + targets_dims, + predictions_dims + ); + assert!( + targets.clone().greater_equal_elem(0.).all().into_scalar(), + "All the values of `targets` must be non-negative." + ); + if !self.log_input { + assert!( + predictions.clone().greater_equal_elem(0.).all().into_scalar(), + "When `log_input` is `false`, all the values of `predictions` must be non-negative." + ); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor::TensorData; + use crate::TestBackend; + type TestTensor = Tensor; + + #[test] + fn test_poisson_nll_loss() { + let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().init(); + + let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum); + let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); + let loss_no_reduction = poisson.forward_no_reduction(predictions, targets); + + let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]); + loss_no_reduction.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([21.0321]); + loss.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([126.1929]); + loss_sum.into_data().assert_approx_eq(&expected, 5); + } + + #[test] + fn test_poisson_nll_loss_no_log_input() { + let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]); + let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_log_input(false).init(); + + let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone()); + + let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]); + loss_no_reduction.into_data().assert_approx_eq(&expected, 5); + } + + #[test] + fn test_poisson_nll_loss_full() { + let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_full(true).init(); + + let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum); + let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); + let loss_no_reduction = poisson.forward_no_reduction(predictions, targets); + + let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]); + loss_no_reduction.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([21.9920]); + loss.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([131.9518]); + loss_sum.into_data().assert_approx_eq(&expected, 5); + } + + #[cfg(feature = "std")] + #[test] + fn test_poisson_nll_loss_gradients() { + type TestAutodiffTensor = Tensor; + + let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad(); + let predictions2 = predictions1.clone(); + let targets = TestAutodiffTensor::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_full(false).init(); + let poisson_full = PoissonNllLossConfig::new().with_full(true).init(); + + let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum); + let loss_full_sum = + poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum); + + let grads = loss_sum.backward(); + let grads_full = loss_full_sum.backward(); + + let grads_predictions1 = predictions1.grad(&grads).unwrap(); + let grads_predictions2 = predictions2.grad(&grads_full).unwrap(); + + let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]); + + grads_predictions1 + .into_data() + .assert_approx_eq(&expected, 5); + grads_predictions2 + .into_data() + .assert_approx_eq(&expected, 5); + } + + #[test] + #[should_panic = "eps for PoissonNllLoss must be a positive number."] + fn test_negative_eps() { + let _poisson = PoissonNllLossConfig::new().with_eps(0.).init(); + } + + #[test] + #[should_panic = "All the values of `targets` must be non-negative."] + fn test_targets_with_negative_values() { + let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().init(); + + let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); + } + + #[test] + #[should_panic = "Shape of targets"] + fn test_shape_tensors() { + let predictions = TensorData::from([0., 1., 2.]); + let targets = TensorData::from([0., 1.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().init(); + + let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone()); + } + + #[test] + #[should_panic = "When `log_input` is `false`, all the values of `predictions` must be non-negative."] + fn test_exp_predictions_non_negative() { + let predictions = TensorData::from([0.3, -0.1, 0.4]); + let targets = TensorData::from([0., 1., 0.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_log_input(false).init(); + + let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone()); + } + + #[test] + fn display() { + let config = PoissonNllLossConfig::new(); + let loss = config.init(); + + assert_eq!( + alloc::format!("{}", loss), + "PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}" + ); + } +}