diff --git a/Cargo.lock b/Cargo.lock index 6eb9713a50..62abf8ac0a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,7 +41,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", "zerocopy", @@ -62,21 +62,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" -[[package]] -name = "alloc-no-stdlib" -version = "2.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" - -[[package]] -name = "alloc-stdlib" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" -dependencies = [ - "alloc-no-stdlib", -] - [[package]] name = "allocator-api2" version = "0.2.21" @@ -189,7 +174,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -207,12 +192,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" -[[package]] -name = "arrayref" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" - [[package]] name = "arrayvec" version = "0.7.6" @@ -270,34 +249,25 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "async-trait" -version = "0.1.85" +version = "0.1.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" +checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", -] - -[[package]] -name = "atoi" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" -dependencies = [ - "num-traits", + "syn 2.0.98", ] [[package]] name = "atoi_simd" -version = "0.15.6" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" +checksum = "4790f9e8961209112beb783d85449b508673cf4a6a419c8449b210743ac4dbe9" [[package]] name = "atomic-waker" @@ -423,7 +393,7 @@ dependencies = [ "serial_test", "strum", "strum_macros", - "sysinfo 0.32.1", + "sysinfo", "tracing-subscriber", "wgpu", "wsl", @@ -524,19 +494,6 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2" -[[package]] -name = "blake3" -version = "1.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e" -dependencies = [ - "arrayref", - "arrayvec", - "cc", - "cfg-if", - "constant_time_eq 0.3.1", -] - [[package]] name = "blas-src" version = "0.10.0" @@ -572,27 +529,6 @@ dependencies = [ "objc2", ] -[[package]] -name = "brotli" -version = "6.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", - "brotli-decompressor", -] - -[[package]] -name = "brotli-decompressor" -version = "4.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", -] - [[package]] name = "bstr" version = "1.11.3" @@ -611,9 +547,9 @@ checksum = "c360505aed52b7ec96a3636c3f039d99103c37d1d9b4f7a8c743d3ea9ffcd03b" [[package]] name = "bumpalo" -version = "3.16.0" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "burn" @@ -653,7 +589,7 @@ version = "0.17.0" dependencies = [ "cubecl-common", "dashmap", - "getrandom", + "getrandom 0.2.15", "indicatif", "rayon", "reqwest", @@ -754,7 +690,7 @@ dependencies = [ "derive-new 0.7.0", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -805,7 +741,7 @@ dependencies = [ "rust-format", "serde", "serde_json", - "syn 2.0.96", + "syn 2.0.98", "thiserror 2.0.11", "tracing-core", "tracing-subscriber", @@ -957,7 +893,7 @@ dependencies = [ "ratatui", "rstest", "serde", - "sysinfo 0.32.1", + "sysinfo", "systemstat", "tracing-appender", "tracing-core", @@ -993,7 +929,7 @@ checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1010,9 +946,12 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" +dependencies = [ + "serde", +] [[package]] name = "bytesize" @@ -1055,7 +994,7 @@ dependencies = [ "gemm", "half", "libc", - "memmap2 0.9.5", + "memmap2", "metal 0.27.0", "num-traits", "num_cpus", @@ -1118,9 +1057,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.9" +version = "1.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" +checksum = "e4730490333d58093109dc02c23174c3f4d490998c3fed3cc8e82d57afedb9cf" dependencies = [ "jobserver", "libc", @@ -1159,16 +1098,15 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", - "serde", "wasm-bindgen", "windows-targets 0.52.6", ] [[package]] name = "chrono-tz" -version = "0.8.6" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e" +checksum = "9c6ac4f2c0bf0f44e9161aec9675e1050aa4a530663c4a9e37e108fa948bca9f" dependencies = [ "chrono", "chrono-tz-build", @@ -1177,42 +1115,14 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.2.1" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" +checksum = "e94fea34d77a245229e7746bd2beb786cd2a896f306ff491fb8cecb3074b10a7" dependencies = [ "parse-zoneinfo", - "phf", "phf_codegen", ] -[[package]] -name = "ciborium" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" -dependencies = [ - "ciborium-io", - "ciborium-ll", - "serde", -] - -[[package]] -name = "ciborium-io" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" - -[[package]] -name = "ciborium-ll" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" -dependencies = [ - "ciborium-io", - "half", -] - [[package]] name = "cipher" version = "0.4.4" @@ -1225,9 +1135,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" +checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796" dependencies = [ "clap_builder", "clap_derive", @@ -1235,9 +1145,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" +checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" dependencies = [ "anstream", "anstyle", @@ -1251,10 +1161,10 @@ version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1274,9 +1184,9 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.52" +version = "0.1.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c682c223677e0e5b6b7f63a64b9351844c3f1b1678a68b7ee617e30fb082620e" +checksum = "e24a03c8b52922d68a1589ad61032f2c1aa5a8158d2aa0d93c6e9534944bbad6" dependencies = [ "cc", ] @@ -1394,16 +1304,6 @@ dependencies = [ "libc", ] -[[package]] -name = "core-foundation" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1417,7 +1317,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c07782be35f9e1140080c6b96f0d44b739e2278479f64e02fdab4e32dfd8b081" dependencies = [ "bitflags 1.3.2", - "core-foundation 0.9.4", + "core-foundation", "core-graphics-types", "foreign-types 0.5.0", "libc", @@ -1430,15 +1330,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" dependencies = [ "bitflags 1.3.2", - "core-foundation 0.9.4", + "core-foundation", "libc", ] [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -1537,9 +1437,9 @@ dependencies = [ [[package]] name = "crunchy" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] name = "crypto-common" @@ -1595,7 +1495,7 @@ dependencies = [ "derive_more 1.0.0", "embassy-futures", "futures-lite", - "getrandom", + "getrandom 0.2.15", "half", "log", "num-traits", @@ -1672,9 +1572,9 @@ dependencies = [ [[package]] name = "cubecl-hip-sys" -version = "6.3.1000" +version = "6.3.1001" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4d987c1720eab39c72c515377a8001f683a4c4d99232a29fc0de389d9a8ce4f" +checksum = "c7e92df7f9feff6a469932fc4d4b349d28000af9e6f34e583eb4f8df70038d48" dependencies = [ "libc", ] @@ -1684,11 +1584,16 @@ name = "cubecl-ir" version = "0.5.0" dependencies = [ "cubecl-common", + "cubecl-macros-internal", + "derive_more 1.0.0", "float-ord", + "fnv", "half", + "hashbrown 0.14.5", "num-traits", + "portable-atomic", "serde", - "type_hash", + "variadics_please", ] [[package]] @@ -1713,7 +1618,17 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", +] + +[[package]] +name = "cubecl-macros-internal" +version = "0.5.0" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.98", ] [[package]] @@ -1797,9 +1712,9 @@ dependencies = [ [[package]] name = "cudarc" -version = "0.12.2" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd76de2aa3a7bdb9a65941ea5a3c688d941688f736a81b2fc5beb88747a7f25" +checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384" dependencies = [ "half", "libloading", @@ -1894,7 +1809,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1905,7 +1820,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1951,7 +1866,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1962,7 +1877,7 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1973,7 +1888,7 @@ checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1994,7 +1909,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2004,7 +1919,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2015,7 +1930,7 @@ checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2035,7 +1950,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "unicode-xid", ] @@ -2091,15 +2006,9 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] -[[package]] -name = "doc-comment" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" - [[package]] name = "document-features" version = "0.2.10" @@ -2121,9 +2030,9 @@ dependencies = [ [[package]] name = "dyn-clone" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" [[package]] name = "dyn-stack" @@ -2140,9 +2049,6 @@ name = "either" version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" -dependencies = [ - "serde", -] [[package]] name = "embassy-futures" @@ -2171,10 +2077,10 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2186,7 +2092,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2305,10 +2211,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" [[package]] -name = "fast-float" -version = "0.2.0" +name = "fast-float2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" +checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55" [[package]] name = "faster-hex" @@ -2416,7 +2322,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2440,16 +2346,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "fs4" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c6b3bd49c37d2aa3f3f2220233b29a7cd23f79d1fe70e5337d25fb390793de" -dependencies = [ - "rustix", - "windows-sys 0.52.0", -] - [[package]] name = "futures" version = "0.3.31" @@ -2519,7 +2415,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2705,10 +2601,22 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets 0.52.6", +] + [[package]] name = "gif" version = "0.13.1" @@ -2776,15 +2684,15 @@ dependencies = [ [[package]] name = "gix-trace" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04bdde120c29f1fc23a24d3e115aeeea3d60d8e65bab92cc5f9d90d9302eb952" +checksum = "7c396a2036920c69695f760a65e7f2677267ccf483f25046977d87e4cb2665f7" [[package]] name = "gix-utils" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba427e3e9599508ed98a6ddf8ed05493db114564e338e41f6a996d2e4790335f" +checksum = "ff08f24e03ac8916c478c8419d7d3c33393da9bb41fa4c24455d5406aeefd35f" dependencies = [ "fastrand", "unicode-normalization", @@ -2946,16 +2854,6 @@ dependencies = [ "serde", ] -[[package]] -name = "halfbrown" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8588661a8607108a5ca69cab034063441a0413a0b041c13618a7dd348021ef6f" -dependencies = [ - "hashbrown 0.14.5", - "serde", -] - [[package]] name = "hashbrown" version = "0.13.2" @@ -2999,12 +2897,6 @@ dependencies = [ "hashbrown 0.14.5", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -3106,9 +2998,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.9.5" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" +checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" [[package]] name = "httpdate" @@ -3124,9 +3016,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "1.5.2" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" dependencies = [ "bytes", "futures-channel", @@ -3154,7 +3046,6 @@ dependencies = [ "hyper", "hyper-util", "rustls", - "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", "tokio-rustls", @@ -3334,7 +3225,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -3440,9 +3331,9 @@ checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -3451,9 +3342,9 @@ dependencies = [ [[package]] name = "indicatif" -version = "0.17.9" +version = "0.17.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" dependencies = [ "console", "number_prefix", @@ -3487,7 +3378,7 @@ dependencies = [ "indoc", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -3507,14 +3398,14 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "ipnet" -version = "2.10.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is_terminal_polyfill" @@ -3555,12 +3446,6 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" -[[package]] -name = "itoap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8" - [[package]] name = "jni-sys" version = "0.3.0" @@ -3629,9 +3514,9 @@ checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libfuzzer-sys" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b9569d2f74e257076d8c6bfa73fb505b46b851e51ddaecc825944aa3bed17fa" +checksum = "cf78f52d400cf2d84a3a973a78a592b4adc535739e0a5597a0da6f0c357adc75" dependencies = [ "arbitrary", "cc", @@ -3825,16 +3710,6 @@ dependencies = [ "rayon", ] -[[package]] -name = "md-5" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" -dependencies = [ - "cfg-if", - "digest", -] - [[package]] name = "md5" version = "0.7.0" @@ -3847,15 +3722,6 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" -[[package]] -name = "memmap2" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" -dependencies = [ - "libc", -] - [[package]] name = "memmap2" version = "0.9.5" @@ -3866,15 +3732,6 @@ dependencies = [ "stable_deref_trait", ] -[[package]] -name = "memoffset" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" -dependencies = [ - "autocfg", -] - [[package]] name = "metal" version = "0.27.0" @@ -3934,9 +3791,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" +checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924" dependencies = [ "adler2", "simd-adler32", @@ -3950,7 +3807,7 @@ checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", "log", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -4001,29 +3858,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", -] - -[[package]] -name = "multiversion" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142" -dependencies = [ - "multiversion-macros", - "target-features", -] - -[[package]] -name = "multiversion-macros" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", - "target-features", + "syn 2.0.98", ] [[package]] @@ -4040,7 +3875,7 @@ dependencies = [ "hexf-parse", "indexmap", "log", - "rustc-hash 1.1.0", + "rustc-hash", "spirv 0.3.0+sdk-1.3.268.0", "strum", "termcolor", @@ -4058,9 +3893,9 @@ dependencies = [ [[package]] name = "native-tls" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +checksum = "0dab59f8e050d5df8e4dd87d9206fb6f65a483e20ac9fda365ade4fab353196c" dependencies = [ "libc", "log", @@ -4068,7 +3903,7 @@ dependencies = [ "openssl-probe", "openssl-sys", "schannel", - "security-framework 2.11.1", + "security-framework", "security-framework-sys", "tempfile", ] @@ -4226,7 +4061,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -4298,7 +4133,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -4407,9 +4242,9 @@ dependencies = [ [[package]] name = "objc2-encode" -version = "4.0.3" +version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7891e71393cd1f227313c9379a26a584ff3d7e6e7159e988851f0934c993f0f8" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" [[package]] name = "objc2-foundation" @@ -4466,36 +4301,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "object_store" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6da452820c715ce78221e8202ccc599b4a52f3e1eb3eedb487b680c81a8e3f3" -dependencies = [ - "async-trait", - "base64 0.22.1", - "bytes", - "chrono", - "futures", - "humantime", - "hyper", - "itertools 0.13.0", - "md-5", - "parking_lot 0.12.3", - "percent-encoding", - "quick-xml", - "rand", - "reqwest", - "ring", - "serde", - "serde_json", - "snafu", - "tokio", - "tracing", - "url", - "walkdir", -] - [[package]] name = "once_cell" version = "1.20.2" @@ -4590,9 +4395,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.68" +version = "0.10.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" +checksum = "61cfb4e166a8bb8c9b55c500bc2308550148ece889be90f609377e58140f42c6" dependencies = [ "bitflags 2.8.0", "cfg-if", @@ -4611,20 +4416,20 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "openssl-probe" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openssl-sys" -version = "0.9.104" +version = "0.9.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" +checksum = "8b22d5b84be05a8d6947c7cb71f7c849aa0f112acd4bf51c2a7c1c988ac0a9dc" dependencies = [ "cc", "libc", @@ -4862,11 +4667,11 @@ dependencies = [ [[package]] name = "polars" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65c6aa86d991a64c95416a61202f7952da2f8cccefa448f9a23c1b8f2301ecc" +checksum = "72571dde488ecccbe799798bf99ab7308ebdb7cf5d95bcc498dbd5a132f0da4d" dependencies = [ - "getrandom", + "getrandom 0.2.15", "polars-arrow", "polars-core", "polars-error", @@ -4882,12 +4687,11 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87dbb24d29ddea5abb73d7954df8b8d3d4bb7f02a3e5c96d1519cdad9e816a3d" +checksum = "6611c758d52e799761cc25900666b71552e6c929d88052811bc9daad4b3321a8" dependencies = [ "ahash", - "atoi", "atoi_simd", "bytemuck", "chrono", @@ -4895,21 +4699,16 @@ dependencies = [ "dyn-clone", "either", "ethnum", - "fast-float", - "getrandom", + "getrandom 0.2.15", "hashbrown 0.15.2", "itoa", - "itoap", "lz4", - "multiversion", "num-traits", "parking_lot 0.12.3", "polars-arrow-format", "polars-error", "polars-schema", "polars-utils", - "ryu", - "serde", "simdutf8", "streaming-iterator", "strength_reduce", @@ -4930,25 +4729,30 @@ dependencies = [ [[package]] name = "polars-compute" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbdb1071147452a4c4b25560f23d2fbaffef255b04757291131b22fc2c0d35b2" +checksum = "332f2547dbb27599a8ffe68e56159f5996ba03d1dad0382ccb62c109ceacdeb6" dependencies = [ + "atoi_simd", "bytemuck", + "chrono", "either", + "fast-float2", + "itoa", "num-traits", "polars-arrow", "polars-error", "polars-utils", + "ryu", "strength_reduce", "version_check", ] [[package]] name = "polars-core" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd5df9b55e614088a3270b06f8649dce76537c268d6b1ca4d9c37008b2be5949" +checksum = "796d06eae7e6e74ed28ea54a8fccc584ebac84e6cf0e1e9ba41ffc807b169a01" dependencies = [ "ahash", "bitflags 2.8.0", @@ -4960,6 +4764,7 @@ dependencies = [ "hashbrown 0.14.5", "hashbrown 0.15.2", "indexmap", + "itoa", "num-traits", "once_cell", "polars-arrow", @@ -4972,32 +4777,29 @@ dependencies = [ "rand_distr", "rayon", "regex", - "serde", - "serde_json", "strum_macros", - "thiserror 1.0.69", + "thiserror 2.0.11", "version_check", "xxhash-rust", ] [[package]] name = "polars-error" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4643898a644f30c83737db85f942f8c8956b0c11190b39afec745218eae1746b" +checksum = "19d6529cae0d1db5ed690e47de41fac9b35ae0c26d476830c2079f130887b847" dependencies = [ - "object_store", "polars-arrow-format", "regex", "simdutf8", - "thiserror 1.0.69", + "thiserror 2.0.11", ] [[package]] name = "polars-expr" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea1b431ed816cba1120cff200f06b962748001bbb2e615ce53cfbbdf701cc136" +checksum = "c8e639991a8ad4fb12880ab44bcc3cf44a5703df003142334d9caf86d77d77e7" dependencies = [ "ahash", "bitflags 2.8.0", @@ -5019,82 +4821,52 @@ dependencies = [ [[package]] name = "polars-io" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2fab2c016635cb416b49461fd6419b0208c6c13a4fd065bd65e4a87dbb66314" +checksum = "719a77e94480f6be090512da196e378cbcbeb3584c6fe1134c600aee906e38ab" dependencies = [ "ahash", "async-trait", "atoi_simd", - "blake3", "bytes", "chrono", - "fast-float", - "fs4", + "fast-float2", "futures", "glob", "hashbrown 0.15.2", "home", "itoa", "memchr", - "memmap2 0.7.1", + "memmap2", "num-traits", - "object_store", "once_cell", "percent-encoding", "polars-arrow", "polars-core", "polars-error", - "polars-json", "polars-parquet", "polars-schema", "polars-time", "polars-utils", - "pyo3", "rayon", "regex", - "reqwest", "ryu", - "serde", - "serde_json", - "simd-json", "simdutf8", "tokio", "tokio-util", - "url", ] [[package]] -name = "polars-json" -version = "0.44.2" +name = "polars-lazy" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5c8c057ef04feaf34b6ce52096bdea3a766fa4725f50442078c8a4ee86397bf" +checksum = "a0a731a672dfc8ac38c1f73c9a4b2ae38d2fc8ac363bfb64c5f3a3e072ffc5ad" dependencies = [ "ahash", + "bitflags 2.8.0", "chrono", - "fallible-streaming-iterator", - "hashbrown 0.15.2", - "indexmap", - "itoa", - "num-traits", - "polars-arrow", - "polars-error", - "polars-utils", - "ryu", - "simd-json", - "streaming-iterator", -] - -[[package]] -name = "polars-lazy" -version = "0.44.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a8ca74f42e7b47cad241b36b98d991cc7fbb51b8d0695a055eb937588d1f310" -dependencies = [ - "ahash", - "bitflags 2.8.0", - "memchr", - "once_cell", + "memchr", + "once_cell", "polars-arrow", "polars-core", "polars-expr", @@ -5112,32 +4884,28 @@ dependencies = [ [[package]] name = "polars-mem-engine" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a32614e5b52c9b83856d80c7e2880b79d83055bfd59969bd1d0b148f9cfdc7a" +checksum = "33442189bcbf2e2559aa7914db3835429030a13f4f18e43af5fba9d1b018cf12" dependencies = [ - "futures", - "memmap2 0.7.1", + "memmap2", "polars-arrow", "polars-core", "polars-error", "polars-expr", "polars-io", - "polars-json", "polars-ops", "polars-plan", "polars-time", "polars-utils", - "pyo3", "rayon", - "tokio", ] [[package]] name = "polars-ops" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "035c800fbe5bbd820afeb8313713ed345853bb014e0f821a4025d40cf0d60e1a" +checksum = "cbb83218b0c216104f0076cd1a005128be078f958125f3d59b094ee73d78c18e" dependencies = [ "ahash", "argminmax", @@ -5151,6 +4919,7 @@ dependencies = [ "indexmap", "memchr", "num-traits", + "once_cell", "polars-arrow", "polars-compute", "polars-core", @@ -5160,39 +4929,33 @@ dependencies = [ "rayon", "regex", "regex-syntax 0.8.5", - "serde", "strum_macros", + "unicode-normalization", "unicode-reverse", "version_check", ] [[package]] name = "polars-parquet" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91dcf1d9f048079376949eaf2e24e240b313ff4a102fb83b57c9a5f807cdca52" +checksum = "5c60ee85535590a38db6c703a21be4cb25342e40f573f070d1e16f9d84a53ac7" dependencies = [ "ahash", "async-stream", "base64 0.22.1", - "brotli", "bytemuck", "ethnum", - "flate2", "futures", "hashbrown 0.15.2", - "lz4", "num-traits", "polars-arrow", "polars-compute", "polars-error", "polars-parquet-format", "polars-utils", - "serde", "simdutf8", - "snap", "streaming-decompression", - "zstd 0.13.2", ] [[package]] @@ -5207,15 +4970,16 @@ dependencies = [ [[package]] name = "polars-pipe" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05936f2b3981eecb2fe74d8ef092bb75a93d2a056b3e4f339f4ac20c71c9e331" +checksum = "42d238fb76698f56e51ddfa89b135e4eda56a4767c6e8859eed0ab78386fcd52" dependencies = [ "crossbeam-channel", "crossbeam-queue", "enum_dispatch", "hashbrown 0.15.2", "num-traits", + "once_cell", "polars-arrow", "polars-compute", "polars-core", @@ -5232,9 +4996,9 @@ dependencies = [ [[package]] name = "polars-plan" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23de436f33f4d1134c58f24e7059a221b957ec20730807e0ef0c80c8e4b3d06a" +checksum = "4f03533a93aa66127fcb909a87153a3c7cfee6f0ae59f497e73d7736208da54c" dependencies = [ "ahash", "bitflags 2.8.0", @@ -5242,65 +5006,59 @@ dependencies = [ "bytes", "chrono", "chrono-tz", - "ciborium", "either", - "futures", "hashbrown 0.15.2", - "memmap2 0.7.1", + "memmap2", "num-traits", "once_cell", "percent-encoding", "polars-arrow", + "polars-compute", "polars-core", "polars-io", - "polars-json", "polars-ops", - "polars-parquet", "polars-time", "polars-utils", - "pyo3", "rayon", "recursive", "regex", - "serde", "strum_macros", "version_check", ] [[package]] name = "polars-row" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3823d3de3e614509bba6929798f1f3d5ae05c1cdfc4eb7029d2ec6ad77201da2" +checksum = "6bf47f7409f8e75328d7d034be390842924eb276716d0458607be0bddb8cc839" dependencies = [ + "bitflags 2.8.0", "bytemuck", "polars-arrow", + "polars-compute", "polars-error", "polars-utils", ] [[package]] name = "polars-schema" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d88667f770291cefa2e8cd366a54f29dc6fe362e9a263914c903db411a58ac1d" +checksum = "416621ae82b84466cf4ff36838a9b0aeb4a67e76bd3065edc8c9cb7da19b1bc7" dependencies = [ "indexmap", "polars-error", "polars-utils", - "serde", "version_check", ] [[package]] name = "polars-sql" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69451f08363bb497407f6ebebe00bc01972a51716d20d115b75f9b5326f1f3c8" +checksum = "edaab553b90aa4d6743bb538978e1982368acb58a94408d7dd3299cad49c7083" dependencies = [ "hex", - "once_cell", - "polars-arrow", "polars-core", "polars-error", "polars-lazy", @@ -5309,22 +5067,22 @@ dependencies = [ "polars-time", "polars-utils", "rand", + "regex", "serde", - "serde_json", "sqlparser", ] [[package]] name = "polars-stream" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "188622b0a4bc4530cf91a288134254ffa065d18932e261075377914225e757c2" +checksum = "498997b656c779610c1496b3d96a59fe569ef22a5b81ccfe5325cb3df8dff2fd" dependencies = [ "atomic-waker", "crossbeam-deque", "crossbeam-utils", "futures", - "memmap2 0.7.1", + "memmap2", "parking_lot 0.12.3", "pin-project-lite", "polars-core", @@ -5332,6 +5090,7 @@ dependencies = [ "polars-expr", "polars-io", "polars-mem-engine", + "polars-ops", "polars-parquet", "polars-plan", "polars-utils", @@ -5345,31 +5104,33 @@ dependencies = [ [[package]] name = "polars-time" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90f36e4d6b19f2c406faea585b9a1814f422fc5b310f65ccf8a55216df0754ef" +checksum = "d192efbdab516d28b3fab1709a969e3385bd5cda050b7c9aa9e2502a01fda879" dependencies = [ - "atoi", + "atoi_simd", "bytemuck", "chrono", "chrono-tz", "now", + "num-traits", "once_cell", "polars-arrow", + "polars-compute", "polars-core", "polars-error", "polars-ops", "polars-utils", + "rayon", "regex", - "serde", "strum_macros", ] [[package]] name = "polars-utils" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96186b70bda00c90b5027bf2f69193c5c40571e80d3e8ec505c22cdc8e3e39aa" +checksum = "a8f6c8166a4a7fbc15b87c81645ed9e1f0651ff2e8c96cafc40ac5bf43441a10" dependencies = [ "ahash", "bytemuck", @@ -5378,16 +5139,15 @@ dependencies = [ "hashbrown 0.15.2", "indexmap", "libc", - "memmap2 0.7.1", + "memmap2", "num-traits", "once_cell", "polars-error", - "pyo3", - "raw-cpuid 11.2.0", + "rand", + "raw-cpuid 11.3.0", "rayon", - "serde", "stacker", - "sysinfo 0.31.4", + "sysinfo", "version_check", ] @@ -5396,6 +5156,9 @@ name = "portable-atomic" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +dependencies = [ + "serde", +] [[package]] name = "portable-atomic-util" @@ -5444,7 +5207,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" dependencies = [ "proc-macro2", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -5481,7 +5244,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30" dependencies = [ "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -5557,69 +5320,6 @@ dependencies = [ "reborrow", ] -[[package]] -name = "pyo3" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" -dependencies = [ - "cfg-if", - "indoc", - "libc", - "memoffset", - "parking_lot 0.12.3", - "portable-atomic", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" -dependencies = [ - "once_cell", - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn 2.0.96", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "pyo3-build-config", - "quote", - "syn 2.0.96", -] - [[package]] name = "pytorch-import" version = "0.17.0" @@ -5656,68 +5356,6 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" -[[package]] -name = "quick-xml" -version = "0.36.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7649a7b4df05aed9ea7ec6f628c67c9953a43869b8bc50929569b2999d443fe" -dependencies = [ - "memchr", - "serde", -] - -[[package]] -name = "quinn" -version = "0.11.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" -dependencies = [ - "bytes", - "pin-project-lite", - "quinn-proto", - "quinn-udp", - "rustc-hash 2.1.0", - "rustls", - "socket2", - "thiserror 2.0.11", - "tokio", - "tracing", -] - -[[package]] -name = "quinn-proto" -version = "0.11.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" -dependencies = [ - "bytes", - "getrandom", - "rand", - "ring", - "rustc-hash 2.1.0", - "rustls", - "rustls-pki-types", - "slab", - "thiserror 2.0.11", - "tinyvec", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-udp" -version = "0.5.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed904" -dependencies = [ - "cfg_aliases", - "libc", - "once_cell", - "socket2", - "tracing", - "windows-sys 0.59.0", -] - [[package]] name = "quote" version = "1.0.38" @@ -5776,7 +5414,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", ] [[package]] @@ -5791,9 +5429,9 @@ dependencies = [ [[package]] name = "range-alloc" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8a99fddc9f0ba0a85884b8d14e3592853e787d581ca1816c91349b10e4eeab" +checksum = "c3d6831663a5098ea164f89cff59c6284e95f4e3c76ce9848d4529f5ccca9bde" [[package]] name = "ratatui" @@ -5878,9 +5516,9 @@ dependencies = [ [[package]] name = "raw-cpuid" -version = "11.2.0" +version = "11.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" +checksum = "c6928fa44c097620b706542d428957635951bade7143269085389d42c8a4927e" dependencies = [ "bitflags 2.8.0", ] @@ -5951,7 +5589,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -5978,31 +5616,11 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom", + "getrandom 0.2.15", "libredox", "thiserror 1.0.69", ] -[[package]] -name = "ref-cast" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccf0a6f84d5f1d581da8b41b47ec8600871962f2a528115b542b362d4b744931" -dependencies = [ - "ref-cast-impl", -] - -[[package]] -name = "ref-cast-impl" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.96", -] - [[package]] name = "regex" version = "1.11.1" @@ -6087,11 +5705,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "quinn", - "rustls", - "rustls-native-certs 0.8.1", "rustls-pemfile", - "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", @@ -6099,14 +5713,11 @@ dependencies = [ "system-configuration", "tokio", "tokio-native-tls", - "tokio-rustls", - "tokio-util", "tower", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-streams", "web-sys", "windows-registry", ] @@ -6128,7 +5739,7 @@ checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.15", "libc", "spin", "untrusted", @@ -6162,7 +5773,7 @@ name = "rspirv" version = "0.12.0+sdk-1.3.296.0" source = "git+https://github.com/gfx-rs/rspirv.git?rev=e19c11fdb30295127cff1d018189bd436892415e#e19c11fdb30295127cff1d018189bd436892415e" dependencies = [ - "rustc-hash 1.1.0", + "rustc-hash", "spirv 0.3.0+sdk-1.3.296.0", ] @@ -6192,7 +5803,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.96", + "syn 2.0.98", "unicode-ident", ] @@ -6232,12 +5843,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" -[[package]] -name = "rustc-hash" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" - [[package]] name = "rustc_version" version = "0.4.1" @@ -6249,9 +5854,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.43" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags 2.8.0", "errno", @@ -6262,9 +5867,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.21" +version = "0.23.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" +checksum = "9fb9263ab4eb695e42321db096e3b8fbd715a59b154d5c88d82db2175b681ba7" dependencies = [ "log", "once_cell", @@ -6285,19 +5890,7 @@ dependencies = [ "rustls-pemfile", "rustls-pki-types", "schannel", - "security-framework 2.11.1", -] - -[[package]] -name = "rustls-native-certs" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" -dependencies = [ - "openssl-probe", - "rustls-pki-types", - "schannel", - "security-framework 3.2.0", + "security-framework", ] [[package]] @@ -6311,12 +5904,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" -dependencies = [ - "web-time", -] +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" [[package]] name = "rustls-webpki" @@ -6337,9 +5927,9 @@ checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "safetensors" @@ -6435,20 +6025,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.8.0", - "core-foundation 0.9.4", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" -dependencies = [ - "bitflags 2.8.0", - "core-foundation 0.10.0", + "core-foundation", "core-foundation-sys", "libc", "security-framework-sys", @@ -6466,9 +6043,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "seq-macro" @@ -6513,14 +6090,14 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "serde_json" -version = "1.0.137" +version = "1.0.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "930cfb6e6abf99298aaad7d29abbef7a9999a9a8806a40088f55f0dcec03146b" +checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" dependencies = [ "itoa", "memchr", @@ -6591,7 +6168,7 @@ checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6675,23 +6252,6 @@ version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" -[[package]] -name = "simd-json" -version = "0.14.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2bcf6c6e164e81bc7a5d49fc6988b3d515d9e8c07457d7b74ffb9324b9cd40" -dependencies = [ - "ahash", - "getrandom", - "halfbrown", - "once_cell", - "ref-cast", - "serde", - "serde_json", - "simdutf8", - "value-trait", -] - [[package]] name = "simd_helpers" version = "0.1.0" @@ -6748,34 +6308,6 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" -[[package]] -name = "snafu" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" -dependencies = [ - "doc-comment", - "snafu-derive", -] - -[[package]] -name = "snafu-derive" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "snap" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" - [[package]] name = "socket2" version = "0.5.8" @@ -6827,9 +6359,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.49.0" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a404d0e14905361b918cb8afdb73605e25c1d5029312bd9785142dcb3aa49e" +checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8" dependencies = [ "log", ] @@ -6910,11 +6442,11 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6930,15 +6462,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", - "quote", "unicode-ident", ] [[package]] name = "syn" -version = "2.0.96" +version = "2.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" +checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" dependencies = [ "proc-macro2", "quote", @@ -6962,7 +6493,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6981,22 +6512,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.31.4" +version = "0.33.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "355dbe4f8799b304b05e1b0f05fc59b2a18d36645cf169607da45bde2f69a1be" -dependencies = [ - "core-foundation-sys", - "libc", - "memchr", - "ntapi", - "windows 0.57.0", -] - -[[package]] -name = "sysinfo" -version = "0.32.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c33cd241af0f2e9e3b5c32163b873b29956890b5342e6745b917ce9d490f4af" +checksum = "4fc858248ea01b66f19d8e8a6d55f41deaf91e9d495246fd01368d99935c6c01" dependencies = [ "core-foundation-sys", "libc", @@ -7014,7 +6532,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags 2.8.0", - "core-foundation 0.9.4", + "core-foundation", "system-configuration-sys", ] @@ -7035,7 +6553,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349" dependencies = [ "cfg-expr", - "heck 0.5.0", + "heck", "pkg-config", "toml", "version-compare", @@ -7066,12 +6584,6 @@ dependencies = [ "xattr", ] -[[package]] -name = "target-features" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" - [[package]] name = "target-lexicon" version = "0.12.16" @@ -7097,13 +6609,13 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.15.0" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" +checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" dependencies = [ "cfg-if", "fastrand", - "getrandom", + "getrandom 0.3.1", "once_cell", "rustix", "windows-sys 0.59.0", @@ -7186,7 +6698,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7197,7 +6709,7 @@ checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7297,7 +6809,7 @@ dependencies = [ "aho-corasick", "derive_builder", "esaxx-rs", - "getrandom", + "getrandom 0.2.15", "hf-hub", "itertools 0.12.1", "lazy_static", @@ -7344,7 +6856,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7415,9 +6927,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.22" +version = "0.22.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +checksum = "02a8b472d1a3d7c18e2d61a489aee3453fd9031c33e4f55bd533f4a7adca1bee" dependencies = [ "indexmap", "serde", @@ -7496,7 +7008,7 @@ checksum = "5a3a646485f7cd8f580749ab94718ad3d344bcc0cc5b0fefe43c15fdd898bb96" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7531,7 +7043,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7603,38 +7115,7 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "deb68604048ff8fa93347f02441e4487594adc20bb8a084f9e564d2b827a0a9f" dependencies = [ - "rustc-hash 1.1.0", -] - -[[package]] -name = "type_hash" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c86f48f11992d3e379358c63cb25736c0b23944ff000d1583bbccad2b0b7c6" -dependencies = [ - "type_hash_core", - "type_hash_macros", -] - -[[package]] -name = "type_hash_core" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87b1e93e2cd97790892dbe2d2813fbaa6eebaeb960265f59e363e79e51e4997a" -dependencies = [ - "fnv", -] - -[[package]] -name = "type_hash_macros" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "746fc164e076483ef087b3989f7aa80ffd9320fa558f3cb72cecfb9bb1dbc41e" -dependencies = [ - "either", - "proc-macro2", - "quote", - "syn 1.0.109", + "rustc-hash", ] [[package]] @@ -7687,9 +7168,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" [[package]] name = "unicode-normalization" @@ -7759,12 +7240,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" -[[package]] -name = "unindent" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" - [[package]] name = "untrusted" version = "0.9.0" @@ -7783,7 +7258,7 @@ dependencies = [ "native-tls", "once_cell", "rustls", - "rustls-native-certs 0.7.3", + "rustls-native-certs", "rustls-pki-types", "serde", "serde_json", @@ -7828,11 +7303,11 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744018581f9a3454a9e15beb8a33b017183f1e7c0cd170232a2d1453b23a51c4" +checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" dependencies = [ - "getrandom", + "getrandom 0.2.15", "rand", ] @@ -7849,21 +7324,9 @@ dependencies = [ [[package]] name = "valuable" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" - -[[package]] -name = "value-trait" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9170e001f458781e92711d2ad666110f153e4e50bfd5cbd02db6547625714187" -dependencies = [ - "float-cmp", - "halfbrown", - "itoa", - "ryu", -] +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" [[package]] name = "variadics_please" @@ -7873,7 +7336,7 @@ checksum = "41b6d82be61465f97d42bd1d15bf20f3b0a3a0905018f38f9d6f6962055b0b5c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7919,6 +7382,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -7941,7 +7413,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "wasm-bindgen-shared", ] @@ -7976,7 +7448,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -8001,19 +7473,6 @@ dependencies = [ "web-sys", ] -[[package]] -name = "wasm-streams" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" -dependencies = [ - "futures-util", - "js-sys", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - [[package]] name = "wasm-timer" version = "0.2.5" @@ -8051,9 +7510,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.7" +version = "0.26.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" +checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9" dependencies = [ "rustls-pki-types", ] @@ -8074,9 +7533,9 @@ dependencies = [ [[package]] name = "wgpu" -version = "24.0.0" +version = "24.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e41253fc7b660735e2a2d9a58c563f2a047d3cc3445293d8f4095538c9e8afbe" +checksum = "47f55718f85c2fa756edffa0e7f0e0a60aba463d1362b57e23123c58f035e4b6" dependencies = [ "arrayvec", "bitflags 2.8.0", @@ -8116,7 +7575,7 @@ dependencies = [ "parking_lot 0.12.3", "profiling", "raw-window-handle", - "rustc-hash 1.1.0", + "rustc-hash", "smallvec", "thiserror 2.0.11", "wgpu-hal", @@ -8159,7 +7618,7 @@ dependencies = [ "range-alloc", "raw-window-handle", "renderdoc-sys", - "rustc-hash 1.1.0", + "rustc-hash", "smallvec", "thiserror 2.0.11", "wasm-bindgen", @@ -8286,7 +7745,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8297,7 +7756,7 @@ checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8308,7 +7767,7 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8319,7 +7778,7 @@ checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8511,13 +7970,22 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.24" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" +checksum = "7e49d2d35d3fad69b39b94139037ecfb4f359f08958b9c11e7315ce770462419" dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags 2.8.0", +] + [[package]] name = "wrapcenum-derive" version = "0.4.1" @@ -8527,7 +7995,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8624,7 +8092,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "synstructure", ] @@ -8646,7 +8114,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8666,7 +8134,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "synstructure", ] @@ -8687,7 +8155,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8709,7 +8177,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index c4723d0f7a..e225c304aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ version = "0.17.0" atomic_float = "1" bytemuck = "1.21.0" candle-core = { version = "0.8" } -clap = { version = "4.5.26", features = ["derive"] } +clap = { version = "4.5.27", features = ["derive"] } colored = "2.1.0" console_error_panic_hook = "0.1.7" csv = "1.3.1" @@ -47,14 +47,14 @@ globwalk = "0.9.1" hashbrown = "0.15.2" hound = "3.5.1" image = "0.25.5" -indicatif = "0.17.9" +indicatif = "0.17.11" js-sys = "0.3.72" libm = "0.2.11" log = { default-features = false, version = "0.4.25" } md5 = "0.7.0" paste = "1" percent-encoding = "2.3.1" -polars = { version = "0.44.2", features = ["lazy"] } +polars = { version = "0.46.0", features = ["lazy"] } pretty_assertions = "1.4.1" proc-macro2 = "1.0.93" protobuf = "3.7.1" @@ -101,7 +101,7 @@ ratatui = "0.29.0" # WGPU stuff text_placeholder = "0.5.1" -wgpu = "24.0.0" +wgpu = "24.0.1" # Benchmarks and Burnbench arboard = "3.4.1" @@ -141,11 +141,11 @@ serde = { version = "1.0.217", default-features = false, features = [ "alloc", ] } # alloc is for no_std, derive is needed serde_json = { version = "1.0.137", default-features = false } -uuid = { version = "1.12.0", default-features = false } +uuid = { version = "1.12.1", default-features = false } libc = "0.2.169" nvml-wrapper = "0.10.0" -sysinfo = "0.32.1" +sysinfo = "0.33.1" systemstat = "0.2.3" tch = "0.15.0" diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 265dbeaaf0..821d189fe0 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -15,10 +15,10 @@ candle-accelerate = ["burn/candle", "burn/accelerate"] candle-cpu = ["burn/candle"] candle-cuda = ["burn/candle-cuda"] candle-metal = ["burn/candle", "burn/metal"] -cuda-jit = ["burn/cuda-jit"] -cuda-jit-fusion = ["cuda-jit", "burn/fusion"] +cuda = ["burn/cuda"] +cuda-fusion = ["cuda", "burn/fusion"] default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"] -hip-jit = ["burn/hip-jit"] +hip = ["burn/hip"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] @@ -27,7 +27,7 @@ tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu", "burn/autotune"] wgpu-fusion = ["wgpu", "burn/fusion"] -wgpu-spirv = ["burn/wgpu-spirv", "burn/autotune"] +wgpu-spirv = ["burn/vulkan", "burn/autotune"] wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"] [dependencies] diff --git a/backend-comparison/src/burnbenchapp/base.rs b/backend-comparison/src/burnbenchapp/base.rs index 9eba1485b3..4fb31edab8 100644 --- a/backend-comparison/src/burnbenchapp/base.rs +++ b/backend-comparison/src/burnbenchapp/base.rs @@ -62,6 +62,13 @@ enum BackendValues { CandleCuda, #[strum(to_string = "candle-metal")] CandleMetal, + #[strum(to_string = "cuda")] + Cuda, + #[strum(to_string = "cuda-fusion")] + CudaFusion, + #[cfg(target_os = "linux")] + #[strum(to_string = "hip")] + Hip, #[strum(to_string = "ndarray")] Ndarray, #[strum(to_string = "ndarray-blas-accelerate")] @@ -82,13 +89,6 @@ enum BackendValues { WgpuSpirv, #[strum(to_string = "wgpu-spirv-fusion")] WgpuSpirvFusion, - #[strum(to_string = "cuda-jit")] - CudaJit, - #[strum(to_string = "cuda-jit-fusion")] - CudaJitFusion, - #[cfg(target_os = "linux")] - #[strum(to_string = "hip-jit")] - HipJit, } #[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)] diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index 26b08bc3b8..b3351e9dd5 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -91,12 +91,12 @@ macro_rules! bench_on_backend { let feature_name = "wgpu-spirv"; #[cfg(feature = "wgpu-spirv-fusion")] let feature_name = "wgpu-spirv-fusion"; - #[cfg(feature = "cuda-jit")] - let feature_name = "cuda-jit"; - #[cfg(feature = "cuda-jit-fusion")] - let feature_name = "cuda-jit-fusion"; - #[cfg(feature = "hip-jit")] - let feature_name = "hip-jit"; + #[cfg(feature = "cuda")] + let feature_name = "cuda"; + #[cfg(feature = "cuda-fusion")] + let feature_name = "cuda-fusion"; + #[cfg(feature = "hip")] + let feature_name = "hip"; #[cfg(any(feature = "wgpu"))] { @@ -172,16 +172,16 @@ macro_rules! bench_on_backend { $fn_name::(&device, feature_name, url, token); } - #[cfg(feature = "cuda-jit")] + #[cfg(feature = "cuda")] { - use burn::backend::cuda_jit::{Cuda, CudaDevice}; + use burn::backend::cuda::{Cuda, CudaDevice}; $fn_name::>(&CudaDevice::default(), feature_name, url, token); } - #[cfg(feature = "hip-jit")] + #[cfg(feature = "hip")] { - use burn::backend::hip_jit::{Hip, HipDevice}; + use burn::backend::hip::{Hip, HipDevice}; $fn_name::>(&HipDevice::default(), feature_name, url, token); } diff --git a/backend-comparison/src/persistence/system_info.rs b/backend-comparison/src/persistence/system_info.rs index 287b629c21..3fe24bc955 100644 --- a/backend-comparison/src/persistence/system_info.rs +++ b/backend-comparison/src/persistence/system_info.rs @@ -38,7 +38,7 @@ impl BenchmarkSystemInfo { fn enumerate_cpus() -> Vec { let system = sysinfo::System::new_with_specifics( - sysinfo::RefreshKind::new().with_cpu(sysinfo::CpuRefreshKind::everything()), + sysinfo::RefreshKind::nothing().with_cpu(sysinfo::CpuRefreshKind::everything()), ); let cpu_names: HashSet = system .cpus() diff --git a/burn-book/src/advanced/no-std.md b/burn-book/src/advanced/no-std.md index 5f5621cc51..e55afc904d 100644 --- a/burn-book/src/advanced/no-std.md +++ b/burn-book/src/advanced/no-std.md @@ -68,7 +68,7 @@ We are using ndarray, so we just need to define the NdArray backend as usual use burn::{backend::NdArray, tensor::Tensor}; type Backend = NdArray; -type BackendDeice = ::Device; +type BackendDevice = ::Device; ``` Then inside the `main` function add @@ -76,7 +76,7 @@ Then inside the `main` function add use your_model::Model; // Get a default device for the backend -let device = BackendDeice::default(); +let device = BackendDevice::default(); // Create a new model and load the state let model: Model = Model::default(); diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index 8a7c01bbc9..3713ee2571 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -131,47 +131,47 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. -| Burn | PyTorch Equivalent | -| ------------------------------------- | ------------------------------------------------------------------------- | -| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` | -| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` | -| `Tensor::from_primitive(primitive)` | N/A | -| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` | -| `tensor.all()` | `tensor.all()` | -| `tensor.all_dim(dim)` | `tensor.all(dim)` | -| `tensor.any()` | `tensor.any()` | -| `tensor.any_dim(dim)` | `tensor.any(dim)` | -| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` | -| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` | -| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` | -| `tensor.device()` | `tensor.device` | -| `tensor.dtype()` | `tensor.dtype` | -| `tensor.dims()` | `tensor.size()` | -| `tensor.equal(other)` | `x == y` | -| `tensor.expand(shape)` | `tensor.expand(shape)` | -| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | -| `tensor.flip(axes)` | `tensor.flip(axes)` | -| `tensor.into_data()` | N/A | -| `tensor.into_primitive()` | N/A | -| `tensor.into_scalar()` | `tensor.item()` | -| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | -| `tensor.not_equal(other)` | `x != y` | -| `tensor.permute(axes)` | `tensor.permute(axes)` | -| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | -| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` | -| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` | -| `tensor.reshape(shape)` | `tensor.view(shape)` | -| `tensor.shape()` | `tensor.shape` | -| `tensor.slice(ranges)` | `tensor[(*ranges,)]` | -| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` | -| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` | -| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | -| `tensor.to_data()` | N/A | -| `tensor.to_device(device)` | `tensor.to(device)` | -| `tensor.transpose()` | `tensor.T` | -| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` | -| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` | -| `tensor.unsqueeze_dims(dims)` | N/A | +| Burn | PyTorch Equivalent | +| ------------------------------------------- | ------------------------------------------------------------------------- | +| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` | +| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` | +| `Tensor::from_primitive(primitive)` | N/A | +| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` | +| `tensor.all()` | `tensor.all()` | +| `tensor.all_dim(dim)` | `tensor.all(dim)` | +| `tensor.any()` | `tensor.any()` | +| `tensor.any_dim(dim)` | `tensor.any(dim)` | +| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` | +| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` | +| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` | +| `tensor.device()` | `tensor.device` | +| `tensor.dtype()` | `tensor.dtype` | +| `tensor.dims()` | `tensor.size()` | +| `tensor.equal(other)` | `x == y` | +| `tensor.expand(shape)` | `tensor.expand(shape)` | +| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | +| `tensor.flip(axes)` | `tensor.flip(axes)` | +| `tensor.into_data()` | N/A | +| `tensor.into_primitive()` | N/A | +| `tensor.into_scalar()` | `tensor.item()` | +| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | +| `tensor.not_equal(other)` | `x != y` | +| `tensor.permute(axes)` | `tensor.permute(axes)` | +| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | +| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` | +| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` | +| `tensor.reshape(shape)` | `tensor.view(shape)` | +| `tensor.shape()` | `tensor.shape` | +| `tensor.slice(ranges)` | `tensor[(*ranges,)]` | +| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` | +| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` | +| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | +| `tensor.to_data()` | N/A | +| `tensor.to_device(device)` | `tensor.to(device)` | +| `tensor.transpose()` | `tensor.T` | +| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` | +| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` | +| `tensor.unsqueeze_dims(dims)` | N/A | ### Numeric Operations @@ -258,32 +258,32 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. Those operations are only available for `Float` tensors. -| Burn API | PyTorch Equivalent | -| --------------------------------------------- | ---------------------------------- | -| `tensor.cast(dtype)` | `tensor.to(dtype)` | -| `tensor.ceil()` | `tensor.ceil()` | -| `tensor.cos()` | `tensor.cos()` | -| `tensor.erf()` | `tensor.erf()` | -| `tensor.exp()` | `tensor.exp()` | -| `tensor.floor()` | `tensor.floor()` | -| `tensor.from_floats(floats, device)` | N/A | -| `tensor.from_full_precision(tensor)` | N/A | -| `tensor.int()` | Similar to `tensor.to(torch.long)` | -| `tensor.log()` | `tensor.log()` | -| `tensor.log1p()` | `tensor.log1p()` | -| `tensor.matmul(other)` | `tensor.matmul(other)` | -| `tensor.random(shape, distribution, device)` | N/A | -| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform | -| `tensor.recip()` | `tensor.reciprocal()` | -| `tensor.round()` | `tensor.round()` | -| `tensor.sin()` | `tensor.sin()` | -| `tensor.sqrt()` | `tensor.sqrt()` | -| `tensor.tanh()` | `tensor.tanh()` | -| `tensor.to_full_precision()` | `tensor.to(torch.float)` | -| `tensor.var(dim)` | `tensor.var(dim)` | -| `tensor.var_bias(dim)` | N/A | -| `tensor.var_mean(dim)` | N/A | -| `tensor.var_mean_bias(dim)` | N/A | +| Burn API | PyTorch Equivalent | +| -------------------------------------------- | ---------------------------------- | +| `tensor.cast(dtype)` | `tensor.to(dtype)` | +| `tensor.ceil()` | `tensor.ceil()` | +| `tensor.cos()` | `tensor.cos()` | +| `tensor.erf()` | `tensor.erf()` | +| `tensor.exp()` | `tensor.exp()` | +| `tensor.floor()` | `tensor.floor()` | +| `tensor.from_floats(floats, device)` | N/A | +| `tensor.from_full_precision(tensor)` | N/A | +| `tensor.int()` | Similar to `tensor.to(torch.long)` | +| `tensor.log()` | `tensor.log()` | +| `tensor.log1p()` | `tensor.log1p()` | +| `tensor.matmul(other)` | `tensor.matmul(other)` | +| `tensor.random(shape, distribution, device)` | N/A | +| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform | +| `tensor.recip()` | `tensor.reciprocal()` | +| `tensor.round()` | `tensor.round()` | +| `tensor.sin()` | `tensor.sin()` | +| `tensor.sqrt()` | `tensor.sqrt()` | +| `tensor.tanh()` | `tensor.tanh()` | +| `tensor.to_full_precision()` | `tensor.to(torch.float)` | +| `tensor.var(dim)` | `tensor.var(dim)` | +| `tensor.var_bias(dim)` | N/A | +| `tensor.var_mean(dim)` | N/A | +| `tensor.var_mean_bias(dim)` | N/A | ### Int Operations @@ -293,6 +293,17 @@ Those operations are only available for `Int` tensors. | ------------------------------------------------ | ------------------------------------------------------- | | `Tensor::arange(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` | | `Tensor::arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` | +| `tensor.bitwise_and(other)` | `torch.bitwise_and(tensor, other)` | +| `tensor.bitwise_and_scalar(scalar)` | `torch.bitwise_and(tensor, scalar)` | +| `tensor.bitwise_not()` | `torch.bitwise_not(tensor)` | +| `tensor.bitwise_left_shift(other)` | `torch.bitwise_left_shift(tensor, other)` | +| `tensor.bitwise_left_shift_scalar(scalar)` | `torch.bitwise_left_shift(tensor, scalar)` | +| `tensor.bitwise_right_shift(other)` | `torch.bitwise_right_shift(tensor, other)` | +| `tensor.bitwise_right_shift_scalar(scalar)` | `torch.bitwise_right_shift(tensor, scalar)` | +| `tensor.bitwise_or(other)` | `torch.bitwise_or(tensor, other)` | +| `tensor.bitwise_or_scalar(scalar)` | `torch.bitwise_or(tensor, scalar)` | +| `tensor.bitwise_xor(other)` | `torch.bitwise_xor(tensor, other)` | +| `tensor.bitwise_xor_scalar(scalar)` | `torch.bitwise_xor(tensor, scalar)` | | `tensor.float()` | `tensor.to(torch.float)` | | `tensor.from_ints(ints)` | N/A | | `tensor.int_random(shape, distribution, device)` | N/A | @@ -328,7 +339,7 @@ strategies. | Burn API | PyTorch Equivalent | | ------------------------------------------------ | -------------------------------------------------- | | `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` | -| `activation::hard_sigmoid(tensor, alpha, beta) | `nn.functional.hardsigmoid(tensor)` | +| `activation::hard_sigmoid(tensor, alpha, beta)` | `nn.functional.hardsigmoid(tensor)` | | `activation::leaky_relu(tensor, negative_slope)` | `nn.functional.leaky_relu(tensor, negative_slope)` | | `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` | | `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` | diff --git a/burn-book/src/saving-and-loading.md b/burn-book/src/saving-and-loading.md index 77f7c863d6..24b52dd22a 100644 --- a/burn-book/src/saving-and-loading.md +++ b/burn-book/src/saving-and-loading.md @@ -4,7 +4,7 @@ Saving your trained machine learning model is quite easy, no matter the output f mentioned in the [Record](./building-blocks/record.md) section, different formats are supported to serialize/deserialize models. By default, we use the `NamedMpkFileRecorder` which uses the [MessagePack](https://msgpack.org/) binary serialization format with the help of -[smp_serde](https://docs.rs/rmp-serde/). +[rmp_serde](https://docs.rs/rmp-serde/). ```rust, ignore // Save model in MessagePack format with full precision diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index 4aad98bb46..f3439d1cad 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -348,4 +348,48 @@ impl IntTensorOps for Autodiff { fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { B::int_argsort(tensor, dim, descending) } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_and(lhs, rhs) + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_and_scalar(lhs, rhs) + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_or(lhs, rhs) + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_or_scalar(lhs, rhs) + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_xor(lhs, rhs) + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_xor_scalar(lhs, rhs) + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + B::bitwise_not(tensor) + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_left_shift(lhs, rhs) + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_left_shift_scalar(lhs, rhs) + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_right_shift(lhs, rhs) + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_right_shift_scalar(lhs, rhs) + } } diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 4ae0c53de7..08b84251fa 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -372,4 +372,47 @@ impl IntTensorOps for Candle) -> IntTensor { sign(tensor) } + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_and is not implemented for Candle IntTensor"); + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_and_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_or is not implemented for Candle IntTensor"); + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_or_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_xor is not implemented for Candle IntTensor"); + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_xor_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + unimplemented!("bitwise_not is not implemented for Candle IntTensor"); + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_left_shift is not implemented for Candle IntTensor"); + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_right_shift is not implemented for Candle IntTensor"); + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_left_shift_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_right_shift_scalar is not implemented for Candle IntTensor"); + } } diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index e895cc4572..423dc784d8 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -36,8 +36,8 @@ doc = [ "ndarray", "tch", "wgpu", - "cuda-jit", - "hip-jit", + "cuda", + "hip", "audio", "vision", "autodiff", @@ -88,7 +88,7 @@ fusion = ["burn-wgpu?/fusion", "burn-cuda?/fusion"] ## Backend features accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"] -autotune = ["burn-wgpu?/autotune"] +autotune = ["burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-hip?/autotune"] blas-netlib = ["burn-ndarray?/blas-netlib"] metal = ["burn-candle?/metal"] openblas = ["burn-ndarray?/blas-openblas"] @@ -100,12 +100,13 @@ template = ["burn-wgpu?/template"] candle = ["burn-candle"] candle-cuda = ["candle", "burn-candle/cuda"] -cuda-jit = ["burn-cuda"] -hip-jit = ["burn-hip"] +cuda = ["burn-cuda"] +hip = ["burn-hip"] ndarray = ["burn-ndarray"] tch = ["burn-tch"] wgpu = ["burn-wgpu"] -wgpu-spirv = ["wgpu", "burn-wgpu/spirv"] +vulkan = ["wgpu", "burn-wgpu/vulkan"] +webgpu = ["wgpu", "burn-wgpu/webgpu"] # Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files. record-item-custom-serde = ["thiserror", "regex"] @@ -113,13 +114,13 @@ record-item-custom-serde = ["thiserror", "regex"] # Serialization formats experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] -test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray. -test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray. +test-cuda = ["cuda"] # To use cuda during testing, default uses ndarray. +test-hip = ["hip"] # To use hip during testing, default uses ndarray. test-tch = ["tch"] # To use tch during testing, default uses ndarray. test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray. test-wgpu-spirv = [ "test-wgpu", - "wgpu-spirv", + "vulkan", ] # To use wgpu-spirv during testing, default uses ndarray. [dependencies] diff --git a/crates/burn-core/src/backend.rs b/crates/burn-core/src/backend.rs index bd4c959302..31ac3a8c41 100644 --- a/crates/burn-core/src/backend.rs +++ b/crates/burn-core/src/backend.rs @@ -21,11 +21,17 @@ pub use burn_wgpu as wgpu; #[cfg(feature = "wgpu")] pub use burn_wgpu::Wgpu; -#[cfg(feature = "cuda-jit")] -pub use burn_cuda as cuda_jit; +#[cfg(feature = "webgpu")] +pub use burn_wgpu::WebGpu; -#[cfg(feature = "cuda-jit")] -pub use burn_cuda::Cuda as CudaJit; +#[cfg(feature = "vulkan")] +pub use burn_wgpu::Vulkan; + +#[cfg(feature = "cuda")] +pub use burn_cuda as cuda; + +#[cfg(feature = "cuda")] +pub use burn_cuda::Cuda; #[cfg(feature = "candle")] pub use burn_candle as candle; @@ -33,11 +39,11 @@ pub use burn_candle as candle; #[cfg(feature = "candle")] pub use burn_candle::Candle; -#[cfg(feature = "hip-jit")] -pub use burn_hip as hip_jit; +#[cfg(feature = "hip")] +pub use burn_hip as hip; -#[cfg(feature = "hip-jit")] -pub use burn_hip::Hip as HipJit; +#[cfg(feature = "hip")] +pub use burn_hip::Hip; #[cfg(feature = "tch")] pub use burn_tch as libtorch; diff --git a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs index f645c84fd9..54b80f4f60 100644 --- a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs +++ b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs @@ -118,9 +118,9 @@ impl BinaryCrossEntropyLoss { (targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits) } else { // - (target * log(input) + (1 - target) * log(1 - input)) - (targets_float.clone() * logits.clone().log() - + (targets_float.neg() + 1.) * (logits.neg() + 1.).log()) - .neg() + // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values + (targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0) + - targets_float * logits.log().clamp_min(-100.0) }; if let Some(weights) = &self.weights { @@ -171,6 +171,38 @@ mod tests { use crate::tensor::{activation::sigmoid, TensorData}; use crate::TestBackend; + #[test] + fn test_binary_cross_entropy_preds_all_correct() { + let device = Default::default(); + let preds = Tensor::::from_floats([1.0, 0.0, 1.0, 0.0], &device); + let targets = + Tensor::::from_data(TensorData::from([1, 0, 1, 0]), &device); + + let loss_actual = BinaryCrossEntropyLossConfig::new() + .init(&device) + .forward(preds, targets) + .into_data(); + + let loss_expected = TensorData::from([0.000]); + loss_actual.assert_approx_eq(&loss_expected, 3); + } + + #[test] + fn test_binary_cross_entropy_preds_all_incorrect() { + let device = Default::default(); + let preds = Tensor::::from_floats([0.0, 1.0, 0.0, 1.0], &device); + let targets = + Tensor::::from_data(TensorData::from([1, 0, 1, 0]), &device); + + let loss_actual = BinaryCrossEntropyLossConfig::new() + .init(&device) + .forward(preds, targets) + .into_data(); + + let loss_expected = TensorData::from([100.000]); // clamped value + loss_actual.assert_approx_eq(&loss_expected, 3); + } + #[test] fn test_binary_cross_entropy() { // import torch diff --git a/crates/burn-dataset/src/dataset/dataframe.rs b/crates/burn-dataset/src/dataset/dataframe.rs index 023b357454..c851e8a3e3 100644 --- a/crates/burn-dataset/src/dataset/dataframe.rs +++ b/crates/burn-dataset/src/dataset/dataframe.rs @@ -269,20 +269,20 @@ mod tests { } fn create_test_dataframe() -> DataFrame { - let s0 = Column::Series(Series::new("int32".into(), &[1i32, 2i32, 3i32])); - let s1 = Column::Series(Series::new("bool".into(), &[true, false, true])); - let s2 = Column::Series(Series::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64])); - let s3 = Column::Series(Series::new("string".into(), &["Boo", "Boo2", "Boo3"])); - let s6 = Column::Series(Series::new("int16".into(), &[1i16, 2i16, 3i16])); - let s8 = Column::Series(Series::new("uint32".into(), &[1u32, 2u32, 3u32])); - let s9 = Column::Series(Series::new("uint64".into(), &[1u64, 2u64, 3u64])); - let s10 = Column::Series(Series::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32])); - let s11 = Column::Series(Series::new("int64".into(), &[1i64, 2i64, 3i64])); - let s12 = Column::Series(Series::new("int8".into(), &[1i8, 2i8, 3i8])); + let s0 = Column::new("int32".into(), &[1i32, 2i32, 3i32]); + let s1 = Column::new("bool".into(), &[true, false, true]); + let s2 = Column::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64]); + let s3 = Column::new("string".into(), &["Boo", "Boo2", "Boo3"]); + let s6 = Column::new("int16".into(), &[1i16, 2i16, 3i16]); + let s8 = Column::new("uint32".into(), &[1u32, 2u32, 3u32]); + let s9 = Column::new("uint64".into(), &[1u64, 2u64, 3u64]); + let s10 = Column::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32]); + let s11 = Column::new("int64".into(), &[1i64, 2i64, 3i64]); + let s12 = Column::new("int8".into(), &[1i8, 2i8, 3i8]); let binary_data: Vec<&[u8]> = vec![&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]; - let s13 = Column::Series(Series::new("binary".into(), binary_data)); + let s13 = Column::new("binary".into(), binary_data); DataFrame::new(vec![s0, s1, s2, s3, s6, s8, s9, s10, s11, s12, s13]).unwrap() } diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index 1857a85e5b..a37fb98495 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -1,5 +1,6 @@ use burn_tensor::{ ops::{binary_ops_shape, FloatTensor, IntTensor}, + repr::{FromDataOperationDescription, TensorDescription}, DType, Element, TensorData, }; use std::marker::PhantomData; @@ -24,15 +25,32 @@ use burn_tensor::{ impl BoolTensorOps for Fusion { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + #[derive(new)] + struct EmptyOps { + desc: TensorDescription, + device: Device, + } + + impl Operation for EmptyOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::bool_empty(Shape::from(&self.desc.shape), &self.device); + handles.register_bool_tensor::(&self.desc.id, output); + } + } + + let stream = StreamId::current(); let client = get_client::(&device.clone()); - let tensor = B::bool_empty(shape.clone(), device); + let out = client.tensor_uninitialized(shape.dims.clone(), DType::Bool); - client.register_tensor( - B::bool_tensor_handle(tensor), - shape.dims, - StreamId::current(), - DType::Bool, - ) + let desc = out.to_description_out(); + + client.register( + vec![stream], + OperationDescription::BaseBool(BaseOperationDescription::Empty(desc.clone())), + EmptyOps::::new(desc, device.clone()), + ); + + out } async fn bool_into_data(tensor: BoolTensor) -> TensorData { @@ -40,16 +58,35 @@ impl BoolTensorOps for Fusion { } fn bool_from_data(data: burn_tensor::TensorData, device: &Device) -> BoolTensor { + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::bool_from_data(self.desc.data, &self.device); + handles.register_bool_tensor::(&self.desc.out.id, output); + } + } + + let stream = StreamId::current(); let client = get_client::(&device.clone()); - let tensor = B::bool_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); - - client.register_tensor( - B::bool_tensor_handle(tensor), - shape.dims, - StreamId::current(), - DType::Bool, - ) + let out = client.tensor_uninitialized(data.shape.clone(), DType::Bool); + + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseBool(BaseOperationDescription::FromData(desc.clone())), + FromDataOps::::new(desc, device.clone()), + ); + + out } fn bool_into_int(tensor: BoolTensor) -> IntTensor { diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 493224fa36..b798ac618f 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -16,16 +16,35 @@ use std::{marker::PhantomData, ops::Range}; impl FloatTensorOps for Fusion { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::float_from_data(self.desc.data, &self.device); + handles.register_float_tensor::(&self.desc.out.id, output); + } + } + + let stream = StreamId::current(); let client = get_client::(&device.clone()); - let tensor = B::float_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); - - client.register_tensor( - B::float_tensor_handle(tensor), - shape.dims, - StreamId::current(), - B::FloatElem::dtype(), - ) + let out = client.tensor_uninitialized(data.shape.clone(), B::FloatElem::dtype()); + + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::FromData(desc.clone())), + FromDataOps::::new(desc, device.clone()), + ); + + out } fn float_random( @@ -233,16 +252,32 @@ impl FloatTensorOps for Fusion { } fn float_empty(shape: Shape, device: &Device) -> FloatTensor { - let client = get_client::(&device.clone()); + #[derive(new)] + struct EmptyOps { + desc: TensorDescription, + device: Device, + } + + impl Operation for EmptyOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::float_empty(Shape::from(&self.desc.shape), &self.device); + handles.register_float_tensor::(&self.desc.id, output); + } + } + let stream = StreamId::current(); - let tensor = B::float_empty(shape.clone(), device); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(shape.dims.clone(), B::FloatElem::dtype()); - client.register_tensor( - B::float_tensor_handle(tensor), - shape.dims, - stream, - B::FloatElem::dtype(), - ) + let desc = out.to_description_out(); + + client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::Empty(desc.clone())), + EmptyOps::::new(desc, device.clone()), + ); + + out } fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 82343f7a1b..4ee4d6e804 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -15,16 +15,32 @@ use std::marker::PhantomData; impl IntTensorOps for Fusion { fn int_empty(shape: Shape, device: &Device) -> IntTensor { - let client = get_client::(&device.clone()); - let tensor = B::int_empty(shape.clone(), device); + #[derive(new)] + struct EmptyOps { + desc: TensorDescription, + device: Device, + } + + impl Operation for EmptyOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::int_empty(Shape::from(&self.desc.shape), &self.device); + handles.register_int_tensor::(&self.desc.id, output); + } + } + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(shape.dims.clone(), B::IntElem::dtype()); - client.register_tensor( - B::int_tensor_handle(tensor), - shape.dims, - stream, - B::IntElem::dtype(), - ) + let desc = out.to_description_out(); + + client.register( + vec![stream], + OperationDescription::BaseInt(BaseOperationDescription::Empty(desc.clone())), + EmptyOps::::new(desc, device.clone()), + ); + + out } async fn int_into_data(tensor: IntTensor) -> TensorData { @@ -32,17 +48,35 @@ impl IntTensorOps for Fusion { } fn int_from_data(data: TensorData, device: &Device) -> IntTensor { - let client = get_client::(&device.clone()); - let tensor = B::int_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::int_from_data(self.desc.data, &self.device); + handles.register_int_tensor::(&self.desc.out.id, output); + } + } + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(data.shape.clone(), B::IntElem::dtype()); - client.register_tensor( - B::int_tensor_handle(tensor), - shape.dims, - stream, - B::IntElem::dtype(), - ) + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseInt(BaseOperationDescription::FromData(desc.clone())), + FromDataOps::::new(desc, device.clone()), + ); + + out } fn int_device(tensor: &IntTensor) -> Device { @@ -1819,4 +1853,267 @@ impl IntTensorOps for Fusion { out } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseAndOps, B::bitwise_and); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseAnd(desc.clone())), + BitwiseAndOps::::new(desc), + ); + + out + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseAndOps, B::bitwise_and_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseAndScalar( + desc.clone(), + )), + BitwiseAndOps::::new(desc), + ); + + out + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseOrOps, B::bitwise_or); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseOr(desc.clone())), + BitwiseOrOps::::new(desc), + ); + + out + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseOrOps, B::bitwise_or_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseOrScalar(desc.clone())), + BitwiseOrOps::::new(desc), + ); + + out + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseXorOps, B::bitwise_xor); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseXor(desc.clone())), + BitwiseXorOps::::new(desc), + ); + + out + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseXorOps, B::bitwise_xor_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseXorScalar( + desc.clone(), + )), + BitwiseXorOps::::new(desc), + ); + + out + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + unary_int_ops!(BitwiseNotOps, B::bitwise_not); + + let stream = tensor.stream; + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseNot(desc.clone())), + BitwiseNotOps::::new(desc), + ); + + out + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseLeftShift( + desc.clone(), + )), + BitwiseLeftShiftOps::::new(desc), + ); + + out + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseLeftShiftScalar( + desc.clone(), + )), + BitwiseLeftShiftOps::::new(desc), + ); + + out + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseRightShift( + desc.clone(), + )), + BitwiseRightShiftOps::::new(desc), + ); + + out + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseRightShiftScalar( + desc.clone(), + )), + BitwiseRightShiftOps::::new(desc), + ); + + out + } } diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index 41bc7ccde6..1449a485af 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -4,8 +4,9 @@ use burn_tensor::{ ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{QuantizationParametersPrimitive, QuantizationScheme}, repr::{ - DequantizeOperationDescription, FloatOperationDescription, HandleContainer, - OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription, + BaseOperationDescription, DequantizeOperationDescription, FloatOperationDescription, + FromDataOperationDescription, HandleContainer, OperationDescription, + QuantizationParametersDescription, QuantizeOperationDescription, }, DType, Device, Element, Shape, TensorData, }; @@ -19,19 +20,41 @@ use crate::{ impl QTensorOps for Fusion { fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::q_from_data(self.desc.data, &self.device); + handles.register_quantized_tensor::(&self.desc.out.id, output); + } + } + match data.dtype { DType::QFloat(_scheme) => { let dtype = data.dtype; - let client = get_client::(device); - let tensor = B::q_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); - - client.register_tensor( - B::quantized_tensor_handle(tensor), - shape.dims, - StreamId::current(), - dtype, - ) + + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(data.shape.clone(), dtype); + + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::FromData( + desc.clone(), + )), + FromDataOps::::new(desc, device.clone()), + ); + + out } _ => panic!( "Invalid dtype (expected DType::QFloat, got {:?})", diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 1bb0bc5deb..671ecfb473 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -721,6 +721,82 @@ impl RelativeOps for IntOperationDescription { out: desc.out.to_relative(converter), }) } + IntOperationDescription::BitwiseAnd(desc) => { + IntOperationDescription::BitwiseAnd(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseAndScalar(desc) => { + IntOperationDescription::BitwiseAndScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseOr(desc) => { + IntOperationDescription::BitwiseOr(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseOrScalar(desc) => { + IntOperationDescription::BitwiseOrScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseXor(desc) => { + IntOperationDescription::BitwiseXor(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseXorScalar(desc) => { + IntOperationDescription::BitwiseXorScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseNot(desc) => { + IntOperationDescription::BitwiseNot(UnaryOperationDescription { + input: desc.input.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseLeftShift(desc) => { + IntOperationDescription::BitwiseLeftShift(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseLeftShiftScalar(desc) => { + IntOperationDescription::BitwiseLeftShiftScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseRightShift(desc) => { + IntOperationDescription::BitwiseRightShift(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseRightShiftScalar(desc) => { + IntOperationDescription::BitwiseRightShiftScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } } } } @@ -1161,6 +1237,12 @@ impl RelativeOps for BaseOperationDescription { BaseOperationDescription::Empty(desc) => { BaseOperationDescription::Empty(desc.to_relative(converter)) } + BaseOperationDescription::FromData(desc) => { + BaseOperationDescription::FromData(FromDataOperationDescription { + data: desc.data.clone(), + out: desc.out.to_relative(converter), + }) + } } } } diff --git a/crates/burn-hip/src/lib.rs b/crates/burn-hip/src/lib.rs index fc8f704e74..13f5239637 100644 --- a/crates/burn-hip/src/lib.rs +++ b/crates/burn-hip/src/lib.rs @@ -26,7 +26,8 @@ pub type Hip = burn_fusion::Fusion Self { + fn init(self, _context: &mut Scope) -> Self { self } } diff --git a/crates/burn-jit/src/fusion/matmul/optimization.rs b/crates/burn-jit/src/fusion/matmul/optimization.rs index 9a020df62c..804628613d 100644 --- a/crates/burn-jit/src/fusion/matmul/optimization.rs +++ b/crates/burn-jit/src/fusion/matmul/optimization.rs @@ -87,7 +87,7 @@ impl MatmulOptimization { fused_matmul_autotune::(self, context); #[cfg(not(feature = "autotune"))] - if self.execute_fused::(context).is_err() { + if self.execute_standard_fused::(context).is_err() { self.execute_fallback::(context); } } diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index ad774cd24d..d189badcdf 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -50,13 +50,13 @@ impl CubeType for Arg { } impl Init for Arg { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _context: &mut Scope) -> Self { self } } impl IntoRuntime for Arg { - fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType { + fn __expand_runtime_method(self, _context: &mut Scope) -> Self::ExpandType { self } } diff --git a/crates/burn-jit/src/kernel/binary.rs b/crates/burn-jit/src/kernel/binary.rs index d7c4d789ab..f0da764a7a 100644 --- a/crates/burn-jit/src/kernel/binary.rs +++ b/crates/burn-jit/src/kernel/binary.rs @@ -112,7 +112,7 @@ pub(crate) fn kernel_scalar_binop( output: &mut Tensor>, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); @@ -132,7 +132,7 @@ pub(crate) fn kernel_binop( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { diff --git a/crates/burn-jit/src/kernel/binary_int.rs b/crates/burn-jit/src/kernel/binary_int.rs new file mode 100644 index 0000000000..390bfc479e --- /dev/null +++ b/crates/burn-jit/src/kernel/binary_int.rs @@ -0,0 +1,276 @@ +use crate::{ops::numeric::empty_device, tensor::JitTensor, IntElement, JitRuntime}; +use burn_tensor::Shape; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +use super::into_contiguous; + +pub(crate) trait BinaryOpIntFamily: Send + Sync + 'static { + type BinaryOp: BinaryOpInt; +} + +#[cube] +pub(crate) trait BinaryOpInt: 'static + Send + Sync { + /// Execute a binary operation. + fn execute(lhs: Line, rhs: Line) -> Line; +} + +pub(crate) struct BitwiseAndOp; +pub(crate) struct BitwiseOrOp; +pub(crate) struct BitwiseXorOp; +pub(crate) struct BitwiseShrOp; +pub(crate) struct BitwiseShlOp; + +impl BinaryOpIntFamily for BitwiseAndOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseOrOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseXorOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseShrOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseShlOp { + type BinaryOp = Self; +} + +#[cube] +impl BinaryOpInt for BitwiseAndOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs & rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseOrOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs | rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseXorOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs ^ rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseShrOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs >> rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseShlOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs << rhs + } +} + +#[cube(launch_unchecked)] +pub(crate) fn kernel_scalar_binop_int( + input: &Tensor>, + scalar: C, + output: &mut Tensor>, +) { + if ABSOLUTE_POS >= output.len() { + terminate!(); + } + + output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); +} + +#[cube(launch_unchecked)] +pub(crate) fn kernel_binop_int( + lhs: &Tensor>, + rhs: &Tensor>, + out: &mut Tensor>, + #[comptime] rank: Option, + #[comptime] to_contiguous_lhs: bool, + #[comptime] to_contiguous_rhs: bool, +) { + let offset_out = ABSOLUTE_POS; + let mut offset_lhs = ABSOLUTE_POS; + let mut offset_rhs = ABSOLUTE_POS; + + if offset_out >= out.len() { + terminate!(); + } + + if to_contiguous_lhs { + offset_lhs = index_offset_with_layout::( + lhs, + out, + offset_out, + 0, + rank.unwrap_or_else(|| out.rank()), + rank.is_some(), + ); + } + + if to_contiguous_rhs { + offset_rhs = index_offset_with_layout::( + rhs, + out, + offset_out, + 0, + rank.unwrap_or_else(|| out.rank()), + rank.is_some(), + ); + } + + out[offset_out] = O::BinaryOp::::execute(lhs[offset_lhs], rhs[offset_rhs]); +} + +pub(crate) fn launch_binop_int( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + let ndims = lhs.shape.num_dims(); + let line_size_lhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &lhs.shape.dims, + &lhs.strides, + ndims - 1, + ); + let line_size_rhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &rhs.shape.dims, + &rhs.strides, + ndims - 1, + ); + let line_size = Ord::min(line_size_lhs, line_size_rhs); + + let mut shape_out = vec![0; ndims]; + lhs.shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + + let shape_out = Shape::from(shape_out); + let client = lhs.client.clone(); + let num_elems = shape_out.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if lhs.can_mut_broadcast(&rhs) { + kernel_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(0), + None, + false, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + ); + + lhs + } else if rhs.can_mut_broadcast(&lhs) { + kernel_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(1), + None, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + false, + ); + + rhs + } else { + let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); + let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; + let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; + + kernel_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + None, + to_contiguous_lhs, + to_contiguous_rhs, + ); + + output + } + } +} + +pub(crate) fn launch_scalar_binop_int( + mut tensor: JitTensor, + scalar: E, +) -> JitTensor { + if !tensor.is_contiguous_buffer() { + tensor = into_contiguous(tensor); + } + + // Vectorization is only enabled when the last dimension is contiguous. + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if tensor.can_mut() { + kernel_scalar_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + TensorArg::alias(0), + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + kernel_scalar_binop_int::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + output.as_tensor_arg::(line_size), + ); + + output + } + } +} diff --git a/crates/burn-jit/src/kernel/cast/base.rs b/crates/burn-jit/src/kernel/cast/base.rs index 798b79a0f0..43b24f071a 100644 --- a/crates/burn-jit/src/kernel/cast/base.rs +++ b/crates/burn-jit/src/kernel/cast/base.rs @@ -12,7 +12,7 @@ pub(crate) fn cast_element( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } let offset_input = index_offset_with_layout::( diff --git a/crates/burn-jit/src/kernel/comparison.rs b/crates/burn-jit/src/kernel/comparison.rs index e33687fb5a..a6de9025bb 100644 --- a/crates/burn-jit/src/kernel/comparison.rs +++ b/crates/burn-jit/src/kernel/comparison.rs @@ -82,7 +82,7 @@ pub(crate) fn kernel_scalar_cmp>( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } output[ABSOLUTE_POS] = Line::cast_from(O::execute(input[ABSOLUTE_POS], Line::new(scalar))); @@ -102,7 +102,7 @@ pub(crate) fn kernel_cmp>( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs index 11fb3b4aee..4f6931f86d 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -241,7 +241,7 @@ fn col2im_kernel( #[comptime] has_bias: bool, ) { if ABSOLUTE_POS >= image.len() { - return; + terminate!(); } let im_x = ABSOLUTE_POS % image.shape(3) + args.pad_w; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs index c724cfc3a3..1cd24f7c0c 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs @@ -35,7 +35,7 @@ fn direct_conv2d_kernel( #[comptime] kernel_size_1_unroll: Option, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_channels = weight.shape(1); diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs index f36f89bdf5..ad70a9b825 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs @@ -7,7 +7,7 @@ use burn_tensor::{ use cubecl::{ flex32, ir::{Elem, FloatKind}, - linalg::matmul::{self}, + linalg::matmul::{self, kernels::MatmulLaunchError}, tensor_line_size, tf32, Feature, }; use half::{bf16, f16}; @@ -195,18 +195,14 @@ where let cube_count = Alg::cube_count(&selection, &problem); let advanced_config = Default::default(); - let config = match Alg::make_config( + let config = Alg::make_config( config_input, &problem, &cube_dim, &cube_count, &advanced_config, - ) { - Ok(val) => val, - Err(err) => { - panic!("Can't launch conv kernel because of an invalid config: {err}") - } - }; + ) + .map_err(MatmulLaunchError::InvalidConfig)?; let bias = bias.unwrap_or_else(|| { empty_device::(input.client.clone(), input.device.clone(), Shape::new([1])) diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index 6b738ab988..09ce56898b 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -53,7 +53,7 @@ fn im2col_kernel( let out_w = args.out_w; if ABSOLUTE_POS > args.num_elements { - return; + terminate!(); } let out_x = ABSOLUTE_POS % out_w; @@ -98,25 +98,38 @@ fn im2col_kernel( } #[cfg(not(test))] -pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option { - let cube_count_per_batch = (out_h * out_w).div_ceil(burn_common::PLANE_DIM_APPROX); +pub(crate) fn batches_per_run( + batch_size: usize, + out_h: usize, + out_w: usize, +) -> Result { + use cubecl::linalg::matmul::kernels::MatmulAvailabilityError; + + let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::PLANE_DIM_APPROX); let max_cube_count = u16::MAX as usize; let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size); if max_simultaneous == 0 { - return None; + return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static( + cube_count_per_batch as u32, + 1, + 1, + )) + .into()); } - Some( - (0..=max_simultaneous) - .rev() - .find(|per_run| batch_size % per_run == 0) - .expect("Logically not possible"), - ) + Ok((0..=max_simultaneous) + .rev() + .find(|per_run| batch_size % per_run == 0) + .expect("Logically not possible")) } #[cfg(test)] #[allow(unused)] -pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option { - Some(1) +pub(crate) fn batches_per_run( + batch_size: usize, + out_h: usize, + out_w: usize, +) -> Result { + Ok(1) } fn im2col( @@ -214,8 +227,7 @@ pub fn conv2d_im2col( return execute_1x1_kernel::(input, weight, bias, options); } - let batches_per_run = batches_per_run(batch_size, out_h, out_w) - .expect("Image too large to run even one batch at once"); + let batches_per_run = batches_per_run(batch_size, out_h, out_w)?; let matmul_shape = Shape::new([groups, out_c_per_group, batches_per_run * out_h * out_w]); let mut out = if batches_per_run != batch_size { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs index 62f0e56d8f..7cbe09dbc0 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs @@ -107,7 +107,7 @@ fn nchw_to_nhwc_kernel( let batch = CUBE_POS_Z; if batch >= input.shape(0) { - return; + terminate!(); } let batch_offset = batch * input.stride(0); @@ -163,7 +163,7 @@ fn nchw_to_nhwc_kernel( let hw = base_hw + mat_hw; if hw >= shape_hw { - return; + terminate!(); } let mat_c_start = mat_hw_start; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index d3e91d5947..a8cd1ceb7f 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -32,7 +32,7 @@ fn conv_transpose2d_direct_kernel( args: ConvArgs, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_c_per_group = weight.shape(0) / args.groups; diff --git a/crates/burn-jit/src/kernel/conv/conv3d.rs b/crates/burn-jit/src/kernel/conv/conv3d.rs index 157610794b..a616c432b9 100644 --- a/crates/burn-jit/src/kernel/conv/conv3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv3d.rs @@ -41,7 +41,7 @@ fn conv3d_kernel( #[comptime] kernel_size_2_unroll: Option, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_channels = weight.shape(1); diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index ddee1360e4..5840f4dc9f 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -275,7 +275,7 @@ fn deform_col2img_coord_kernel( // Alternatively : [batch, offset_channels, out_h, out_w] if ABSOLUTE_POS >= grad_offset.len() { - return; + terminate!(); } let offset_channels = offset.shape(1); @@ -551,7 +551,7 @@ fn deform_col2img_kernel( ) { // Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] if ABSOLUTE_POS >= columns.len() { - return; + terminate!(); } let n_in_channels = args.in_channels; diff --git a/crates/burn-jit/src/kernel/conv/error.rs b/crates/burn-jit/src/kernel/conv/error.rs index 99c91fc751..2654a20e24 100644 --- a/crates/burn-jit/src/kernel/conv/error.rs +++ b/crates/burn-jit/src/kernel/conv/error.rs @@ -1,5 +1,8 @@ use core::fmt::Debug; -use cubecl::{linalg::matmul::kernels::MatmulLaunchError, tune::AutotuneError}; +use cubecl::{ + linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}, + tune::AutotuneError, +}; pub enum ConvLaunchError { Matmul(MatmulLaunchError), @@ -30,6 +33,12 @@ impl From for ConvLaunchError { } } +impl From for ConvLaunchError { + fn from(value: MatmulAvailabilityError) -> Self { + Self::Matmul(MatmulLaunchError::Unavailable(value)) + } +} + #[allow(clippy::from_over_into)] impl Into for ConvLaunchError { fn into(self) -> AutotuneError { diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index 583e0346d3..a682a76eac 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -11,7 +11,7 @@ fn flip_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/gather.rs b/crates/burn-jit/src/kernel/index/gather.rs index 9e9b5685bb..c1aa56072e 100644 --- a/crates/burn-jit/src/kernel/index/gather.rs +++ b/crates/burn-jit/src/kernel/index/gather.rs @@ -12,7 +12,7 @@ fn gather_kernel( dim: &u32, ) { if ABSOLUTE_POS >= indices.len() { - return; + terminate!(); } let index = indices[ABSOLUTE_POS]; diff --git a/crates/burn-jit/src/kernel/index/repeat_dim.rs b/crates/burn-jit/src/kernel/index/repeat_dim.rs index 3887bfbd8b..b19f9e2b21 100644 --- a/crates/burn-jit/src/kernel/index/repeat_dim.rs +++ b/crates/burn-jit/src/kernel/index/repeat_dim.rs @@ -4,7 +4,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch_unchecked)] fn repeat_dim_kernel(input: &Tensor, output: &mut Tensor, dim: u32) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/scatter.rs b/crates/burn-jit/src/kernel/index/scatter.rs index 4ddd9c00fb..4cca94f824 100644 --- a/crates/burn-jit/src/kernel/index/scatter.rs +++ b/crates/burn-jit/src/kernel/index/scatter.rs @@ -46,7 +46,7 @@ fn scatter_kernel( let should_stop = ABSOLUTE_POS >= num_elems; if should_stop { - return; + terminate!(); } for i in 0..shape_value { diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index b104bf504f..fe664ab420 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -10,7 +10,7 @@ fn select_kernel( dim: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index a0fed49dbd..cd4c013f63 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -29,7 +29,7 @@ fn select_assign_kernel( } if ABSOLUTE_POS >= num_elems { - return; + terminate!(); } let strides_tensor_dim = tensor.stride(dim); diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index 7f20f033b8..b6daba8da5 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -52,7 +52,7 @@ fn slice_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index 1d545d79c7..3f77ef1302 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch)] fn interpolate_bicubic_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/bilinear.rs b/crates/burn-jit/src/kernel/interpolate/bilinear.rs index 3557fcdbb8..f0cb95b536 100644 --- a/crates/burn-jit/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-jit/src/kernel/interpolate/bilinear.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch)] fn interpolate_bilinear_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest.rs b/crates/burn-jit/src/kernel/interpolate/nearest.rs index 0743a13567..0e6ba32552 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch_unchecked)] fn interpolate_nearest_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs index 5ea860a7ae..f0442ec92e 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch_unchecked)] fn interpolate_nearest_backward_kernel(grad: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let out_h = output.shape(2); diff --git a/crates/burn-jit/src/kernel/mask/mask_fill.rs b/crates/burn-jit/src/kernel/mask/mask_fill.rs index 386e7a5039..95096c7994 100644 --- a/crates/burn-jit/src/kernel/mask/mask_fill.rs +++ b/crates/burn-jit/src/kernel/mask/mask_fill.rs @@ -16,7 +16,7 @@ fn mask_fill_readonly_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); @@ -35,7 +35,7 @@ fn mask_fill_inplace_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= input.len() { - return; + terminate!(); } let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); diff --git a/crates/burn-jit/src/kernel/mask/mask_where.rs b/crates/burn-jit/src/kernel/mask/mask_where.rs index 5518e9648b..99384fde98 100644 --- a/crates/burn-jit/src/kernel/mask/mask_where.rs +++ b/crates/burn-jit/src/kernel/mask/mask_where.rs @@ -16,7 +16,7 @@ fn mask_where_readonly_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); @@ -36,7 +36,7 @@ fn mask_where_inplace_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= input.len() { - return; + terminate!(); } let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index fd23cd2e2d..93d2833976 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -1,4 +1,5 @@ mod binary; +mod binary_int; mod cast; mod clamp; mod comparison; @@ -6,13 +7,16 @@ mod contiguous; mod index; mod mask; mod unary_float; +mod unary_int; mod unary_numeric; pub(crate) use binary::*; +pub(crate) use binary_int::*; pub use cast::*; pub use contiguous::*; pub use mask::*; pub(crate) use unary_float::*; +pub(crate) use unary_int::*; pub(crate) use unary_numeric::*; pub use burn_common::PLANE_DIM_APPROX; diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs index bba68c7166..d2a5a21d0a 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs @@ -24,7 +24,7 @@ fn avg_pool2d_backward_kernel( #[comptime] count_include_pad: bool, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index 6da6e2b37c..40259c4573 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -16,7 +16,7 @@ fn max_pool2d_with_indices_backward_kernel( #[comptime] kernel_size_1: i32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index 72040d8839..270e32f854 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -48,7 +48,7 @@ pub(crate) fn dequantize_per_tensor_affine_int8_kernel( ) { // Last two positions contain the qparams if ABSOLUTE_POS >= input.len() - 2 { - return; + terminate!(); } let qparams = QParams::new(scheme); @@ -85,7 +85,7 @@ pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( ) { // Last position contains the qparam if ABSOLUTE_POS >= input.len() - 1 { - return; + terminate!(); } let qparams = QParams::new(scheme); diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs index e9494aa987..0a7b0ea553 100644 --- a/crates/burn-jit/src/kernel/quantization/quantize.rs +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -34,7 +34,7 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let scale = scale[0]; @@ -43,13 +43,13 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( // Cast the scale to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 1 { output[ABSOLUTE_POS] = u32::bitcast_from(scale); - return; + terminate!(); } // Cast the offset to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 2 { output[ABSOLUTE_POS] = u32::bitcast_from(offset); - return; + terminate!(); } let line_size = comptime!(input.line_size()); @@ -120,7 +120,7 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let scale = scale[0]; @@ -128,7 +128,7 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( // Cast the scale to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 1 { output[ABSOLUTE_POS] = u32::bitcast_from(scale); - return; + terminate!(); } let line_size = comptime!(input.line_size()); diff --git a/crates/burn-jit/src/kernel/reduce/base.rs b/crates/burn-jit/src/kernel/reduce/base.rs index 9ab1f5d2b6..ccfcc3ef9e 100644 --- a/crates/burn-jit/src/kernel/reduce/base.rs +++ b/crates/burn-jit/src/kernel/reduce/base.rs @@ -1,31 +1,96 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +#[cfg(feature = "autotune")] +use super::{autotune_reduce, autotune_sum}; +use crate::{ + element::JitElement, + ops::{from_data, numeric::empty_device}, + tensor::JitTensor, + JitRuntime, +}; +use burn_tensor::{Shape, TensorData}; +pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum}; +use cubecl::reduce::shared_sum; -use super::autotune_reduce; +/// Specialize reduce function to compute the sum of all elements of the `input` tensor and return +/// the value into a single-element tensor of shape `1 x 1 x 1 x ...` with the same rank as `input`. +/// +/// This is expected to be faster for larger tensors than calling [reduce] with the `Sum` instruction. +/// +/// Return an error if the `client` doesn't support atomic add for the type `E`. +pub fn sum( + tensor: JitTensor, + cube_count: SumStrategy, +) -> Result, cubecl::reduce::ReduceError> { + let client = tensor.client.clone(); + let device = tensor.device.clone(); -pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum}; + match cube_count { + SumStrategy::OneShot(cube_count) => { + let output = shared_sum::(&client, tensor.as_handle_ref(), cube_count)?; + Ok(from_data::( + TensorData::new(vec![output], vec![1]), + &device, + )) + } + SumStrategy::Chained(strategy) => reduce::(tensor, strategy), + #[cfg(feature = "autotune")] + SumStrategy::Autotune => Ok(autotune_sum::(&client, tensor)), + } +} + +/// Select a strategy to perform a sum. +pub enum SumStrategy { + /// Run a single kernel with many cubes working in parallel to sum all elements. + /// The provided value is the number of elements summed per unit (up-to-rounding ) + OneShot(u32), + /// Use multiple kernels + Chained(ReduceStrategy), + /// Use autotune to find the best cube count given the hardware and the input. + #[cfg(feature = "autotune")] + Autotune, +} + +impl Default for SumStrategy { + fn default() -> Self { + #[cfg(feature = "autotune")] + return Self::Autotune; + + #[cfg(not(feature = "autotune"))] + return Self::OneShot(4); + } +} /// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy). /// /// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`. -/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid. -/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`. /// /// If there is no error, the output is a tensor with decreasing strides /// where the shape of reduced dim is set to 1 but all shape are similar to the input. pub fn reduce( - mut input: JitTensor, + mut tensor: JitTensor, strategy: ReduceStrategy, ) -> Result, cubecl::reduce::ReduceError> { - input.shape = input.shape.flatten(); - input.strides = vec![1]; - reduce_dim::(input, 0, strategy) + // In practice, it looks like starting by the axis with the smallest shape + // and going in increasing order lead to the fastest calculation. + let sorted_axis = argsort(&tensor.shape.dims); + for axis in sorted_axis { + tensor = reduce_dim::(tensor, axis, strategy)?; + } + // reshape to scalar tensor + tensor.shape = Shape::new([1]); + tensor.strides = vec![1]; + Ok(tensor) +} + +fn argsort(shape: &[usize]) -> Vec { + let mut indices = (0..shape.len()).collect::>(); + indices.sort_by_key(|&i| &shape[i]); + indices } /// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy). /// /// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`. /// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid. -/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`. /// /// If there is no error, the output is a tensor with decreasing strides /// where the shape of reduced dim is set to 1 but all shape are similar to the input. @@ -58,7 +123,8 @@ pub fn reduce_dim { - autotune_reduce::(&client, input, output.clone(), dim) + autotune_reduce::(&client, input, output.clone(), dim); + Ok(()) } }; result.map(|_| output) diff --git a/crates/burn-jit/src/kernel/reduce/tune.rs b/crates/burn-jit/src/kernel/reduce/tune.rs index b364907238..cd5cd61157 100644 --- a/crates/burn-jit/src/kernel/reduce/tune.rs +++ b/crates/burn-jit/src/kernel/reduce/tune.rs @@ -12,7 +12,6 @@ use crate::{ kernel::prng::random_like_uniform, ops::numeric::empty_device, tensor::JitTensor, JitAutotuneKey, JitElement, JitRuntime, JitTuneId, }; -use reduce_ops::*; /// Executes autotune on reduce operations. pub fn autotune_reduce< @@ -25,7 +24,9 @@ pub fn autotune_reduce< input: JitTensor, output: JitTensor, dim: usize, -) -> Result<(), cubecl::reduce::ReduceError> { +) { + use reduce_ops::*; + static TUNER: LocalTuner = local_tuner!(); let tunables = TunableSet::new(create_key::, reduce_input_gen::) @@ -40,12 +41,10 @@ pub fn autotune_reduce< &tunables, (input, output, dim), ); - - Ok(()) } #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] -/// Autotune key representative of redue versions +/// Autotune key representative of reduce versions pub struct ReduceAutotuneKey { dtype: burn_tensor::DType, #[autotune(anchor)] @@ -207,3 +206,89 @@ mod reduce_ops { .map_err(|e| format!("{e}")) } } + +/// Executes autotune on reduce operations. +#[cfg(feature = "autotune")] +pub fn autotune_sum( + client: &ComputeClient, + input: JitTensor, +) -> JitTensor { + use sum_ops::*; + + static TUNER: LocalTuner = local_tuner!(); + + let tunables = TunableSet::new(create_key_sum::, sum_input_gen::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_chained::); + + TUNER.execute( + &JitTuneId::new::(&input.device), + client, + &tunables, + input, + ) +} + +pub(crate) fn create_key_sum(input: &JitTensor) -> JitAutotuneKey { + JitAutotuneKey::Sum(SumAutotuneKey::generate(input)) +} + +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +/// Autotune key representative of sum versions +pub struct SumAutotuneKey { + dtype: burn_tensor::DType, + #[autotune(anchor)] + length: usize, +} + +impl SumAutotuneKey { + pub(crate) fn generate(input: &JitTensor) -> Self { + let dtype = input.dtype; + let length = input.shape.num_elements(); + Self { dtype, length } + } +} +mod sum_ops { + #![allow(missing_docs)] + + use burn_tensor::TensorData; + use cubecl::reduce::instructions::Sum; + + use crate::ops::from_data; + + use super::*; + + pub(crate) fn sum_input_gen( + _key: &JitAutotuneKey, + input: &JitTensor, + ) -> JitTensor { + let random_bounds: (E, E) = ((-10.0_f32).elem::(), (10.0_f32).elem::()); + random_like_uniform(input, random_bounds.0, random_bounds.1) + } + + pub(crate) fn sum_one_shot( + input: JitTensor, + ) -> Result, String> { + let device = input.device.clone(); + cubecl::reduce::shared_sum::(&input.client, input.as_handle_ref(), C) + .map(|output| from_data::(TensorData::new(vec![output], vec![1]), &device)) + .map_err(|e| e.to_string()) + } + + #[cfg(feature = "autotune")] + pub(crate) fn sum_chained( + input: JitTensor, + ) -> Result, String> { + crate::kernel::reduce::reduce::( + input, + crate::kernel::reduce::ReduceStrategy::Autotune, + ) + .map_err(|e| e.to_string()) + } +} diff --git a/crates/burn-jit/src/kernel/unary_float.rs b/crates/burn-jit/src/kernel/unary_float.rs index 33a311ecbc..4664d3c0b3 100644 --- a/crates/burn-jit/src/kernel/unary_float.rs +++ b/crates/burn-jit/src/kernel/unary_float.rs @@ -27,7 +27,7 @@ pub(crate) fn unary_float( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } if comptime![to_contiguous] { diff --git a/crates/burn-jit/src/kernel/unary_int.rs b/crates/burn-jit/src/kernel/unary_int.rs new file mode 100644 index 0000000000..17bced52d1 --- /dev/null +++ b/crates/burn-jit/src/kernel/unary_int.rs @@ -0,0 +1,148 @@ +use crate::{ops::numeric::empty_device, tensor::JitTensor, IntElement, JitRuntime}; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +pub(crate) trait IntUnaryOpFamily: 'static + Send + Sync { + type Options: LaunchArg; + type Unary: IntUnaryOp>; +} + +#[cube] +pub(crate) trait IntUnaryOp: 'static + Send + Sync { + type Options: LaunchArg; + + fn execute(input: Line, options: &Self::Options) -> Line; +} + +#[cube(launch_unchecked)] +pub(crate) fn unary_int( + input: &Tensor>, + output: &mut Tensor>, + options: &O::Options, + #[comptime] rank: Option, + #[comptime] to_contiguous: bool, +) { + let offset_output = ABSOLUTE_POS; + + if offset_output >= output.len() { + terminate!(); + } + + if comptime![to_contiguous] { + let offset_input = index_offset_with_layout::( + input, + output, + offset_output, + 0, + rank.unwrap_or_else(|| output.rank()), + rank.is_some(), + ); + + output[offset_output] = O::Unary::::execute(input[offset_input], options); + } else { + output[offset_output] = O::Unary::::execute(input[offset_output], options); + } +} + +pub(crate) fn launch_unary_int(tensor: JitTensor, args: Args) -> JitTensor +where + for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, + R: JitRuntime, + E: IntElement + Int, + O: IntUnaryOpFamily, +{ + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + let is_contiguous = tensor.is_contiguous(); + + unsafe { + if tensor.can_mut() && tensor.is_contiguous_buffer() { + unary_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + TensorArg::alias(0), + args(&()), + None, + false, + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + unary_int::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + args(&()), + Some(ndims as u32), + !is_contiguous, + ); + output + } + } +} + +pub(crate) mod unary_basic_int { + + use super::*; + + pub(crate) fn launch(tensor: JitTensor, args: Args) -> JitTensor + where + R: JitRuntime, + for<'a> Args: FnOnce(&'a ()) -> &'a BasicIntUnaryKind, + I: IntElement, + { + launch_unary_int::(tensor, |input| { + BasicIntUnaryOptionsLaunch::new(args(input)) + }) + } + + #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)] + pub enum BasicIntUnaryKind { + BitwiseNot, + } + + #[derive(CubeLaunch)] + struct BasicIntUnaryOptions { + #[cube(comptime)] + kind: BasicIntUnaryKind, + } + struct BasicIntUnary; + + #[cube] + impl IntUnaryOp for BasicIntUnary { + type Options = BasicIntUnaryOptions; + + fn execute(input: Line, options: &Self::Options) -> Line { + match comptime![options.kind] { + BasicIntUnaryKind::BitwiseNot => Line::bitwise_not(input), + } + } + } + + impl IntUnaryOpFamily for BasicIntUnary { + type Options = BasicIntUnaryOptions; + type Unary = Self; + } +} diff --git a/crates/burn-jit/src/kernel/unary_numeric.rs b/crates/burn-jit/src/kernel/unary_numeric.rs index 0b8dcb2cbc..aaeadbb685 100644 --- a/crates/burn-jit/src/kernel/unary_numeric.rs +++ b/crates/burn-jit/src/kernel/unary_numeric.rs @@ -27,7 +27,7 @@ pub(crate) fn unary_numeric( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } if comptime![to_contiguous] { diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index d32de97436..17b775361e 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -355,7 +355,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::reduce::(tensor, Default::default()).unwrap() + reduce::sum::(tensor, Default::default()).unwrap() ) } diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 5702a90849..8da778d1e8 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -1,5 +1,10 @@ +use self::unary_basic_int::BasicIntUnaryKind; + use super::{expand, numeric, permute}; -use crate::kernel::{launch_unary_numeric, reduce, NumericUnaryOp, NumericUnaryOpFamily}; +use crate::kernel::{ + launch_binop_int, launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int, + BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, +}; use crate::{ element::BoolElement, kernel::prng::{random_bernoulli, random_normal, random_uniform}, @@ -193,7 +198,7 @@ where } fn int_sum(tensor: IntTensor) -> IntTensor { - reduce::reduce::(tensor, Default::default()).unwrap() + reduce::sum::(tensor, Default::default()).unwrap() } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { @@ -293,4 +298,48 @@ where fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { kernel::flip::(tensor, axes) } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + numeric::bitwise_and::(lhs, rhs) + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + numeric::bitwise_and_scalar::(lhs, rhs) + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + numeric::bitwise_or::(lhs, rhs) + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + numeric::bitwise_or_scalar(lhs, rhs) + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + numeric::bitwise_xor::(lhs, rhs) + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + numeric::bitwise_xor_scalar(lhs, rhs) + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + unary_basic_int::launch::(tensor, |_| &BasicIntUnaryKind::BitwiseNot) + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + launch_binop_int::(lhs, rhs) + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + launch_scalar_binop_int::(lhs, rhs) + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + launch_binop_int::(lhs, rhs) + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + launch_scalar_binop_int::(lhs, rhs) + } } diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index d0d5be8468..cf15916aab 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -1,8 +1,9 @@ use crate::kernel::{ - launch_binop, launch_scalar_binop, AddOp, DivOp, MulOp, PowOp, RemainderOp, SubOp, + launch_binop, launch_binop_int, launch_scalar_binop, launch_scalar_binop_int, AddOp, + BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, MulOp, PowOp, RemainderOp, SubOp, }; use crate::{element::JitElement, tensor::JitTensor}; -use crate::{FloatElement, JitRuntime}; +use crate::{FloatElement, IntElement, JitRuntime}; use burn_tensor::{ElementConversion, Shape}; use cubecl::client::ComputeClient; use cubecl::tensor_vectorization_factor; @@ -30,7 +31,7 @@ pub fn full_device( #[cube(launch)] pub fn full_kernel(tensor: &mut Tensor, value: C) { if ABSOLUTE_POS >= tensor.len() { - return; + terminate!(); } tensor[ABSOLUTE_POS] = value; @@ -139,3 +140,36 @@ pub fn remainder_scalar(lhs: JitTensor, rhs: E) pub fn pow(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::>(lhs, rhs) } + +pub fn bitwise_and( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_binop_int::(lhs, rhs) +} + +pub fn bitwise_and_scalar(lhs: JitTensor, rhs: E) -> JitTensor { + launch_scalar_binop_int::(lhs, rhs) +} + +pub fn bitwise_or( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_binop_int::(lhs, rhs) +} + +pub fn bitwise_or_scalar(lhs: JitTensor, rhs: E) -> JitTensor { + launch_scalar_binop_int::(lhs, rhs) +} + +pub fn bitwise_xor( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_binop_int::(lhs, rhs) +} + +pub fn bitwise_xor_scalar(lhs: JitTensor, rhs: E) -> JitTensor { + launch_scalar_binop_int::(lhs, rhs) +} diff --git a/crates/burn-jit/src/tune_key.rs b/crates/burn-jit/src/tune_key.rs index cb29e2fe0c..9a86a85483 100644 --- a/crates/burn-jit/src/tune_key.rs +++ b/crates/burn-jit/src/tune_key.rs @@ -1,7 +1,7 @@ use crate::kernel::{ conv::{Conv2dAutotuneKey, ConvTranspose2dAutotuneKey}, matmul::MatmulAutotuneKey, - reduce::ReduceAutotuneKey, + reduce::{ReduceAutotuneKey, SumAutotuneKey}, }; use cubecl::tune::AutotuneKey; use serde::{Deserialize, Serialize}; @@ -14,6 +14,8 @@ pub enum JitAutotuneKey { Matmul(MatmulAutotuneKey), /// Key for reduce dim operations Reduce(ReduceAutotuneKey), + /// Key for sum operations + Sum(SumAutotuneKey), /// Key for convolution operations Conv2d(Conv2dAutotuneKey), /// Key for transpose convolution operations @@ -25,6 +27,7 @@ impl Display for JitAutotuneKey { match self { JitAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), JitAutotuneKey::Reduce(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + JitAutotuneKey::Sum(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), JitAutotuneKey::Conv2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), JitAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), } diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index 9009b5c4a8..43c7cdb100 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -351,4 +351,71 @@ impl IntTensorOps fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { NdArrayOps::expand(tensor, shape) } + + fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() & (b.elem::())).elem() + }) + } + + fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() & rhs.elem::()).elem() + }) + } + + fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() | (b.elem::())).elem() + }) + } + + fn bitwise_or_scalar( + lhs: burn_tensor::ops::IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> burn_tensor::ops::IntTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() | rhs.elem::()).elem() + }) + } + + fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() ^ (b.elem::())).elem() + }) + } + + fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() ^ rhs.elem::()).elem() + }) + } + + fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(tensor, |a: I| (!a.elem::()).elem()) + } + + fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() << (b.elem::())).elem() + }) + } + + fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() << rhs.elem::()).elem() + }) + } + + fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() >> (b.elem::())).elem() + }) + } + + fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() >> rhs.elem::()).elem() + }) + } } diff --git a/crates/burn-remote/src/server/session.rs b/crates/burn-remote/src/server/session.rs index 7d32d04b74..3da6b2afa1 100644 --- a/crates/burn-remote/src/server/session.rs +++ b/crates/burn-remote/src/server/session.rs @@ -101,12 +101,12 @@ impl SessionManager { impl Session { fn new(runner: Runner) -> Self { - let (sender, reveiver) = std::sync::mpsc::sync_channel(1); + let (sender, receiver) = std::sync::mpsc::sync_channel(1); Self { runner, streams: Default::default(), sender, - receiver: Some(reveiver), + receiver: Some(receiver), } } diff --git a/crates/burn-router/src/ops/op_bool.rs b/crates/burn-router/src/ops/op_bool.rs index b5ec3660ae..fb3361ec17 100644 --- a/crates/burn-router/src/ops/op_bool.rs +++ b/crates/burn-router/src/ops/op_bool.rs @@ -4,9 +4,9 @@ use burn_tensor::ops::{BoolTensor, BoolTensorOps, FloatElem, FloatTensor, IntEle use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, - OperationDescription, PermuteOperationDescription, RepeatDimOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, - UnaryOperationDescription, + FromDataOperationDescription, OperationDescription, PermuteOperationDescription, + RepeatDimOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, + SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{DType, Device, Element, Shape, TensorData, TensorMetadata}; @@ -31,7 +31,18 @@ impl BoolTensorOps for BackendRouter { fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { let client = get_client::(device); - client.register_tensor_data(data.convert::()) + let out = client.register_empty_tensor(data.shape.clone(), DType::Bool); + + let desc = FromDataOperationDescription { + data, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::FromData(desc), + )); + + out } fn bool_into_int(tensor: BoolTensor) -> IntTensor { diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs index 10bddc3803..1b17d5a2ad 100644 --- a/crates/burn-router/src/ops/op_float.rs +++ b/crates/burn-router/src/ops/op_float.rs @@ -8,12 +8,13 @@ use burn_tensor::ops::{ use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription, - MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, - PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - RepeatDimOperationDescription, ScalarOperationDescription, ScatterOperationDescription, - SelectAssignOperationDescription, SelectOperationDescription, SliceAssignOperationDescription, - SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, + FloatOperationDescription, FromDataOperationDescription, GatherOperationDescription, + MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, + OperationDescription, PermuteOperationDescription, RandomOperationDescription, + ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ScalarOperationDescription, + ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, + SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, + UnaryOperationDescription, }; use burn_tensor::{ DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, @@ -24,7 +25,18 @@ use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient}; impl FloatTensorOps for BackendRouter { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { let client = get_client::(device); - client.register_tensor_data(data.convert::<::FloatElem>()) + let out = client.register_empty_tensor(data.shape.clone(), FloatElem::::dtype()); + + let desc = FromDataOperationDescription { + data, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::FromData(desc), + )); + + out } fn float_random( diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index 9aa3bf2dc8..eefecd7ef8 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -8,12 +8,13 @@ use burn_tensor::ops::{ use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - GatherOperationDescription, IntOperationDescription, MaskFillOperationDescription, - MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, - PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - RepeatDimOperationDescription, ScalarOperationDescription, ScatterOperationDescription, - SelectAssignOperationDescription, SelectOperationDescription, SliceAssignOperationDescription, - SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, + FromDataOperationDescription, GatherOperationDescription, IntOperationDescription, + MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, + OperationDescription, PermuteOperationDescription, RandomOperationDescription, + ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ScalarOperationDescription, + ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, + SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, + UnaryOperationDescription, }; use burn_tensor::{ DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, @@ -44,7 +45,18 @@ impl IntTensorOps for BackendRouter { fn int_from_data(data: TensorData, device: &Device) -> IntTensor { let client = get_client::(device); - client.register_tensor_data(data.convert::<::IntElem>()) + let out = client.register_empty_tensor(data.shape.clone(), IntElem::::dtype()); + + let desc = FromDataOperationDescription { + data, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::FromData(desc), + )); + + out } fn int_device(tensor: &IntTensor) -> Device { @@ -1172,4 +1184,201 @@ impl IntTensorOps for BackendRouter { out } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseAnd(desc), + )); + + out + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseOr(desc), + )); + + out + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseXor(desc), + )); + + out + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseNot(desc), + )); + + out + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseAndScalar(desc), + )); + + out + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseOrScalar(desc), + )); + + out + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseXorScalar(desc), + )); + + out + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseLeftShift(desc), + )); + + out + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseLeftShiftScalar(desc), + )); + + out + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseRightShift(desc), + )); + + out + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseRightShiftScalar(desc), + )); + + out + } } diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index 04f93a4769..7443be94f9 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -245,6 +245,10 @@ impl RunnerClient for Runner { let output = B::float_empty(shape, &self.device); handles.register_float_tensor::(&desc.id, output); } + BaseOperationDescription::FromData(desc) => { + let output = B::float_from_data(desc.data.clone(), &self.device); + handles.register_float_tensor::(&desc.out.id, output); + } }, OperationDescription::BaseInt(op) => match op { BaseOperationDescription::ToDevice(_) => unreachable!(), @@ -316,6 +320,10 @@ impl RunnerClient for Runner { let output = B::int_empty(shape, &self.device); handles.register_int_tensor::(&desc.id, output); } + BaseOperationDescription::FromData(desc) => { + let output = B::int_from_data(desc.data.clone(), &self.device); + handles.register_int_tensor::(&desc.out.id, output); + } }, OperationDescription::BaseBool(op) => match op { BaseOperationDescription::ToDevice(_) => unreachable!(), @@ -391,6 +399,10 @@ impl RunnerClient for Runner { let output = B::bool_empty(shape, &self.device); handles.register_bool_tensor::(&desc.id, output); } + BaseOperationDescription::FromData(desc) => { + let output = B::bool_from_data(desc.data.clone(), &self.device); + handles.register_bool_tensor::(&desc.out.id, output); + } }, OperationDescription::NumericFloat(_dtype, op) => match op { NumericOperationDescription::Add(desc) => { @@ -792,6 +804,39 @@ impl RunnerClient for Runner { let output = B::int_into_float(tensor); handles.register_float_tensor::(&desc.out.id, output); } + IntOperationDescription::BitwiseAnd(desc) => { + binary_int_ops!(handles, desc, B::bitwise_and) + } + IntOperationDescription::BitwiseAndScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_and_scalar) + } + IntOperationDescription::BitwiseOr(desc) => { + binary_int_ops!(handles, desc, B::bitwise_or) + } + IntOperationDescription::BitwiseOrScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_or_scalar) + } + IntOperationDescription::BitwiseXor(desc) => { + binary_int_ops!(handles, desc, B::bitwise_xor) + } + IntOperationDescription::BitwiseXorScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_xor_scalar) + } + IntOperationDescription::BitwiseNot(desc) => { + unary_int_ops!(handles, desc, B::bitwise_not) + } + IntOperationDescription::BitwiseLeftShift(desc) => { + binary_int_ops!(handles, desc, B::bitwise_left_shift) + } + IntOperationDescription::BitwiseRightShift(desc) => { + binary_int_ops!(handles, desc, B::bitwise_right_shift) + } + IntOperationDescription::BitwiseLeftShiftScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_left_shift_scalar) + } + IntOperationDescription::BitwiseRightShiftScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_right_shift_scalar) + } }, OperationDescription::Float(_dtype, op) => match op { FloatOperationDescription::Exp(desc) => { diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index 7b04207871..704c6176cc 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -477,4 +477,118 @@ impl TchOps { pub fn argsort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor { TchTensor::new(tensor.tensor.argsort(dim as i64, descending)) } + + pub fn bitwise_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_and_tensor_(rhs).unwrap(), + |lhs, rhs| rhs.f_bitwise_and_tensor_(lhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_and_tensor(rhs).unwrap(), + ) + } + + pub fn bitwise_and_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_and_(scalar.clone().into()).unwrap(), + |tensor| tensor.f_bitwise_and(scalar.clone().into()).unwrap(), + ) + } + + pub fn bitwise_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_or_tensor_(rhs).unwrap(), + |lhs, rhs| rhs.f_bitwise_or_tensor_(lhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_or_tensor(rhs).unwrap(), + ) + } + + pub fn bitwise_or_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_or_(scalar.clone().into()).unwrap(), + |tensor| tensor.f_bitwise_or(scalar.clone().into()).unwrap(), + ) + } + + pub fn bitwise_xor(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_xor_tensor_(rhs).unwrap(), + |lhs, rhs| rhs.f_bitwise_xor_tensor_(lhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_xor_tensor(rhs).unwrap(), + ) + } + + pub fn bitwise_xor_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_xor_(scalar.clone().into()).unwrap(), + |tensor| tensor.f_bitwise_xor(scalar.clone().into()).unwrap(), + ) + } + + pub fn bitwise_not(tensor: TchTensor) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_not_().unwrap(), + |tensor| tensor.f_bitwise_not().unwrap(), + ) + } + + pub fn bitwise_left_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_left_shift_(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(), + ) + } + + pub fn bitwise_left_shift_scalar + Clone>( + tensor: TchTensor, + scalar: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| { + tensor + .f_bitwise_left_shift_tensor_scalar_(scalar.clone().into()) + .unwrap() + }, + |tensor| { + tensor + .f_bitwise_left_shift_tensor_scalar(scalar.clone().into()) + .unwrap() + }, + ) + } + + pub fn bitwise_right_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_right_shift_(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(), + ) + } + + pub fn bitwise_right_shift_scalar + Clone>( + tensor: TchTensor, + scalar: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| { + tensor + .f_bitwise_right_shift_tensor_scalar_(scalar.clone().into()) + .unwrap() + }, + |tensor| { + tensor + .f_bitwise_right_shift_tensor_scalar(scalar.clone().into()) + .unwrap() + }, + ) + } } diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index 0da31fe430..0ac829abaf 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -416,4 +416,63 @@ impl IntTensorOps for LibTorch { fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { TchOps::argsort(tensor, dim, descending) } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_and(lhs, rhs) + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_or(lhs, rhs) + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_xor(lhs, rhs) + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + TchOps::bitwise_not(tensor) + } + + fn bitwise_and_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_and_scalar(lhs, rhs) + } + + fn bitwise_or_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_or_scalar(lhs, rhs) + } + + fn bitwise_xor_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_xor_scalar(lhs, rhs) + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_left_shift(lhs, rhs) + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_right_shift(lhs, rhs) + } + + fn bitwise_left_shift_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_left_shift_scalar(lhs, rhs) + } + + fn bitwise_right_shift_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_right_shift_scalar(lhs, rhs) + } } diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 58a8d83fce..d3203ea14d 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -6,6 +6,7 @@ use alloc::borrow::ToOwned; use alloc::boxed::Box; use alloc::{string::String, vec, vec::Vec}; +use crate::TensorData; use crate::{ ops::{ ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions, @@ -197,6 +198,12 @@ pub enum ModuleOperationDescription { /// Basic operations that can be done on any tensor type. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum BaseOperationDescription { + /// Operation corresponding to: + /// + /// Float => [from_data](crate::ops::FloatTensorOps::float_from_data). + /// Int => [from_data](crate::ops::IntTensorOps::int_from_data). + /// Bool => [from_data](crate::ops::BoolTensorOps::bool_from_data). + FromData(FromDataOperationDescription), /// Operation corresponding to: /// /// Float => [to device](crate::ops::FloatTensorOps::float_to_device). @@ -272,9 +279,9 @@ pub enum BaseOperationDescription { /// Operation corresponding to: /// - /// Float => [equal](crate::ops::FloatTensorOps::float_empty). - /// Int => [equal](crate::ops::IntTensorOps::int_empty). - /// Bool => [equal](crate::ops::BoolTensorOps::bool_empty). + /// Float => [empty](crate::ops::FloatTensorOps::float_empty). + /// Int => [empty](crate::ops::IntTensorOps::int_empty). + /// Bool => [empty](crate::ops::BoolTensorOps::bool_empty). Empty(TensorDescription), } @@ -520,6 +527,50 @@ pub enum NumericOperationDescription { pub enum IntOperationDescription { /// Operation corresponding to [into float](crate::ops::IntTensorOps::int_into_float). IntoFloat(UnaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise and](crate::ops::IntTensorOps::bitwise_and). + BitwiseAnd(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise and scalar](crate::ops::IntTensorOps::bitwise_and_scalar). + BitwiseAndScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise or](crate::ops::IntTensorOps::bitwise_or). + BitwiseOr(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise or scalar](crate::ops::IntTensorOps::bitwise_or_scalar). + BitwiseOrScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise xor](crate::ops::IntTensorOps::bitwise_xor). + BitwiseXor(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise xor scalar](crate::ops::IntTensorOps::bitwise_xor_scalar). + BitwiseXorScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise not](crate::ops::IntTensorOps::bitwise_not). + BitwiseNot(UnaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise left shift](crate::ops::IntTensorOps::bitwise_left_shift). + BitwiseLeftShift(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise left shift scalar](crate::ops::IntTensorOps::bitwise_left_shift_scalar). + BitwiseLeftShiftScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise right shift](crate::ops::IntTensorOps::bitwise_right_shift). + BitwiseRightShift(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise right shift scalar](crate::ops::IntTensorOps::bitwise_right_shift_scalar). + BitwiseRightShiftScalar(ScalarOperationDescription), } /// Operation description specific to a bool tensor. @@ -586,6 +637,13 @@ pub struct RandomOperationDescription { pub distribution: Distribution, } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[allow(missing_docs)] +pub struct FromDataOperationDescription { + pub out: TensorDescription, + pub data: TensorData, +} + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ExpandDescription { @@ -1357,6 +1415,7 @@ impl BaseOperationDescription { BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(), BaseOperationDescription::Cast(desc) => vec![&desc.input, &desc.out], BaseOperationDescription::Empty(desc) => vec![desc], + BaseOperationDescription::FromData(desc) => vec![&desc.out], } } } @@ -1537,6 +1596,39 @@ impl IntOperationDescription { fn nodes(&self) -> Vec<&TensorDescription> { match self { IntOperationDescription::IntoFloat(desc) => vec![&desc.input, &desc.out], + IntOperationDescription::BitwiseAnd(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseAndScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseOr(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseOrScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseXor(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseXorScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseNot(desc) => { + vec![&desc.input, &desc.out] + } + IntOperationDescription::BitwiseLeftShift(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseLeftShiftScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseRightShift(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseRightShiftScalar(desc) => { + vec![&desc.lhs, &desc.out] + } } } } @@ -1670,6 +1762,12 @@ impl ModuleOperationDescription { } } +impl core::hash::Hash for FromDataOperationDescription { + fn hash(&self, state: &mut H) { + self.out.hash(state); + } +} + impl core::hash::Hash for RandomOperationDescription { fn hash(&self, state: &mut H) { self.out.hash(state); diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index e882a107c7..5d65b68ceb 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -99,4 +99,59 @@ where ) -> Tensor { cartesian_grid::(shape, device) } + + /// Applies the bitwise logical and operation with each bit representing the integer. + pub fn bitwise_and(self, other: Self) -> Self { + Self::new(B::bitwise_and(self.primitive, other.primitive)) + } + + /// Applies the bitwise logical or operation with another tensor. + pub fn bitwise_or(self, other: Self) -> Self { + Self::new(B::bitwise_or(self.primitive, other.primitive)) + } + + /// Applies the bitwise logical xor operation with another tensor. + pub fn bitwise_xor(self, other: Self) -> Self { + Self::new(B::bitwise_xor(self.primitive, other.primitive)) + } + + /// Applies the bitwise logical not operation. + pub fn bitwise_not(self) -> Self { + Self::new(B::bitwise_not(self.primitive)) + } + + /// Applies the bitwise logical and operation with each bit in the scalar and the integers in the tensor. + pub fn bitwise_and_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_and_scalar(self.primitive, other)) + } + + /// Applies the bitwise logical or operation with each bit in the scalar and the integers in the tensor. + pub fn bitwise_or_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_or_scalar(self.primitive, other)) + } + + /// Applies bitwise logical xor operation with each bit in the scalar and the integers in the tensor. + pub fn bitwise_xor_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_xor_scalar(self.primitive, other)) + } + + /// Applies the bitwise left shift operation with the integers in the tensor. + pub fn bitwise_left_shift(self, other: Self) -> Self { + Self::new(B::bitwise_left_shift(self.primitive, other.primitive)) + } + + /// Applies the bitwise right shift operation with the integers in the tensor. + pub fn bitwise_right_shift(self, other: Self) -> Self { + Self::new(B::bitwise_right_shift(self.primitive, other.primitive)) + } + + /// Applies the bitwise left shift operation with the integers in the tensor. + pub fn bitwise_left_shift_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_left_shift_scalar(self.primitive, other)) + } + + /// Applies the bitwise right shift operation with the integers in the tensor. + pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_right_shift_scalar(self.primitive, other)) + } } diff --git a/crates/burn-tensor/src/tensor/backend/conversion.rs b/crates/burn-tensor/src/tensor/backend/conversion.rs index 46b0423b71..6aebe06463 100644 --- a/crates/burn-tensor/src/tensor/backend/conversion.rs +++ b/crates/burn-tensor/src/tensor/backend/conversion.rs @@ -188,7 +188,7 @@ mod tests { } #[test] - fn should_build_indices_2d_complexe() { + fn should_build_indices_2d_complex() { let shape = Shape::new([2, 3]); let indices = build_indices(&shape, Order::Left); @@ -206,7 +206,7 @@ mod tests { } #[test] - fn should_build_indices_3d_complexe() { + fn should_build_indices_3d_complex() { let shape = Shape::new([2, 5, 3]); let indices = build_indices(&shape, Order::Left); diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index abdd2e54ba..81b73eb2dd 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -1185,4 +1185,37 @@ pub trait IntTensorOps { fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { argsort::(tensor, dim, descending) } + + /// Bitwise AND operation for Int Tensors + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise AND operation for Int Tensors with a scalar + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise OR operation for Int Tensors + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise OR operation for Int Tensors with a scalar + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise XOR operation for Int Tensors + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise XOR operation for Int Tensors with a scalar + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise NOT operation for Int Tensors + fn bitwise_not(tensor: IntTensor) -> IntTensor; + + /// Bitwise left shift operation for Int Tensors + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise left shift operation for Int Tensors with a scalar + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise right shift operation for Int Tensors + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise right shift operation for Int Tensors with a scalar + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; } diff --git a/crates/burn-tensor/src/tensor/quantization/scheme.rs b/crates/burn-tensor/src/tensor/quantization/scheme.rs index fb141ee16d..27fa996ad6 100644 --- a/crates/burn-tensor/src/tensor/quantization/scheme.rs +++ b/crates/burn-tensor/src/tensor/quantization/scheme.rs @@ -37,7 +37,7 @@ impl CubeType for QuantizationScheme { } #[cfg(feature = "cubecl")] impl cubecl::frontend::Init for QuantizationScheme { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _scope: &mut cubecl::ir::Scope) -> Self { self } } diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 8aa41ee24d..ee9aec9fe8 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -310,6 +310,7 @@ macro_rules! testgen_with_int_param { burn_tensor::testgen_sub!(); burn_tensor::testgen_transpose!(); burn_tensor::testgen_gather_scatter!(); + burn_tensor::testgen_bitwise!(); // test stats burn_tensor::testgen_eye!(); diff --git a/crates/burn-tensor/src/tests/ops/bitwise.rs b/crates/burn-tensor/src/tests/ops/bitwise.rs new file mode 100644 index 0000000000..c85f5edcc5 --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/bitwise.rs @@ -0,0 +1,176 @@ +#[burn_tensor_testgen::testgen(bitwise)] +mod tests { + use super::*; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_apply_bitwise_and_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); + + let output = tensor_1.bitwise_and(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[2, 4, 0], [9, 2, 8]]), false); + } + + #[test] + fn should_apply_bitwise_and_1d() { + let tensor_1 = TestTensorInt::<1>::from([13, 7]); + let tensor_2 = TestTensorInt::from([11, 3]); + + let output = tensor_1.bitwise_and(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([9, 3]), false); + } + + #[test] + fn should_apply_bitwise_and_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 5; + + let output = tensor_1.bitwise_and_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[1, 4, 5], [1, 1, 0]]), false); + } + + #[test] + fn should_apply_bitwise_not_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + + let output = tensor_1.bitwise_not(); + + output + .into_data() + .assert_eq(&TensorData::from([[-4, -5, -6], [-10, -4, -9]]), false); + } + + #[test] + fn should_apply_bitwise_or_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 5; + + let output = tensor_1.bitwise_or_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[7, 5, 5], [13, 7, 13]]), false); + } + + #[test] + fn should_apply_bitwise_or_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); + + let output = tensor_1.bitwise_or(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[7, 7, 13], [9, 11, 15]]), false); + } + + #[test] + fn should_apply_bitwise_or_1d() { + let tensor_1 = TestTensorInt::<1>::from([13, 7]); + let tensor_2 = TestTensorInt::from([11, 3]); + + let output = tensor_1.bitwise_or(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([15, 7]), false); + } + + #[test] + fn should_apply_bitwise_xor_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 5; + + let output = tensor_1.bitwise_xor_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[6, 1, 0], [12, 6, 13]]), false); + } + + #[test] + fn should_apply_bitwise_xor_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); + + let output = tensor_1.bitwise_xor(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[5, 3, 13], [0, 9, 7]]), false); + } + + #[test] + fn should_apply_bitwise_xor_1d() { + let tensor_1 = TestTensorInt::<1>::from([13, 7]); + let tensor_2 = TestTensorInt::from([11, 3]); + + let output = tensor_1.bitwise_xor(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([6, 4]), false); + } + + #[test] + fn should_apply_bitwise_left_shift_2d() { + if (IntType::MAX as u32) < 512 { + return; + } + + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); + + let output = tensor_1.bitwise_left_shift(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[6, 16, 40], [144, 96, 512]]), false); + } + + #[test] + fn should_apply_bitwise_left_shift_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 2; + + let output = tensor_1.bitwise_left_shift_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[12, 16, 20], [36, 12, 32]]), false); + } + + #[test] + fn should_apply_bitwise_right_shift_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); + + let output = tensor_1.bitwise_right_shift(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[1, 1, 0], [0, 0, 0]]), false); + } + + #[test] + fn should_apply_bitwise_right_shift_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 2; + + let output = tensor_1.bitwise_right_shift_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[0, 1, 1], [2, 0, 2]]), false); + } +} diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index b1096e0216..32bdd0f4ba 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -7,6 +7,7 @@ mod arange; mod arange_step; mod arg; mod argwhere_nonzero; +mod bitwise; mod bool; mod cartesian_grid; mod cast; diff --git a/crates/burn-train/src/metric/cpu_use.rs b/crates/burn-train/src/metric/cpu_use.rs index 2769793088..d06d8429db 100644 --- a/crates/burn-train/src/metric/cpu_use.rs +++ b/crates/burn-train/src/metric/cpu_use.rs @@ -26,7 +26,9 @@ impl CpuUse { } fn refresh(sys: &mut System) -> f64 { - sys.refresh_specifics(RefreshKind::new().with_cpu(CpuRefreshKind::new().with_cpu_usage())); + sys.refresh_specifics( + RefreshKind::nothing().with_cpu(CpuRefreshKind::nothing().with_cpu_usage()), + ); let cpus = sys.cpus(); let num_cpus = cpus.len(); diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index c2e034ada5..e0c247172d 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -17,10 +17,17 @@ default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"] doc = ["burn-jit/doc"] exclusive-memory-only = ["cubecl/exclusive-memory-only"] fusion = ["burn-fusion", "burn-jit/fusion"] -spirv = ["cubecl/wgpu-spirv"] std = ["burn-jit/std", "cubecl/std"] template = ["burn-jit/template", "cubecl/template"] +# Backends +webgpu = ["cubecl-wgsl"] +vulkan = ["cubecl-spirv"] + +# Compilers +cubecl-wgsl = [] +cubecl-spirv = ["cubecl/wgpu-spirv"] + [dependencies] cubecl = { workspace = true, features = ["wgpu"] } diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 7aab106b29..c11854fcaf 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -12,21 +12,21 @@ pub use burn_jit::{ pub use burn_jit::{tensor::JitTensor, JitBackend}; pub use burn_jit::{BoolElement, FloatElement, IntElement}; pub use cubecl::flex32; -pub use cubecl::wgpu::*; pub use cubecl::CubeDim; -pub type Wgsl = cubecl::wgpu::WgslCompiler; -#[cfg(feature = "spirv")] -pub type SpirV = cubecl::wgpu::spirv::VkSpirvCompiler; +pub use cubecl::wgpu::{ + init_device, init_setup, init_setup_async, MemoryConfiguration, RuntimeOptions, WgpuDevice, + WgpuResource, WgpuRuntime, WgpuSetup, WgpuStorage, +}; +// Vulkan and WebGpu would have conflicting type names +pub mod graphics { + pub use cubecl::wgpu::{AutoGraphicsApi, Dx12, GraphicsApi, Metal, OpenGl, Vulkan, WebGpu}; +} -#[cfg(feature = "spirv")] -type Compiler = SpirV; -#[cfg(feature = "spirv")] -type Bool = u8; -#[cfg(not(feature = "spirv"))] -type Compiler = Wgsl; -#[cfg(not(feature = "spirv"))] -type Bool = u32; +#[cfg(feature = "cubecl-spirv")] +pub use cubecl::wgpu::spirv::SpirvCompiler; +#[cfg(feature = "cubecl-wgsl")] +pub use cubecl::wgpu::WgslCompiler; #[cfg(feature = "fusion")] /// Tensor backend that uses the wgpu crate for executing GPU compute shaders. @@ -44,14 +44,14 @@ type Bool = u32; /// ```rust, ignore /// fn custom_init() { /// let device = Default::default(); -/// burn::backend::wgpu::init_sync::( +/// burn::backend::wgpu::init_setup::( /// &device, /// Default::default(), /// ); /// } /// ``` /// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API. -/// It's also possible to use an existing wgpu device, by using `init_existing_device`. +/// It's also possible to use an existing wgpu device, by using `init_device`. /// /// # Notes /// @@ -60,7 +60,7 @@ type Bool = u32; /// /// You can disable the `fusion` feature flag to remove that functionality, which might be /// necessary on `wasm` for now. -pub type Wgpu = +pub type Wgpu = burn_fusion::Fusion, F, I, B>>; #[cfg(not(feature = "fusion"))] @@ -79,14 +79,14 @@ pub type Wgpu = /// ```rust, ignore /// fn custom_init() { /// let device = Default::default(); -/// burn::backend::wgpu::init_sync::( +/// burn::backend::wgpu::init_setup::( /// &device, /// Default::default(), /// ); /// } /// ``` /// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API. -/// It's also possible to use an existing wgpu device, by using `init_existing_device`. +/// It's also possible to use an existing wgpu device, by using `init_device`. /// /// # Notes /// @@ -95,20 +95,33 @@ pub type Wgpu = /// /// You can enable the `fusion` feature flag to add that functionality, which might improve /// performance. -pub type Wgpu = +pub type Wgpu = JitBackend, F, I, B>; +#[cfg(feature = "vulkan")] +/// Tensor backend that leverages the Vulkan graphics API to execute GPU compute shaders compiled to SPIR-V. +pub type Vulkan = Wgpu; + +#[cfg(feature = "webgpu")] +/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL. +pub type WebGpu = Wgpu; + #[cfg(test)] mod tests { use burn_jit::JitBackend; - #[cfg(feature = "spirv")] + #[cfg(feature = "vulkan")] pub use half::f16; - pub type TestRuntime = cubecl::wgpu::WgpuRuntime; + + #[cfg(feature = "cubecl-spirv")] + type Compiler = cubecl::wgpu::spirv::VkSpirvCompiler; + #[cfg(not(feature = "cubecl-spirv"))] + type Compiler = cubecl::wgpu::WgslCompiler; + pub type TestRuntime = cubecl::wgpu::WgpuRuntime; // Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it // breaks a lot of tests from precision issues - #[cfg(feature = "spirv")] + #[cfg(feature = "vulkan")] burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]); - #[cfg(not(feature = "spirv"))] + #[cfg(not(feature = "vulkan"))] burn_jit::testgen_all!([f32], [i32], [u32]); } diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index cd13682a4b..b0abf7d178 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -50,15 +50,16 @@ openblas-system = ["burn-core/openblas-system"] template = ["burn-core/template"] candle = ["burn-core/candle"] -cuda-jit = ["burn-core/cuda-jit"] -hip-jit = ["burn-core/hip-jit"] +cuda = ["burn-core/cuda"] +hip = ["burn-core/hip"] ndarray = ["burn-core/ndarray"] remote = ["burn-core/remote"] router = ["burn-core/router"] server = ["burn-core/server"] tch = ["burn-core/tch"] wgpu = ["burn-core/wgpu"] -wgpu-spirv = ["burn-core/wgpu-spirv"] +vulkan = ["burn-core/vulkan"] +webgpu = ["burn-core/webgpu"] # Network utils network = ["burn-core/network"] diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs index b0ecf06a71..203d1a802d 100644 --- a/crates/burn/src/lib.rs +++ b/crates/burn/src/lib.rs @@ -76,12 +76,14 @@ //! - `vision`: Enables vision datasets (MnistDataset) //! - Backends //! - `wgpu`: Makes available the WGPU backend -//! - `wgpu-spirv`: Makes available the `wgpu` backend with the alternative SPIR-V compiler +//! - `webgpu`: Makes available the `wgpu` backend with the WebGPU Shading Language (WGSL) compiler +//! - `vulkan`: Makes available the `wgpu` backend with the alternative SPIR-V compiler +//! - `cuda`: Makes available the CUDA backend +//! - `hip`: Makes available the HIP backend //! - `candle`: Makes available the Candle backend //! - `tch`: Makes available the LibTorch backend //! - `ndarray`: Makes available the NdArray backend //! - Backend specifications -//! - `cuda`: If supported, CUDA will be used //! - `accelerate`: If supported, Accelerate will be used //! - `blas-netlib`: If supported, Blas Netlib will be use //! - `openblas`: If supported, Openblas will be use diff --git a/examples/custom-cubecl-kernel/src/kernel.rs b/examples/custom-cubecl-kernel/src/kernel.rs index 0809971327..08d4ded4d7 100644 --- a/examples/custom-cubecl-kernel/src/kernel.rs +++ b/examples/custom-cubecl-kernel/src/kernel.rs @@ -17,7 +17,7 @@ pub fn fused_matmul_add_relu_kernel( let dim_k = rhs.shape(rhs.rank() - 1); if row >= n_rows || col >= n_cols { - return; + terminate!(); } let offset_output = batch * n_rows * n_cols; diff --git a/examples/custom-renderer/examples/custom-renderer.rs b/examples/custom-renderer/examples/custom-renderer.rs index ea580833df..aa344b1d2b 100644 --- a/examples/custom-renderer/examples/custom-renderer.rs +++ b/examples/custom-renderer/examples/custom-renderer.rs @@ -1,5 +1,5 @@ -use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu}; +use burn::backend::{wgpu::WgpuDevice, Autodiff, WebGpu}; fn main() { - custom_renderer::run::>(WgpuDevice::default()); + custom_renderer::run::>(WgpuDevice::default()); } diff --git a/examples/custom-training-loop/Cargo.toml b/examples/custom-training-loop/Cargo.toml index 536307fdba..6e1fca1e92 100644 --- a/examples/custom-training-loop/Cargo.toml +++ b/examples/custom-training-loop/Cargo.toml @@ -7,7 +7,7 @@ publish = false version.workspace = true [dependencies] -burn = {path = "../../crates/burn", features=["autodiff", "wgpu", "vision"]} +burn = {path = "../../crates/burn", features=["autodiff", "webgpu", "vision"]} guide = {path = "../guide"} # Serialization diff --git a/examples/custom-training-loop/examples/custom-training-loop.rs b/examples/custom-training-loop/examples/custom-training-loop.rs index a418ede196..ec9d55f42a 100644 --- a/examples/custom-training-loop/examples/custom-training-loop.rs +++ b/examples/custom-training-loop/examples/custom-training-loop.rs @@ -1,5 +1,5 @@ -use burn::backend::{Autodiff, Wgpu}; +use burn::backend::{Autodiff, WebGpu}; fn main() { - custom_training_loop::run::>(Default::default()); + custom_training_loop::run::>(Default::default()); } diff --git a/examples/guide/Cargo.toml b/examples/guide/Cargo.toml index e60b8d45e5..aea61f5e25 100644 --- a/examples/guide/Cargo.toml +++ b/examples/guide/Cargo.toml @@ -10,7 +10,7 @@ version.workspace = true default = ["burn/default"] [dependencies] -burn = {path = "../../crates/burn", features = ["wgpu", "train", "vision"]} +burn = {path = "../../crates/burn", features = ["webgpu", "train", "vision"]} # Serialization log = {workspace = true} diff --git a/examples/guide/src/bin/infer.rs b/examples/guide/src/bin/infer.rs index 6a246d85f0..44c5b1dabc 100644 --- a/examples/guide/src/bin/infer.rs +++ b/examples/guide/src/bin/infer.rs @@ -1,9 +1,9 @@ #![recursion_limit = "131"] -use burn::{backend::Wgpu, data::dataset::Dataset}; +use burn::{backend::WebGpu, data::dataset::Dataset}; use guide::inference; fn main() { - type MyBackend = Wgpu; + type MyBackend = WebGpu; let device = burn::backend::wgpu::WgpuDevice::default(); diff --git a/examples/guide/src/bin/print.rs b/examples/guide/src/bin/print.rs index 9432aa93a4..6f3b710c25 100644 --- a/examples/guide/src/bin/print.rs +++ b/examples/guide/src/bin/print.rs @@ -1,8 +1,8 @@ -use burn::backend::Wgpu; +use burn::backend::WebGpu; use guide::model::ModelConfig; fn main() { - type MyBackend = Wgpu; + type MyBackend = WebGpu; let device = Default::default(); let model = ModelConfig::new(10, 512).init::(&device); diff --git a/examples/guide/src/bin/train.rs b/examples/guide/src/bin/train.rs index 04f1f44146..a4acf02b69 100644 --- a/examples/guide/src/bin/train.rs +++ b/examples/guide/src/bin/train.rs @@ -1,5 +1,5 @@ use burn::{ - backend::{Autodiff, Wgpu}, + backend::{Autodiff, WebGpu}, data::dataset::Dataset, optim::AdamConfig, }; @@ -10,7 +10,7 @@ use guide::{ }; fn main() { - type MyBackend = Wgpu; + type MyBackend = WebGpu; type MyAutodiffBackend = Autodiff; // Create a default Wgpu device diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index 4b20507abc..a9868099f6 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -14,7 +14,7 @@ use burn::{ tensor::activation::softmax, }; -use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; +use burn::backend::wgpu::{graphics::AutoGraphicsApi, WebGpu, WgpuDevice}; use burn_candle::Candle; use serde::Serialize; @@ -37,8 +37,8 @@ pub enum ModelType { /// The model is loaded to the NdArray backend WithNdArrayBackend(Model>), - /// The model is loaded to the Wgpu backend - WithWgpuBackend(Model>), + /// The model is loaded to the WebGpu backend + WithWgpuBackend(Model>), } /// The image is 224x224 pixels with 3 channels (RGB) diff --git a/examples/raspberry-pi-pico/src/bin/main.rs b/examples/raspberry-pi-pico/src/bin/main.rs index 1b7f6acdf0..a502a8193e 100644 --- a/examples/raspberry-pi-pico/src/bin/main.rs +++ b/examples/raspberry-pi-pico/src/bin/main.rs @@ -10,7 +10,7 @@ use embassy_rp as _; use embedded_alloc::Heap; type Backend = NdArray; -type BackendDeice = ::Device; +type BackendDevice = ::Device; #[global_allocator] static HEAP: Heap = Heap::empty(); @@ -25,7 +25,7 @@ async fn main(_spawner: Spawner) { } // Get a default device for the backend - let device = BackendDeice::default(); + let device = BackendDevice::default(); // Create a new model and load the state let model: Model = Model::default(); @@ -47,7 +47,7 @@ async fn main(_spawner: Spawner) { } } -fn run_model<'a>(model: &Model, device: &BackendDeice, input: f32) -> Tensor { +fn run_model<'a>(model: &Model, device: &BackendDevice, input: f32) -> Tensor { // Define the tensor let input = Tensor::::from_floats([[input]], &device); diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index bb4824fba9..f9f80bdb8d 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -7,10 +7,10 @@ publish = false version.workspace = true [features] -default = ["wgpu"] -cuda-jit = ["burn/cuda-jit"] -wgpu = ["burn/wgpu"] -wgpu-spirv = ["wgpu", "burn/wgpu-spirv"] +default = ["webgpu"] +cuda = ["burn/cuda"] +webgpu = ["burn/webgpu"] +vulkan = ["burn/vulkan"] ndarray = ["burn/ndarray"] [dependencies] diff --git a/examples/server/src/lib.rs b/examples/server/src/lib.rs index 70705a0876..014a5e2cf5 100644 --- a/examples/server/src/lib.rs +++ b/examples/server/src/lib.rs @@ -11,10 +11,12 @@ pub fn start() { cfg_if::cfg_if! { if #[cfg(feature = "ndarray")]{ burn::server::start::(Default::default(), port); - } else if #[cfg(feature = "cuda-jit")]{ - burn::server::start::(Default::default(), port); - } else if #[cfg(feature = "wgpu")] { - burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "cuda")]{ + burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "webgpu")] { + burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "vulkan")] { + burn::server::start::(Default::default(), port); } else { panic!("No backend selected, can't start server on port {port}"); } diff --git a/examples/text-classification/Cargo.toml b/examples/text-classification/Cargo.toml index 4ec5d7c89a..043c61672d 100644 --- a/examples/text-classification/Cargo.toml +++ b/examples/text-classification/Cargo.toml @@ -16,10 +16,10 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] -wgpu-spirv = ["wgpu", "burn/wgpu-spirv"] +vulkan = ["wgpu", "burn/vulkan"] remote = ["burn/remote"] -cuda-jit = ["burn/cuda-jit"] -hip-jit = ["burn/hip-jit"] +cuda = ["burn/cuda"] +hip = ["burn/hip"] [dependencies] # Burn diff --git a/examples/text-classification/README.md b/examples/text-classification/README.md index 8bc611361f..9d62606706 100644 --- a/examples/text-classification/README.md +++ b/examples/text-classification/README.md @@ -102,6 +102,6 @@ cd burn # Use the --release flag to really speed up training. # AG News -cargo run --example ag-news-train --release --features cuda-jit # Train on the ag news dataset -cargo run --example ag-news-infer --release --features cuda-jit # Run inference on the ag news dataset +cargo run --example ag-news-train --release --features cuda # Train on the ag news dataset +cargo run --example ag-news-infer --release --features cuda # Run inference on the ag news dataset ``` diff --git a/examples/text-classification/examples/ag-news-infer.rs b/examples/text-classification/examples/ag-news-infer.rs index 9af5c6c6eb..77626e0b60 100644 --- a/examples/text-classification/examples/ag-news-infer.rs +++ b/examples/text-classification/examples/ag-news-infer.rs @@ -81,13 +81,13 @@ mod wgpu { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::{launch, ElemType}; - use burn::backend::{cuda_jit::CudaDevice, CudaJit}; + use burn::backend::{cuda::CudaDevice, Cuda}; pub fn run() { - launch::>(CudaDevice::default()); + launch::>(CudaDevice::default()); } } @@ -105,6 +105,6 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); } diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 1be9803a15..927c190b2c 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -103,23 +103,23 @@ mod remote { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::{launch, ElemType}; - use burn::backend::{Autodiff, CudaJit}; + use burn::backend::{Autodiff, Cuda}; pub fn run() { - launch::>>(vec![Default::default()]); + launch::>>(vec![Default::default()]); } } -#[cfg(feature = "hip-jit")] -mod hip_jit { +#[cfg(feature = "hip")] +mod hip { use crate::{launch, ElemType}; - use burn::backend::{Autodiff, HipJit}; + use burn::backend::{Autodiff, Hip}; pub fn run() { - launch::>>(vec![Default::default()]); + launch::>>(vec![Default::default()]); } } @@ -137,10 +137,10 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); - #[cfg(feature = "hip-jit")] - hip_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); + #[cfg(feature = "hip")] + hip::run(); #[cfg(feature = "remote")] remote::run(); } diff --git a/examples/text-classification/examples/db-pedia-infer.rs b/examples/text-classification/examples/db-pedia-infer.rs index 490ed3b97e..027eb76122 100644 --- a/examples/text-classification/examples/db-pedia-infer.rs +++ b/examples/text-classification/examples/db-pedia-infer.rs @@ -1,6 +1,6 @@ use text_classification::DbPediaDataset; -use burn::tensor::backend::AutodiffBackend; +use burn::tensor::backend::Backend; #[cfg(not(feature = "f16"))] #[allow(dead_code)] @@ -8,7 +8,7 @@ type ElemType = f32; #[cfg(feature = "f16")] type ElemType = burn::tensor::f16; -pub fn launch(device: B::Device) { +pub fn launch(device: B::Device) { text_classification::inference::infer::( device, "/tmp/text-classification-db-pedia", @@ -34,24 +34,18 @@ pub fn launch(device: B::Device) { feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::{ - ndarray::{NdArray, NdArrayDevice}, - Autodiff, - }; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(NdArrayDevice::Cpu); + launch::>(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::{ - libtorch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::{launch, ElemType}; @@ -61,35 +55,29 @@ mod tch_gpu { #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; - launch::>>(device); + launch::>(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::{ - tch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::tch::{LibTorch, LibTorchDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(LibTorchDevice::Cpu); + launch::>(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::{ - wgpu::{Wgpu, WgpuDevice}, - Autodiff, - }; + use burn::backend::wgpu::{Wgpu, WgpuDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(WgpuDevice::default()); + launch::>(WgpuDevice::default()); } } diff --git a/examples/wgan/Cargo.toml b/examples/wgan/Cargo.toml index 48d5680f51..d6ee6345b1 100644 --- a/examples/wgan/Cargo.toml +++ b/examples/wgan/Cargo.toml @@ -11,7 +11,7 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] -cuda-jit = ["burn/cuda-jit"] +cuda = ["burn/cuda"] [dependencies] burn = { path = "../../crates/burn", features=["train", "vision"] } diff --git a/examples/wgan/README.md b/examples/wgan/README.md index d7252ba520..0828145f61 100644 --- a/examples/wgan/README.md +++ b/examples/wgan/README.md @@ -12,7 +12,7 @@ Please note that better performance maybe gained by adopting a convolution layer ```sh # Cuda backend -cargo run --example wgan-mnist --release --features cuda-jit +cargo run --example wgan-mnist --release --features cuda # Wgpu backend cargo run --example wgan-mnist --release --features wgpu @@ -36,5 +36,5 @@ cargo run --example wgan-mnist --release --features ndarray-blas-netlib # f32 To generate a sample of images, you can use `wgan-generate`. The same feature flags are used to select a backend. ```sh -cargo run --example wgan-generate --release --features cuda-jit +cargo run --example wgan-generate --release --features cuda ``` diff --git a/examples/wgan/examples/wgan-generate.rs b/examples/wgan/examples/wgan-generate.rs index fa66623ca3..1d0a4fd87d 100644 --- a/examples/wgan/examples/wgan-generate.rs +++ b/examples/wgan/examples/wgan-generate.rs @@ -11,24 +11,18 @@ pub fn launch(device: B::Device) { feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::{ - ndarray::{NdArray, NdArrayDevice}, - Autodiff, - }; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; use crate::launch; pub fn run() { - launch::>(NdArrayDevice::Cpu); + launch::(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::{ - libtorch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::launch; @@ -38,41 +32,38 @@ mod tch_gpu { #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; - launch::>(device); + launch::(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::{ - libtorch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::launch; pub fn run() { - launch::>(LibTorchDevice::Cpu); + launch::(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { use crate::launch; - use burn::backend::{wgpu::Wgpu, Autodiff}; + use burn::backend::wgpu::Wgpu; pub fn run() { - launch::>(Default::default()); + launch::(Default::default()); } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::launch; - use burn::backend::{Autodiff, CudaJit}; + use burn::backend::Cuda; pub fn run() { - launch::>(Default::default()); + launch::(Default::default()); } } @@ -90,6 +81,6 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); } diff --git a/examples/wgan/examples/wgan-mnist.rs b/examples/wgan/examples/wgan-mnist.rs index d964b07844..787acfec94 100644 --- a/examples/wgan/examples/wgan-mnist.rs +++ b/examples/wgan/examples/wgan-mnist.rs @@ -78,13 +78,13 @@ mod wgpu { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::launch; - use burn::backend::{cuda_jit::CudaDevice, Autodiff, CudaJit}; + use burn::backend::{cuda::CudaDevice, Autodiff, Cuda}; pub fn run() { - launch::>(CudaDevice::default()); + launch::>(CudaDevice::default()); } } @@ -102,6 +102,6 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); } diff --git a/examples/wgan/src/model.rs b/examples/wgan/src/model.rs index b9615f5270..755d8e9e1d 100644 --- a/examples/wgan/src/model.rs +++ b/examples/wgan/src/model.rs @@ -83,7 +83,7 @@ impl Discriminator { } } -// Use model config to construct a generative and adverserial model +// Use model config to construct a generative and adversarial model #[derive(Config, Debug)] pub struct ModelConfig { /// Dimensionality of the latent space diff --git a/xtask/src/commands/test.rs b/xtask/src/commands/test.rs index 47e50f80ed..5b94b2909e 100644 --- a/xtask/src/commands/test.rs +++ b/xtask/src/commands/test.rs @@ -83,7 +83,7 @@ pub(crate) fn handle_command( vec!["--features", "test-wgpu-spirv"], None, None, - "std wgpu-spirv", + "std vulkan", )?; } }