diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000..6d0b5bdb0 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# rustfmt codebase (gh-1375) +d07f5f33800e5240e7edb02bdbc4815ab30ef37e diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 000000000..ae74aeb45 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,194 @@ +on: + pull_request: + merge_group: + push: + branches: + - master + +name: Continuous integration + +env: + CARGO_TERM_COLOR: always + HOST: x86_64-unknown-linux-gnu + FEATURES: "test docs" + RUSTFLAGS: "-D warnings" + MSRV: 1.64.0 + BLAS_MSRV: 1.71.1 + +jobs: + clippy: + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - stable + name: clippy/${{ matrix.rust }} + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + components: clippy + - uses: Swatinem/rust-cache@v2 + - run: cargo clippy --features docs + + format: + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - nightly + name: format/${{ matrix.rust }} + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + components: rustfmt + - run: cargo fmt --all --check + + nostd: + runs-on: ubuntu-latest + continue-on-error: ${{ matrix.experimental }} + strategy: + matrix: + include: + - rust: stable + experimental: false + target: thumbv6m-none-eabi + + name: nostd/${{ matrix.target }}/${{ matrix.rust }} + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: ${{ matrix.rust }} + targets: ${{ matrix.target }} + - name: Tests + run: | + cargo rustc "--target=${{ matrix.target }}" --no-default-features --features portable-atomic-critical-section + + tests: + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - stable + - beta + - nightly + - 1.64.0 # MSRV + + name: tests/${{ matrix.rust }} + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + - uses: rui314/setup-mold@v1 + - uses: Swatinem/rust-cache@v2 + - name: Install openblas + run: sudo apt-get install libopenblas-dev gfortran + - run: ./scripts/all-tests.sh "$FEATURES" ${{ matrix.rust }} + + blas-msrv: + runs-on: ubuntu-latest + name: blas-msrv + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: 1.71.1 # BLAS MSRV + - uses: rui314/setup-mold@v1 + - uses: Swatinem/rust-cache@v2 + - name: Install openblas + run: sudo apt-get install libopenblas-dev gfortran + - run: cargo tree -p blas-tests -i openblas-src -F blas-tests/openblas-system + - run: cargo tree -p blas-tests -i openblas-build -F blas-tests/openblas-system + - run: ./scripts/blas-integ-tests.sh $BLAS_MSRV + + miri: + runs-on: ubuntu-latest + name: miri + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@nightly + with: + components: miri + - uses: Swatinem/rust-cache@v2 + - run: ./scripts/miri-tests.sh + + cross_test: + #if: ${{ github.event_name == 'merge_group' }} + runs-on: ubuntu-latest + strategy: + matrix: + include: + - rust: stable + target: s390x-unknown-linux-gnu + - rust: stable + target: i686-unknown-linux-gnu + + name: cross_test/${{ matrix.target }}/${{ matrix.rust }} + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + targets: ${{ matrix.target }} + - uses: rui314/setup-mold@v1 + - uses: Swatinem/rust-cache@v2 + - name: Install cross + run: cargo install cross + - run: ./scripts/cross-tests.sh "docs" ${{ matrix.rust }} ${{ matrix.target }} + + cargo-careful: + #if: ${{ github.event_name == 'merge_group' }} + runs-on: ubuntu-latest + name: cargo-careful + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + - uses: Swatinem/rust-cache@v2 + - name: Install cargo-careful + run: cargo install cargo-careful + - run: cargo careful test -Zcareful-sanitizer --features="$FEATURES" + + docs: + #if: ${{ github.event_name == 'merge_group' }} + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - stable + name: docs/${{ matrix.rust }} + env: + RUSTDOCFLAGS: "-Dwarnings" + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + - run: cargo doc --no-deps --all-features + + conclusion: + needs: + - clippy + - format # should format be required? + - nostd + - tests + - miri + - cross_test + - cargo-careful + - docs + if: always() + runs-on: ubuntu-latest + steps: + - name: Result + run: | + jq -C <<< "${needs}" + # Check if all needs were successful or skipped. + "$(jq -r 'all(.result as $result | (["success", "skipped"] | contains([$result])))' <<< "${needs}")" + env: + needs: ${{ toJson(needs) }} diff --git a/.gitignore b/.gitignore index 1e7caa9ea..e9b5ca25b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,7 @@ +# Rust items Cargo.lock target/ + +# Editor settings +.vscode +.idea diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index c46a2c790..000000000 --- a/.travis.yml +++ /dev/null @@ -1,34 +0,0 @@ -language: rust -# use trusty for newer openblas -sudo: required -dist: trusty -matrix: - include: - - rust: 1.37.0 - env: - - FEATURES='test docs' - - RUSTFLAGS='-D warnings' - - rust: stable - env: - - FEATURES='test docs' - - RUSTFLAGS='-D warnings' - - rust: beta - env: - - FEATURES='test docs' - - CHANNEL='beta' - - RUSTFLAGS='-D warnings' - - rust: nightly - env: - - FEATURES='test docs' - - CHANNEL='nightly' -env: - global: - - HOST=x86_64-unknown-linux-gnu - - CARGO_INCREMENTAL=0 -addons: - apt: - packages: - - libopenblas-dev - - gfortran -script: - - ./scripts/all-tests.sh "$FEATURES" "$CHANNEL" diff --git a/Cargo.toml b/Cargo.toml index 3da63d1d7..5c7217025 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,14 @@ [package] name = "ndarray" -version = "0.13.0" +version = "0.16.1" edition = "2018" +rust-version = "1.64" authors = [ - "bluss", + "Ulrik Sverdrup \"bluss\"", "Jim Turner" ] -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" readme = "README-crates.io.md" repository = "https://github.com/rust-ndarray/ndarray" @@ -18,9 +19,8 @@ description = "An n-dimensional array for general elements and for numerics. Lig keywords = ["array", "data-structure", "multidimensional", "matrix", "blas"] categories = ["data-structures", "science"] -build = "build.rs" - exclude = ["docgen/images/*"] +resolver = "2" [lib] name = "ndarray" @@ -28,50 +28,94 @@ bench = false test = true [dependencies] -num-integer = "0.1.39" -num-traits = "0.2" -num-complex = "0.2" +num-integer = { workspace = true } +num-traits = { workspace = true } +num-complex = { workspace = true } -rayon = { version = "1.0.3", optional = true } +approx = { workspace = true, optional = true } +rayon = { version = "1.10.0", optional = true } -approx = { version = "0.3.2", optional = true } +# Use via the `blas` crate feature +cblas-sys = { workspace = true, optional = true } +libc = { version = "0.2.82", optional = true } -# Use via the `blas` crate feature! -cblas-sys = { version = "0.1.4", optional = true, default-features = false } -blas-src = { version = "0.2.0", optional = true, default-features = false } +matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm"] } -matrixmultiply = { version = "0.2.0" } -serde = { version = "1.0", optional = true } +serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] } rawpointer = { version = "0.2" } [dev-dependencies] defmac = "0.2" -quickcheck = { version = "0.9", default-features = false } -approx = "0.3.2" -itertools = { version = "0.8.0", default-features = false, features = ["use_std"] } +quickcheck = { workspace = true } +approx = { workspace = true, default-features = true } +itertools = { workspace = true } +ndarray-gen = { workspace = true } [features] +default = ["std"] + # Enable blas usage # See README for more instructions -blas = ["cblas-sys", "blas-src"] +blas = ["dep:cblas-sys", "dep:libc"] +serde = ["dep:serde"] # Old name for the serde feature -serde-1 = ["serde"] +serde-1 = ["dep:serde"] # These features are used for testing -test-blas-openblas-sys = ["blas"] -test = ["test-blas-openblas-sys"] +test = [] # This feature is used for docs docs = ["approx", "serde", "rayon"] -[profile.release] +std = ["num-traits/std", "matrixmultiply/std"] +rayon = ["dep:rayon", "std"] + +matrixmultiply-threading = ["matrixmultiply/threading"] + +portable-atomic-critical-section = ["portable-atomic/critical-section"] + + +[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] +portable-atomic = { version = "1.6.0" } +portable-atomic-util = { version = "0.2.0", features = [ "alloc" ] } + +[workspace] +members = [ + "ndarray-rand", + "crates/*", +] +default-members = [ + ".", + "ndarray-rand", + "crates/ndarray-gen", + "crates/numeric-tests", + "crates/serialization-tests", + # exclude blas-tests and blas-mock-tests that activate "blas" feature +] + +[workspace.dependencies] +ndarray = { version = "0.16", path = ".", default-features = false } +ndarray-rand = { path = "ndarray-rand" } +ndarray-gen = { path = "crates/ndarray-gen" } + +num-integer = { version = "0.1.39", default-features = false } +num-traits = { version = "0.2", default-features = false } +num-complex = { version = "0.4", default-features = false } +approx = { version = "0.5", default-features = false } +quickcheck = { version = "1.0", default-features = false } +rand = { version = "0.8.0", features = ["small_rng"] } +rand_distr = { version = "0.4.0" } +itertools = { version = "0.13.0", default-features = false, features = ["use_std"] } +cblas-sys = { version = "0.1.4", default-features = false } + [profile.bench] debug = true -[workspace] -members = ["ndarray-rand", "serialization-tests", "blas-tests"] -exclude = ["numeric-tests"] +[profile.test.package.numeric-tests] +opt-level = 2 +[profile.test.package.blas-tests] +opt-level = 2 [package.metadata.release] no-dev-version = true diff --git a/LICENSE-MIT b/LICENSE-MIT index c87e92dc4..d0af99b04 100644 --- a/LICENSE-MIT +++ b/LICENSE-MIT @@ -1,4 +1,4 @@ -Copyright (c) 2015 - 2018 Ulrik Sverdrup "bluss", +Copyright (c) 2015 - 2021 Ulrik Sverdrup "bluss", Jim Turner, and ndarray developers diff --git a/README-quick-start.md b/README-quick-start.md new file mode 100644 index 000000000..ad13acc72 --- /dev/null +++ b/README-quick-start.md @@ -0,0 +1,622 @@ +# Quickstart tutorial + +If you are familiar with Python Numpy, do check out this [For Numpy User Doc](https://docs.rs/ndarray/0.13.0/ndarray/doc/ndarray_for_numpy_users/index.html) +after you go through this tutorial. + +You can use [play.integer32.com](https://play.integer32.com/) to immediately try out the examples. + +## The Basics + +You can create your first 2x3 floating-point ndarray as such: +```rust +use ndarray::prelude::*; + +fn main() { + let a = array![ + [1.,2.,3.], + [4.,5.,6.], + ]; + assert_eq!(a.ndim(), 2); // get the number of dimensions of array a + assert_eq!(a.len(), 6); // get the number of elements in array a + assert_eq!(a.shape(), [2, 3]); // get the shape of array a + assert_eq!(a.is_empty(), false); // check if the array has zero elements + + println!("{:?}", a); +} +``` +This code will create a simple array, then print it to stdout as such: +``` +[[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0]], shape=[2, 3], strides=[3, 1], layout=C (0x1), const ndim=2 +``` + +## Array Creation + +### Element type and dimensionality + +Now let's create more arrays. A common operation on matrices is to create a matrix full of 0's of certain dimensions. Let's try to do that with dimensions (3, 2, 4) using the `Array::zeros` function: +```rust +use ndarray::prelude::*; +use ndarray::Array; +fn main() { + let a = Array::zeros((3, 2, 4).f()); + println!("{:?}", a); +} +``` +Unfortunately, this code does not compile. +``` +| let a = Array::zeros((3, 2, 4).f()); +| - ^^^^^^^^^^^^ cannot infer type for type parameter `A` +``` +Indeed, note that the compiler needs to infer the element type and dimensionality from context only. In this +case the compiler does not have enough information. To fix the code, we can explicitly give the element type through turbofish syntax, and let it infer the dimensionality type: + +```rust +use ndarray::prelude::*; +use ndarray::Array; +fn main() { + let a = Array::::zeros((3, 2, 4).f()); + println!("{:?}", a); +} +``` +This code now compiles to what we wanted: +``` +[[[0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0]], + + [[0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0]], + + [[0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0]]], shape=[3, 2, 4], strides=[1, 3, 6], layout=F (0x2), const ndim=3 +``` + +We could also specify its dimensionality explicitly `Array::::zeros(...)`, with`Ix3` standing for 3D array type. Phew! We achieved type safety. If you tried changing the code above to `Array::::zeros((3, 2, 4, 5).f());`, which is not of dimension 3 anymore, Rust's type system would gracefully prevent you from compiling the code. + +### Creating arrays with different initial values and/or different types + +The [`from_elem`](http://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html#method.from_elem) method allows initializing an array of given dimension to a specific value of any type: + +```rust +use ndarray::{Array, Ix3}; +fn main() { + let a = Array::::from_elem((3, 2, 4), false); + println!("{:?}", a); +} +``` + +### Some common array initializing helper functions +`linspace` - Create a 1-D array with 11 elements with values 0., …, 5. +```rust +use ndarray::prelude::*; +use ndarray::{Array, Ix3}; +fn main() { + let a = Array::::linspace(0., 5., 11); + println!("{:?}", a); +} +``` +The output is: +``` +[0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0], shape=[11], strides=[1], layout=C | F (0x3), const ndim=1 +``` + +Common array initializing methods include [`range`](https://docs.rs/ndarray/0.13.0/ndarray/struct.ArrayBase.html#method.range), [`logspace`](https://docs.rs/ndarray/0.13.0/ndarray/struct.ArrayBase.html#method.logspace), [`eye`](https://docs.rs/ndarray/0.13.0/ndarray/struct.ArrayBase.html#method.eye), [`ones`](https://docs.rs/ndarray/0.13.0/ndarray/struct.ArrayBase.html#method.ones)... + +## Basic operations + +Basic operations on arrays are all element-wise; you need to use specific methods for operations such as matrix multiplication (see later section). +```rust +use ndarray::prelude::*; +use ndarray::Array; +use std::f64::INFINITY as inf; + +fn main() { + let a = array![ + [10.,20.,30., 40.,], + ]; + let b = Array::range(0., 4., 1.); // [0., 1., 2., 3, ] + + assert_eq!(&a + &b, array![[10., 21., 32., 43.,]]); // Allocates a new array. Note the explicit `&`. + assert_eq!(&a - &b, array![[10., 19., 28., 37.,]]); + assert_eq!(&a * &b, array![[0., 20., 60., 120.,]]); + assert_eq!(&a / &b, array![[inf, 20., 15., 13.333333333333334,]]); +} +``` + + +Note that (for any binary operator `@`): +* `&A @ &A` produces a new `Array` +* `B @ A` consumes `B`, updates it with the result, and returns it +* `B @ &A` consumes `B`, updates it with the result, and returns it +* `C @= &A` performs an arithmetic operation in place + +Try removing all the `&` sign in front of `a` and `b` in the last example: it will not compile anymore because of those rules. + +For more info checkout https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html#arithmetic-operations + + + +Some operations have `_axis` appended to the function name: they generally take in a parameter of type `Axis` as one of their inputs, +such as `sum_axis`: + +```rust +use ndarray::{aview0, aview1, arr2, Axis}; + +fn main() { + let a = arr2(&[[1., 2., 3.], + [4., 5., 6.]]); + assert!( + a.sum_axis(Axis(0)) == aview1(&[5., 7., 9.]) && + a.sum_axis(Axis(1)) == aview1(&[6., 15.]) && + + a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&21.) && + a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&a.sum()) + ); +} +``` + +### Matrix product + +```rust +use ndarray::prelude::*; +use ndarray::Array; + +fn main() { + let a = array![ + [10.,20.,30., 40.,], + ]; + let b = Array::range(0., 4., 1.); // b = [0., 1., 2., 3, ] + println!("a shape {:?}", &a.shape()); + println!("b shape {:?}", &b.shape()); + + let b = b.into_shape_with_order((4,1)).unwrap(); // reshape b to shape [4, 1] + println!("b shape after reshape {:?}", &b.shape()); + + println!("{}", a.dot(&b)); // [1, 4] x [4, 1] -> [1, 1] + println!("{}", a.t().dot(&b.t())); // [4, 1] x [1, 4] -> [4, 4] +} +``` +The output is: +``` +a shape [1, 4] +b shape [4] +b shape after reshape [4, 1] +[[200]] +[[0, 10, 20, 30], + [0, 20, 40, 60], + [0, 30, 60, 90], + [0, 40, 80, 120]] +``` + +## Indexing, Slicing and Iterating +One-dimensional arrays can be indexed, sliced and iterated over, much like `numpy` arrays + +```rust +use ndarray::prelude::*; +use ndarray::Array; + +fn main() { + let a = Array::range(0., 10., 1.); + + let mut a = a.mapv(|a: f64| a.powi(3)); // numpy equivlant of `a ** 3`; https://doc.rust-lang.org/nightly/std/primitive.f64.html#method.powi + + println!("{}", a); + + println!("{}", a[[2]]); + println!("{}", a.slice(s![2])); + + println!("{}", a.slice(s![2..5])); + + a.slice_mut(s![..6;2]).fill(1000.); // numpy equivlant of `a[:6:2] = 1000` + println!("{}", a); + + for i in a.iter() { + print!("{}, ", i.powf(1./3.)) + } +} +``` +The output is: +``` +[0, 1, 8, 27, 64, 125, 216, 343, 512, 729] +8 +8 +[8, 27, 64] +[1000, 1, 1000, 27, 1000, 125, 216, 343, 512, 729] +9.999999999999998, 1, 9.999999999999998, 3, 9.999999999999998, 4.999999999999999, 5.999999999999999, 6.999999999999999, 7.999999999999999, 8.999999999999998, +``` + +For more info about iteration see [Loops, Producers, and Iterators](https://docs.rs/ndarray/0.13.0/ndarray/struct.ArrayBase.html#loops-producers-and-iterators) + +Let's try a iterating over a 3D array with elements of type `isize`. This is how you index it: +```rust +use ndarray::prelude::*; + +fn main() { + let a = array![ + [[ 0, 1, 2], // a 3D array 2 x 2 x 3 + [ 10, 12, 13]], + + [[100,101,102], + [110,112,113]] + ]; + + let a = a.mapv(|a: isize| a.pow(1)); // numpy equivalent of `a ** 1`; + // This line does nothing except illustrating mapv with isize type + println!("a -> \n{}\n", a); + + println!("`a.slice(s![1, .., ..])` -> \n{}\n", a.slice(s![1, .., ..])); + + println!("`a.slice(s![.., .., 2])` -> \n{}\n", a.slice(s![.., .., 2])); + + println!("`a.slice(s![.., 1, 0..2])` -> \n{}\n", a.slice(s![.., 1, 0..2])); + + println!("`a.iter()` ->"); + for i in a.iter() { + print!("{}, ", i) // flat out to every element + } + + println!("\n\n`a.outer_iter()` ->"); + for i in a.outer_iter() { + print!("row: {}, \n", i) // iterate through first dimension + } +} +``` +The output is: +``` +a -> +[[[0, 1, 2], + [10, 12, 13]], + + [[100, 101, 102], + [110, 112, 113]]] + +`a.slice(s![1, .., ..])` -> +[[100, 101, 102], + [110, 112, 113]] + +`a.slice(s![.., .., 2])` -> +[[2, 13], + [102, 113]] + +`a.slice(s![.., 1, 0..2])` -> +[[10, 12], + [110, 112]] + +`a.iter()` -> +0, 1, 2, 10, 12, 13, 100, 101, 102, 110, 112, 113, + +`a.outer_iter()` -> +row: [[0, 1, 2], + [10, 12, 13]], +row: [[100, 101, 102], + [110, 112, 113]], +``` + +## Shape Manipulation + +### Changing the shape of an array +The shape of an array can be changed with the `into_shape_with_order` or `to_shape` method. + +````rust +use ndarray::prelude::*; +use ndarray::Array; +use std::iter::FromIterator; +// use ndarray_rand::RandomExt; +// use ndarray_rand::rand_distr::Uniform; + +fn main() { + // Or you may use ndarray_rand crate to generate random arrays + // let a = Array::random((2, 5), Uniform::new(0., 10.)); + + let a = array![ + [3., 7., 3., 4.], + [1., 4., 2., 2.], + [7., 2., 4., 9.]]; + + println!("a = \n{:?}\n", a); + + // use trait FromIterator to flatten a matrix to a vector + let b = Array::from_iter(a.iter()); + println!("b = \n{:?}\n", b); + + let c = b.into_shape_with_order([6, 2]).unwrap(); // consume b and generate c with new shape + println!("c = \n{:?}", c); +} +```` +The output is: +``` +a = +[[3.0, 7.0, 3.0, 4.0], + [1.0, 4.0, 2.0, 2.0], + [7.0, 2.0, 4.0, 9.0]], shape=[3, 4], strides=[4, 1], layout=C (0x1), const ndim=2 + +b = +[3.0, 7.0, 3.0, 4.0, 1.0, 4.0, 2.0, 2.0, 7.0, 2.0, 4.0, 9.0], shape=[12], strides=[1], layout=C | F (0x3), const ndim=1 + +c = +[[3.0, 7.0], + [3.0, 4.0], + [1.0, 4.0], + [2.0, 2.0], + [7.0, 2.0], + [4.0, 9.0]], shape=[6, 2], strides=[2, 1], layout=C (0x1), const ndim=2 +``` + +### Stacking/concatenating together different arrays + +The `stack!` and `concatenate!` macros are helpful for stacking/concatenating +arrays. The `stack!` macro stacks arrays along a new axis, while the +`concatenate!` macro concatenates arrays along an existing axis: + +```rust +use ndarray::prelude::*; +use ndarray::{concatenate, stack, Axis}; + +fn main() { + let a = array![ + [3., 7., 8.], + [5., 2., 4.], + ]; + + let b = array![ + [1., 9., 0.], + [5., 4., 1.], + ]; + + println!("stack, axis 0:\n{:?}\n", stack![Axis(0), a, b]); + println!("stack, axis 1:\n{:?}\n", stack![Axis(1), a, b]); + println!("stack, axis 2:\n{:?}\n", stack![Axis(2), a, b]); + println!("concatenate, axis 0:\n{:?}\n", concatenate![Axis(0), a, b]); + println!("concatenate, axis 1:\n{:?}\n", concatenate![Axis(1), a, b]); +} +``` +The output is: +``` +stack, axis 0: +[[[3.0, 7.0, 8.0], + [5.0, 2.0, 4.0]], + + [[1.0, 9.0, 0.0], + [5.0, 4.0, 1.0]]], shape=[2, 2, 3], strides=[6, 3, 1], layout=Cc (0x5), const ndim=3 + +stack, axis 1: +[[[3.0, 7.0, 8.0], + [1.0, 9.0, 0.0]], + + [[5.0, 2.0, 4.0], + [5.0, 4.0, 1.0]]], shape=[2, 2, 3], strides=[3, 6, 1], layout=c (0x4), const ndim=3 + +stack, axis 2: +[[[3.0, 1.0], + [7.0, 9.0], + [8.0, 0.0]], + + [[5.0, 5.0], + [2.0, 4.0], + [4.0, 1.0]]], shape=[2, 3, 2], strides=[1, 2, 6], layout=Ff (0xa), const ndim=3 + +concatenate, axis 0: +[[3.0, 7.0, 8.0], + [5.0, 2.0, 4.0], + [1.0, 9.0, 0.0], + [5.0, 4.0, 1.0]], shape=[4, 3], strides=[3, 1], layout=Cc (0x5), const ndim=2 + +concatenate, axis 1: +[[3.0, 7.0, 8.0, 1.0, 9.0, 0.0], + [5.0, 2.0, 4.0, 5.0, 4.0, 1.0]], shape=[2, 6], strides=[1, 2], layout=Ff (0xa), const ndim=2 +``` + +### Splitting one array into several smaller ones + +More to see here [ArrayView::split_at](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html#method.split_at) +```rust +use ndarray::prelude::*; +use ndarray::Axis; + +fn main() { + + let a = array![ + [6., 7., 6., 9., 0., 5., 4., 0., 6., 8., 5., 2.], + [8., 5., 5., 7., 1., 8., 6., 7., 1., 8., 1., 0.]]; + + let (s1, s2) = a.view().split_at(Axis(0), 1); + println!("Split a from Axis(0), at index 1:"); + println!("s1 = \n{}", s1); + println!("s2 = \n{}\n", s2); + + + let (s1, s2) = a.view().split_at(Axis(1), 4); + println!("Split a from Axis(1), at index 4:"); + println!("s1 = \n{}", s1); + println!("s2 = \n{}\n", s2); +} +``` +The output is: +``` +Split a from Axis(0), at index 1: +s1 = +[[6, 7, 6, 9, 0, 5, 4, 0, 6, 8, 5, 2]] +s2 = +[[8, 5, 5, 7, 1, 8, 6, 7, 1, 8, 1, 0]] + +Split a from Axis(1), at index 4: +s1 = +[[6, 7, 6, 9], + [8, 5, 5, 7]] +s2 = +[[0, 5, 4, 0, 6, 8, 5, 2], + [1, 8, 6, 7, 1, 8, 1, 0]] + +``` + +## Copies and Views +### View, Ref or Shallow Copy + +Rust has ownership, so we cannot simply update an element of an array while we have a shared view of it. This brings guarantees & helps having more robust code. +```rust +use ndarray::prelude::*; +use ndarray::{Array, Axis}; + +fn main() { + + let mut a = Array::range(0., 12., 1.).into_shape_with_order([3 ,4]).unwrap(); + println!("a = \n{}\n", a); + + { + let (s1, s2) = a.view().split_at(Axis(1), 2); + + // with s as a view sharing the ref of a, we cannot update a here + // a.slice_mut(s![1, 1]).fill(1234.); + + println!("Split a from Axis(0), at index 1:"); + println!("s1 = \n{}", s1); + println!("s2 = \n{}\n", s2); + } + + // now we can update a again here, as views of s1, s2 are dropped already + a.slice_mut(s![1, 1]).fill(1234.); + + let (s1, s2) = a.view().split_at(Axis(1), 2); + println!("Split a from Axis(0), at index 1:"); + println!("s1 = \n{}", s1); + println!("s2 = \n{}\n", s2); +} +``` +The output is: +``` +a = +[[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]] + +Split a from Axis(0), at index 1: +s1 = +[[0, 1], + [4, 5], + [8, 9]] +s2 = +[[2, 3], + [6, 7], + [10, 11]] + +Split a from Axis(0), at index 1: +s1 = +[[0, 1], + [4, 1234], + [8, 9]] +s2 = +[[2, 3], + [6, 7], + [10, 11]] +``` + +### Deep Copy +As the usual way in Rust, a `clone()` call will +make a copy of your array: +```rust +use ndarray::prelude::*; +use ndarray::Array; + +fn main() { + + let mut a = Array::range(0., 4., 1.).into_shape_with_order([2 ,2]).unwrap(); + let b = a.clone(); + + println!("a = \n{}\n", a); + println!("b clone of a = \n{}\n", a); + + a.slice_mut(s![1, 1]).fill(1234.); + + println!("a updated..."); + println!("a = \n{}\n", a); + println!("b clone of a = \n{}\n", b); +} +``` + +The output is: +``` +a = +[[0, 1], + [2, 3]] + +b clone of a = +[[0, 1], + [2, 3]] + +a updated... +a = +[[0, 1], + [2, 1234]] + +b clone of a = +[[0, 1], + [2, 3]] +``` + +Notice that using `clone()` (or cloning) an `Array` type also copies the array's elements. It creates an independently owned array of the same type. + +Cloning an `ArrayView` does not clone or copy the underlying elements - it only clones the view reference (as it happens in Rust when cloning a `&` reference). + +## Broadcasting + +Arrays support limited broadcasting, where arithmetic operations with array operands of different sizes can be carried out by repeating the elements of the smaller dimension array. + +```rust +use ndarray::prelude::*; + +fn main() { + let a = array![ + [1., 1.], + [1., 2.], + [0., 3.], + [0., 4.]]; + + let b = array![[0., 1.]]; + + let c = array![ + [1., 2.], + [1., 3.], + [0., 4.], + [0., 5.]]; + + // We can add because the shapes are compatible even if not equal. + // The `b` array is shape 1 × 2 but acts like a 4 × 2 array. + assert!(c == a + b); +} +``` + +See [.broadcast()](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html#method.broadcast) for a more detailed description. + +And here is a short example of it: +```rust +use ndarray::prelude::*; + +fn main() { + let a = array![ + [1., 2.], + [3., 4.], + ]; + + let b = a.broadcast((3, 2, 2)).unwrap(); + println!("shape of a is {:?}", a.shape()); + println!("a is broadcased to 3x2x2 = \n{}", b); +} +``` +The output is: +``` +shape of a is [2, 2] +a is broadcased to 3x2x2 = +[[[1, 2], + [3, 4]], + + [[1, 2], + [3, 4]], + + [[1, 2], + [3, 4]]] +``` + +## Want to learn more? +Please checkout these docs for more information +* [`ArrayBase` doc page](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html) +* [`ndarray` for `numpy` user doc page](https://docs.rs/ndarray/latest/ndarray/doc/ndarray_for_numpy_users/index.html) diff --git a/README.rst b/README.rst index e551987c0..49558b1c1 100644 --- a/README.rst +++ b/README.rst @@ -5,17 +5,28 @@ The ``ndarray`` crate provides an *n*-dimensional container for general elements and for numerics. Please read the `API documentation on docs.rs`__ +or take a look at the `quickstart tutorial <./README-quick-start.md>`_. __ https://docs.rs/ndarray/ -|build_status|_ |crates|_ +|build_status|_ |crates|_ |matrix-chat|_ |irc|_ -.. |build_status| image:: https://api.travis-ci.org/rust-ndarray/ndarray.svg?branch=master -.. _build_status: https://travis-ci.org/rust-ndarray/ndarray +.. |build_status| image:: https://github.com/rust-ndarray/ndarray/actions/workflows/ci.yaml/badge.svg + :alt: CI build status +.. _build_status: https://github.com/rust-ndarray/ndarray/actions -.. |crates| image:: http://meritbadge.herokuapp.com/ndarray +.. |crates| image:: https://img.shields.io/crates/v/ndarray.svg + :alt: ndarray at crates.io .. _crates: https://crates.io/crates/ndarray +.. |matrix-chat| image:: https://img.shields.io/badge/Matrix-%23rust--sci%3Amatrix.org-lightgrey + :alt: Matrix chat at #rust-sci:matrix.org +.. _matrix-chat: https://matrix.to/#/#rust-sci:matrix.org + +.. |irc| image:: https://img.shields.io/badge/IRC-%23rust--sci%20on%20OFTC-lightgrey + :alt: IRC at #rust-sci on OFTC +.. _irc: https://webchat.oftc.net/?channels=rust-sci + Highlights ---------- @@ -47,22 +58,48 @@ Crate Feature Flags The following crate feature flags are available. They are configured in your `Cargo.toml`. +- ``std`` + + - Rust standard library (enabled by default) + + - This crate can be used without the standard library by disabling the + default `std` feature. To do so, use this in your `Cargo.toml`: + + :: + + [dependencies] + ndarray = { version = "0.x.y", default-features = false } + + + - The `geomspace` `linspace` `logspace` `range` `std` `var` `var_axis` and `std_axis` + methods are only available when `std` is enabled. + - ``serde`` - - Optional, compatible with Rust stable - Enables serialization support for serde 1.x - ``rayon`` - - Optional, compatible with Rust stable - Enables parallel iterators, parallelized methods and ``par_azip!``. + - Implies std + +- ``approx`` + + - Implementations of traits from version 0.5 of the [`approx`] crate. - ``blas`` - - Optional and experimental, compatible with Rust stable - Enable transparent BLAS support for matrix multiplication. Uses ``blas-src`` for pluggable backend, which needs to be configured - separately. + separately (see below). + +- ``matrixmultiply-threading`` + + - Enable the ``threading`` feature in the matrixmultiply package + +- ``portable-atomic-critical-section`` + + - Whether ``portable-atomic`` should use ``critical-section`` How to use with cargo --------------------- @@ -70,20 +107,62 @@ How to use with cargo :: [dependencies] - ndarray = "0.13.0" + ndarray = "0.16.0" + +How to enable BLAS integration +------------------------------ + +Blas integration is an optional add-on. Without BLAS, ndarray uses the +``matrixmultiply`` crate for matrix multiplication for ``f64`` and ``f32`` +arrays (and it's always enabled as a fallback since it supports matrices of +arbitrary strides in both dimensions). -How to enable blas integration. Depend on ``blas-src`` directly to pick a blas -provider. Depend on the same ``blas-src`` version as ``ndarray`` does, for the -selection to work. A proposed configuration using system openblas is shown -below. Note that only end-user projects (not libraries) should select -provider:: +Depend and link to ``blas-src`` directly to pick a blas provider. Ndarray +presently requires a blas provider that provides the ``cblas-sys`` interface. If +further feature selection is wanted or needed then you might need to depend directly on +the backend crate's source too. The backend version **must** be the one that +``blas-src`` also depends on. + +An example configuration using system openblas is shown below. Note that only +end-user projects (not libraries) should select provider:: + + [dependencies] + ndarray = { version = "0.16.0", features = ["blas"] } + blas-src = { version = "0.10", features = ["openblas"] } + openblas-src = { version = "0.10", features = ["cblas", "system"] } +Using system-installed dependencies can save a long time building dependencies. +An example configuration using (compiled) netlib is shown below anyway:: [dependencies] - ndarray = { version = "0.13.0", features = ["blas"] } - blas-src = { version = "0.2.0", default-features = false, features = ["openblas"] } - openblas-src = { version = "0.6.0", default-features = false, features = ["cblas", "system"] } + ndarray = { version = "0.16.0", features = ["blas"] } + blas-src = { version = "0.10.0", default-features = false, features = ["netlib"] } +When this is done, your program must also link to ``blas_src`` by using it or +explicitly including it in your code:: + + extern crate blas_src; + +The following versions have been verified to work together. For ndarray 0.15 or later, +there is no tight coupling to the ``blas-src`` version, so version selection is more flexible. + +=========== ============ ================ ============== +``ndarray`` ``blas-src`` ``openblas-src`` ``netlib-src`` +=========== ============ ================ ============== +0.16 0.10 0.10 0.8 +0.15 0.8 0.10 0.8 +0.15 0.7 0.9 0.8 +0.14 0.6.1 0.9.0 +0.13 0.2.0 0.6.0 +=========== ============ ================ ============== + +------------ +BLAS on MSRV +------------ + +Although ``ndarray`` currently maintains an MSRV of 1.64.0, this is separate from the MSRV (either stated or real) of the various BLAS providers. +As of the time of writing, ``openblas`` currently supports MSRV of 1.71.1. +So, while ``ndarray`` and ``openblas-src`` are compatible, they can only work together with toolchains 1.71.1 or above. Recent Changes -------------- @@ -101,4 +180,3 @@ http://opensource.org/licenses/MIT, at your option. This file may not be copied, modified, or distributed except according to those terms. - diff --git a/RELEASES.md b/RELEASES.md index 5d8f51a2e..8b4786666 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,3 +1,804 @@ +Version 0.16.1 (2024-08-14) +=========================== + +- Refactor and simplify BLAS gemm call further by [@bluss](https://github.com/bluss) [#1421](https://github.com/rust-ndarray/ndarray/pull/1421) +- Fix infinite recursion and off-by-one error in triu/tril by [@akern40](https://github.com/akern40) [#1418](https://github.com/rust-ndarray/ndarray/pull/1418) +- Fix using BLAS for all compatible cases of memory layout by [@bluss](https://github.com/bluss) [#1419](https://github.com/rust-ndarray/ndarray/pull/1419) +- Use PR check instead of Merge Queue, and check rustdoc by [@bluss](https://github.com/bluss) [#1420](https://github.com/rust-ndarray/ndarray/pull/1420) +- Make iterators covariant in element type by [@bluss](https://github.com/bluss) [#1417](https://github.com/rust-ndarray/ndarray/pull/1417) + +Version 0.16.0 (2024-08-03) +=========================== + +Featured Changes +---------------- + +- Better shape: Deprecate reshape, into_shape by [@bluss](https://github.com/bluss) [#1310](https://github.com/rust-ndarray/ndarray/pull/1310)
+ `.into_shape()` **is now deprecated**. + Use `.into_shape_with_order()` or `.to_shape()` instead, which don't have `into_shape`'s drawbacks. + +New Features and Improvements +----------------------------- + +- Check for aliasing in `RawViewMut::from_shape_ptr` with a debug assertion by [@bluss](https://github.com/bluss) [#1413](https://github.com/rust-ndarray/ndarray/pull/1413) +- Allow aliasing in ArrayView::from_shape by [@bluss](https://github.com/bluss) [#1410](https://github.com/rust-ndarray/ndarray/pull/1410) +- Remove deprecations from 0.15.x by [@bluss](https://github.com/bluss) [#1409](https://github.com/rust-ndarray/ndarray/pull/1409) +- Make `CowArray` an owned storage array, require Clone bound for `into_shared` by [@jturner314](https://github.com/jturner314) [#1028](https://github.com/rust-ndarray/ndarray/pull/1028) +- Change `NdProducer::Dim` of `axis_windows()` to `Ix1` by [@jonasBoss](https://github.com/jonasBoss) [#1305](https://github.com/rust-ndarray/ndarray/pull/1305) +- Add `squeeze()` to dynamic dimension arrays by [@barakugav](https://github.com/barakugav) [#1396](https://github.com/rust-ndarray/ndarray/pull/1396) +- Add `flatten`, `flatten_with_order` and `into_flat` to arrays by [@barakugav](https://github.com/barakugav) [#1397](https://github.com/rust-ndarray/ndarray/pull/1397) +- Make compatible with thumbv6m-none-eabi by [@BjornTheProgrammer](https://github.com/BjornTheProgrammer) [#1384](https://github.com/rust-ndarray/ndarray/pull/1384) +- `is_unique` for `ArcArray` by [@daniellga](https://github.com/daniellga) [#1399](https://github.com/rust-ndarray/ndarray/pull/1399) +- Add `triu` and `tril` methods directly to ArrayBase by [@akern40](https://github.com/akern40) [#1386](https://github.com/rust-ndarray/ndarray/pull/1386) +- Fix styling of the BLAS integration heading. by [@adamreichold](https://github.com/adamreichold) [#1390](https://github.com/rust-ndarray/ndarray/pull/1390) +- Implement `product_axis` by [@akern40](https://github.com/akern40) [#1387](https://github.com/rust-ndarray/ndarray/pull/1387) +- Add reserve method for owned arrays by [@ssande7](https://github.com/ssande7) [#1268](https://github.com/rust-ndarray/ndarray/pull/1268) +- Use inline on spit_at and smaller methods by [@bluss](https://github.com/bluss) [#1381](https://github.com/rust-ndarray/ndarray/pull/1381) +- Update to Approx 0.5 by [@bluss](https://github.com/bluss) [#1380](https://github.com/rust-ndarray/ndarray/pull/1380) +- Add .into_raw_vec_with_offset() and deprecate .into_raw_vec() by [@bluss](https://github.com/bluss) [#1379](https://github.com/rust-ndarray/ndarray/pull/1379) +- Add additional array -> array view conversions by [@bluss](https://github.com/bluss) [#1130](https://github.com/rust-ndarray/ndarray/pull/1130) +- implement DoubleEndedIterator for 1d `LanesIter` by [@Muthsera](https://github.com/Muthsera) [#1237](https://github.com/rust-ndarray/ndarray/pull/1237) +- Add Zip::any by [@nilgoyette](https://github.com/nilgoyette) [#1228](https://github.com/rust-ndarray/ndarray/pull/1228) +- Make the aview0, aview1, and aview2 free functions be const fns by [@jturner314](https://github.com/jturner314) [#1132](https://github.com/rust-ndarray/ndarray/pull/1132) +- Add missing safety checks to `From<&[[A; N]]> for ArrayView` and `From<&mut [[A; N]]> for ArrayViewMut` by [@jturner314](https://github.com/jturner314) [#1131](https://github.com/rust-ndarray/ndarray/pull/1131) +- derived Debug for Iter and IterMut by [@biskwikman](https://github.com/biskwikman) [#1353](https://github.com/rust-ndarray/ndarray/pull/1353) +- Fix Miri errors for WindowsIter and ExactChunksIter/Mut by [@jturner314](https://github.com/jturner314) [#1142](https://github.com/rust-ndarray/ndarray/pull/1142) +- Fix Miri failure with -Zmiri-tag-raw-pointers by [@jturner314](https://github.com/jturner314) [#1138](https://github.com/rust-ndarray/ndarray/pull/1138) +- Track-caller panics by [@xd009642](https://github.com/xd009642) [#975](https://github.com/rust-ndarray/ndarray/pull/975) +- Add slice_axis_move method by [@jturner314](https://github.com/jturner314) [#1211](https://github.com/rust-ndarray/ndarray/pull/1211) +- iterators: Re-export IntoIter by [@bluss](https://github.com/bluss) [#1370](https://github.com/rust-ndarray/ndarray/pull/1370) +- Fix unsafe blocks in `s![]` macro by [@jturner314](https://github.com/jturner314) [#1196](https://github.com/rust-ndarray/ndarray/pull/1196) +- Fix comparison with NumPy of slicing with negative step by [@venkat0791](https://github.com/venkat0791) [#1319](https://github.com/rust-ndarray/ndarray/pull/1319) +- Updated Windows `base` Computations to be Safer by [@LazaroHurtado](https://github.com/LazaroHurtado) [#1297](https://github.com/rust-ndarray/ndarray/pull/1297) +- Update README-quick-start.md by [@fumseckk](https://github.com/fumseckk) [#1246](https://github.com/rust-ndarray/ndarray/pull/1246) +- Added stride support to `Windows` by [@LazaroHurtado](https://github.com/LazaroHurtado) [#1249](https://github.com/rust-ndarray/ndarray/pull/1249) +- Added select example to numpy user docs by [@WillAyd](https://github.com/WillAyd) [#1294](https://github.com/rust-ndarray/ndarray/pull/1294) +- Add both approx features to the readme by [@nilgoyette](https://github.com/nilgoyette) [#1289](https://github.com/rust-ndarray/ndarray/pull/1289) +- Add NumPy examples combining slicing and assignment by [@jturner314](https://github.com/jturner314) [#1210](https://github.com/rust-ndarray/ndarray/pull/1210) +- Fix contig check for single element arrays by [@bluss](https://github.com/bluss) [#1362](https://github.com/rust-ndarray/ndarray/pull/1362) +- Export Linspace and Logspace iterators by [@johann-cm](https://github.com/johann-cm) [#1348](https://github.com/rust-ndarray/ndarray/pull/1348) +- Use `clone_from()` in two places by [@ChayimFriedman2](https://github.com/ChayimFriedman2) [#1347](https://github.com/rust-ndarray/ndarray/pull/1347) +- Update README-quick-start.md by [@joelchen](https://github.com/joelchen) [#1344](https://github.com/rust-ndarray/ndarray/pull/1344) +- Provide element-wise math functions for floats by [@KmolYuan](https://github.com/KmolYuan) [#1042](https://github.com/rust-ndarray/ndarray/pull/1042) +- Improve example in doc for columns method by [@gkobeaga](https://github.com/gkobeaga) [#1221](https://github.com/rust-ndarray/ndarray/pull/1221) +- Fix description of stack! in quick start by [@jturner314](https://github.com/jturner314) [#1156](https://github.com/rust-ndarray/ndarray/pull/1156) + +Tests, CI and Maintainer tasks +------------------------------ + +- CI: require rustfmt, nostd by [@bluss](https://github.com/bluss) [#1411](https://github.com/rust-ndarray/ndarray/pull/1411) +- Prepare changelog for 0.16.0 by [@bluss](https://github.com/bluss) [#1401](https://github.com/rust-ndarray/ndarray/pull/1401) +- Organize dependencies with workspace = true (cont.) by [@bluss](https://github.com/bluss) [#1407](https://github.com/rust-ndarray/ndarray/pull/1407) +- Update to use dep: for features by [@bluss](https://github.com/bluss) [#1406](https://github.com/rust-ndarray/ndarray/pull/1406) +- Organize the workspace of test crates a bit better by [@bluss](https://github.com/bluss) [#1405](https://github.com/rust-ndarray/ndarray/pull/1405) +- Add rustfmt commit to ignored revisions for git blame by [@lucascolley](https://github.com/lucascolley) [#1376](https://github.com/rust-ndarray/ndarray/pull/1376) +- The minimum amount of work required to fix our CI by [@adamreichold](https://github.com/adamreichold) [#1388](https://github.com/rust-ndarray/ndarray/pull/1388) +- Fixed broke continuous integration badge by [@juhotuho10](https://github.com/juhotuho10) [#1382](https://github.com/rust-ndarray/ndarray/pull/1382) +- Use mold linker to speed up ci by [@bluss](https://github.com/bluss) [#1378](https://github.com/rust-ndarray/ndarray/pull/1378) +- Add rustformat config and CI by [@bluss](https://github.com/bluss) [#1375](https://github.com/rust-ndarray/ndarray/pull/1375) +- Add docs to CI by [@jturner314](https://github.com/jturner314) [#925](https://github.com/rust-ndarray/ndarray/pull/925) +- Test using cargo-careful by [@bluss](https://github.com/bluss) [#1371](https://github.com/rust-ndarray/ndarray/pull/1371) +- Further ci updates - numeric tests, and run all tests on PRs by [@bluss](https://github.com/bluss) [#1369](https://github.com/rust-ndarray/ndarray/pull/1369) +- Setup ci so that most checks run in merge queue only by [@bluss](https://github.com/bluss) [#1368](https://github.com/rust-ndarray/ndarray/pull/1368) +- Use merge queue by [@bluss](https://github.com/bluss) [#1367](https://github.com/rust-ndarray/ndarray/pull/1367) +- Try to make the master branch shipshape by [@adamreichold](https://github.com/adamreichold) [#1286](https://github.com/rust-ndarray/ndarray/pull/1286) +- Update ci - run cross tests only on master by [@bluss](https://github.com/bluss) [#1366](https://github.com/rust-ndarray/ndarray/pull/1366) +- ndarray_for_numpy_users some example to code not pointed out to clippy by [@higumachan](https://github.com/higumachan) [#1360](https://github.com/rust-ndarray/ndarray/pull/1360) +- Fix minimum rust version mismatch in lib.rs by [@HoKim98](https://github.com/HoKim98) [#1352](https://github.com/rust-ndarray/ndarray/pull/1352) +- Fix MSRV build by pinning crossbeam crates. by [@adamreichold](https://github.com/adamreichold) [#1345](https://github.com/rust-ndarray/ndarray/pull/1345) +- Fix new rustc lints to make the CI pass. by [@adamreichold](https://github.com/adamreichold) [#1337](https://github.com/rust-ndarray/ndarray/pull/1337) +- Make Clippy happy and fix MSRV build by [@adamreichold](https://github.com/adamreichold) [#1320](https://github.com/rust-ndarray/ndarray/pull/1320) +- small formatting fix in README.rst by [@podusowski](https://github.com/podusowski) [#1199](https://github.com/rust-ndarray/ndarray/pull/1199) +- Fix CI failures (mostly linting with clippy) by [@aganders3](https://github.com/aganders3) [#1171](https://github.com/rust-ndarray/ndarray/pull/1171) +- Remove doc(hidden) attr from items in trait impls by [@jturner314](https://github.com/jturner314) [#1165](https://github.com/rust-ndarray/ndarray/pull/1165) + + +Version 0.15.6 (2022-07-30) +=========================== + +New features +------------ + +- Add `get_ptr` and `get_mut_ptr` methods for getting an element's pointer from + an index, by [@adamreichold]. + + https://github.com/rust-ndarray/ndarray/pull/1151 + +Other changes +------------- + +- Various fixes to resolve compiler and Clippy warnings/errors, by [@aganders3] + and [@jturner314]. + + https://github.com/rust-ndarray/ndarray/pull/1171 + +- Fix description of `stack!` in quick start docs, by [@jturner314]. Thanks to + [@HyeokSuLee] for pointing out the issue. + + https://github.com/rust-ndarray/ndarray/pull/1156 + +- Add MSRV to `Cargo.toml`. + + https://github.com/rust-ndarray/ndarray/pull/1191 + + +Version 0.15.5 (2022-07-30) +=========================== + +Enhancements +------------ + +- The `s!` macro now works in `no_std` environments, by [@makotokato]. + + https://github.com/rust-ndarray/ndarray/pull/1154 + +Other changes +------------- + +- Improve docs and fix typos, by [@steffahn] and [@Rikorose]. + + https://github.com/rust-ndarray/ndarray/pull/1134
+ https://github.com/rust-ndarray/ndarray/pull/1164 + + +Version 0.15.4 (2021-11-23) +=========================== + +The Dr. Turner release 🚀 + +New features +------------ + +- Complex matrix multiplication now uses BLAS ``cgemm``/``zgemm`` when + enabled (and matrix layout allows), by [@ethanhs]. + + https://github.com/rust-ndarray/ndarray/pull/1106 + +- Use `matrixmultiply` as fallback for complex matrix multiplication + when BLAS is not available or the matrix layout requires it by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/1118 + +- Add ``into/to_slice_memory_order`` methods for views, lifetime-preserving + versions of existing similar methods by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/1015 + +- ``kron`` function for Kronecker product by [@ethanhs]. + + https://github.com/rust-ndarray/ndarray/pull/1105 + +- ``split_complex`` method for splitting complex arrays into separate + real and imag view parts by [@jturner314] and [@ethanhs]. + + https://github.com/rust-ndarray/ndarray/pull/1107 + +- New method ``try_into_owned_nocopy`` by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/1022 + +- New producer and iterable ``axis_windows`` by [@VasanthakumarV] + and [@jturner314]. + + https://github.com/rust-ndarray/ndarray/pull/1022 + +- New method ``Zip::par_fold`` by [@adamreichold] + + https://github.com/rust-ndarray/ndarray/pull/1095 + +- New constructor ``from_diag_elem`` by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/1076 + +- ``Parallel::with_min_len`` method for parallel iterators by [@adamreichold] + + https://github.com/rust-ndarray/ndarray/pull/1081 + +- Allocation-preserving map function ``.mapv_into_any()`` added by [@benkay86] + +Enhancements +------------ + +- Improve performance of ``.sum_axis()`` for some cases by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/1061 + +Bug fixes +--------- + +- Fix error in calling dgemv (matrix-vector multiplication) with BLAS and + broadcasted arrays, by [@jturner314]. + + https://github.com/rust-ndarray/ndarray/pull/1088 + +API changes +----------- + +- Support approx 0.5 partially alongside the already existing approx 0.4 support. + New feature flag is `approx-0_5`, by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/1025 + +- Slice and reference-to-array conversions to CowArray added for by [@jturner314]. + + https://github.com/rust-ndarray/ndarray/pull/1038 + +- Allow trailing comma in stack and concatenate macros by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/1044 + +- ``Zip`` now has a ``must_use`` marker to help users by [@adamreichold] + + https://github.com/rust-ndarray/ndarray/pull/1082 + +Other changes +------------- + +- Fixing the crates.io badge on github by [@atouchet] + + https://github.com/rust-ndarray/ndarray/pull/1104 + +- Use intra-doc links in docs by [@LeSeulArtichaut] + + https://github.com/rust-ndarray/ndarray/pull/1033 + +- Clippy fixes by [@adamreichold] + + https://github.com/rust-ndarray/ndarray/pull/1092
+ https://github.com/rust-ndarray/ndarray/pull/1091 + +- Minor fixes in links and punctuation in docs by [@jimblandy] + + https://github.com/rust-ndarray/ndarray/pull/1056 + +- Minor fixes in docs by [@chohner] + + https://github.com/rust-ndarray/ndarray/pull/1119 + +- Update tests to quickcheck 1.0 by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/1114 + + +Version 0.15.3 (2021-06-05) +=========================== + +New features +------------ + +- New methods `.last/_mut()` for arrays and array views by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/1013 + +Bug fixes +--------- + +- Fix `as_slice_memory_order_mut()` so that it never changes strides (the + memory layout) of the array when called. + + This was a bug that impacted `ArcArray` (and for example not `Array` or `ArrayView/Mut`), + and multiple methods on `ArcArray` that use `as_slice_memory_order_mut` (for example `map_mut`). + Fix by [@jturner314]. + + https://github.com/rust-ndarray/ndarray/pull/1019 + +API changes +----------- + +- Array1 now implements `From>` by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/1016 + +- ArcArray now implements `From>` by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/1021 + +- CowArray now implements RawDataSubst by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/1020 + +Other changes +------------- + +- Mention unsharing in `.as_mut_ptr` docs by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/1017 + +- Clarify and fix minor errors in push/append method docs by [@bluss] f21c668a + +- Fix several warnings in doc example code by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/1009 + + +Version 0.15.2 (2021-05-17 🇳🇴) +================================ + +New features +------------ + +- New methods for growing/appending to owned `Array`s. These methods allow + building an array efficiently chunk by chunk. By [@bluss]. + + - `.push_row()`, `.push_column()` + - `.push(axis, array)`, `.append(axis, array)` + + `stack`, `concatenate` and `.select()` now support all `Clone`-able elements + as a result. + + https://github.com/rust-ndarray/ndarray/pull/932
+ https://github.com/rust-ndarray/ndarray/pull/990 + +- New reshaping method `.to_shape(...)`, called with new shape and optional + ordering parameter, this is the first improvement for reshaping in terms of + added features and increased consistency, with more to come. By [@bluss]. + + https://github.com/rust-ndarray/ndarray/pull/982 + +- `Array` now implements a by-value iterator, by [@bluss]. + + https://github.com/rust-ndarray/ndarray/pull/986 + +- New methods `.move_into()` and `.move_into_uninit()` which allow assigning + into an array by moving values from an array into another, by [@bluss]. + + https://github.com/rust-ndarray/ndarray/pull/932
+ https://github.com/rust-ndarray/ndarray/pull/997 + +- New method `.remove_index()` for owned arrays by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/967 + +- New constructor `build_uninit` which makes it easier to initialize + uninitialized arrays in a way that's generic over all owned array kinds. + By [@bluss]. + + https://github.com/rust-ndarray/ndarray/pull/1001 + +Enhancements +------------ + +- Preserve the allocation of the input array in some more cases for arithmetic ops by [@SparrowLii] + + https://github.com/rust-ndarray/ndarray/pull/963 + +- Improve broadcasting performance for &array + &array arithmetic ops by [@SparrowLii] + + https://github.com/rust-ndarray/ndarray/pull/965 + +Bug fixes +--------- + +- Fix an error in construction of empty array with negative strides, by [@jturner314]. + + https://github.com/rust-ndarray/ndarray/pull/998 + +- Fix minor performance bug with loop order selection in Zip by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/977 + +API changes +----------- + +- Add dimension getters to `Shape` and `StrideShape` by [@stokhos] + + https://github.com/rust-ndarray/ndarray/pull/978 + +Other changes +------------- + +- Rustdoc now uses the ndarray logo that [@jturner314] created previously + + https://github.com/rust-ndarray/ndarray/pull/981 + +- Minor doc changes by [@stokhos], [@cassiersg] and [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/968
+ https://github.com/rust-ndarray/ndarray/pull/971
+ https://github.com/rust-ndarray/ndarray/pull/974 + +- A little refactoring to reduce generics bloat in a few places by [@bluss]. + + https://github.com/rust-ndarray/ndarray/pull/1004 + + +Version 0.15.1 (2021-03-29) +=========================== + +Enhancements +------------ + +- Arrays and views now have additional PartialEq impls so that it's possible to + compare arrays with references to arrays and vice versa by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/958 + +Bug fixes +--------- + +- Fix panic in creation of `.windows()` producer from negative stride array by + [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/957 + +Other changes +------------- + +- Update BLAS documentation further by @bluss + + https://github.com/rust-ndarray/ndarray/pull/955
+ https://github.com/rust-ndarray/ndarray/pull/959 + + +Version 0.15.0 (2021-03-25) +=========================== + +New features +------------ + +- Support inserting new axes while slicing by [@jturner314]. This is an example: + + ```rust + let view = arr.slice(s![.., -1, 2..;-1, NewAxis]); + ``` + + https://github.com/rust-ndarray/ndarray/pull/570 + +- Support two-sided broadcasting in arithmetic operations with arrays by [@SparrowLii] + + This now allows, for example, addition of a 3 x 1 with a 1 x 3 array; the + operands are in this case broadcast to 3 x 3 which is the shape of the result. + + Note that this means that a new trait bound is required in some places when + mixing dimensionality types of arrays in arithmetic operations. + + https://github.com/rust-ndarray/ndarray/pull/898 + +- Support for compiling ndarray as `no_std` (using core and alloc) by + [@xd009642] and [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/861
+ https://github.com/rust-ndarray/ndarray/pull/889 + +- New methods `.cell_view()` and `ArrayViewMut::into_cell_view` that enable + new ways of working with array elements as if they were in Cells - setting + elements through shared views and broadcast views, by [@bluss]. + + https://github.com/rust-ndarray/ndarray/pull/877 + +- New methods `slice_each_axis/_mut/_inplace` that make it easier to slice + a dynamic number of axes in some situations, by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/913 + +- New method `a.assign_to(b)` with the inverse argument order compared to the + existing `b.assign(a)` and some extra features like assigning into + uninitialized arrays, By [@bluss]. + + https://github.com/rust-ndarray/ndarray/pull/947 + +- New methods `.std()` and `.var()` for standard deviation and variance by + [@kdubovikov] + + https://github.com/rust-ndarray/ndarray/pull/790 + +Enhancements +------------ + +- Ndarray can now correctly determine that arrays can be contiguous, even if + they have negative strides, by [@SparrowLii] + + https://github.com/rust-ndarray/ndarray/pull/885
+ https://github.com/rust-ndarray/ndarray/pull/948 + +- Improvements to `map_inplace` by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/911 + +- `.into_dimensionality` performance was improved for the `IxDyn` to `IxDyn` + case by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/906 + +- Improved performance for scalar + &array and &array + scalar operations by + [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/890 + +API changes +----------- + +- New constructors `Array::from_iter` and `Array::from_vec` by [@bluss]. + No new functionality, just that these constructors are available without trait + imports. + + https://github.com/rust-ndarray/ndarray/pull/921 + +- `NdProducer::raw_dim` is now a documented method by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/918 + +- `AxisDescription` is now a struct with field names, not a tuple struct by + [@jturner314]. Its accessor methods are now deprecated. + + https://github.com/rust-ndarray/ndarray/pull/915 + +- Methods for array comparison `abs_diff_eq` and `relative_eq` are now + exposed as inherent methods too (no trait import needed), still under the approx + feature flag by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/946 + +- Changes to the slicing-related types and macro by [@jturner314] and [@bluss]: + + - Remove the `Dimension::SliceArg` associated type, and add a new `SliceArg` + trait for this purpose. + - Change the return type of the `s![]` macro to an owned `SliceInfo` rather + than a reference. + - Replace the `SliceOrIndex` enum with `SliceInfoElem`, which has an + additional `NewAxis` variant and does not have a `step_by` method. + - Change the type parameters of `SliceInfo` in order to support the `NewAxis` + functionality and remove some tricky `unsafe` code. + - Mark the `SliceInfo::new` method as `unsafe`. The new implementations of + `TryFrom` can be used as a safe alternative. + - Remove the `AsRef> for SliceInfo` + implementation. Add the similar `From<&'a SliceInfo> for + SliceInfo<&'a [SliceInfoElem], Din, Dout>` conversion as an alternative. + - Change the *expr* `;` *step* case in the `s![]` macro to error at compile + time if an unsupported type for *expr* is used, instead of panicking at + runtime. + + https://github.com/rust-ndarray/ndarray/pull/570
+ https://github.com/rust-ndarray/ndarray/pull/940
+ https://github.com/rust-ndarray/ndarray/pull/943
+ https://github.com/rust-ndarray/ndarray/pull/945
+ +- Removed already deprecated methods by [@bluss]: + + - Remove deprecated `.all_close()` - use approx feature and methods like `.abs_diff_eq` instead + - Mark `.scalar_sum()` as deprecated - use `.sum()` instead + - Remove deprecated `DataClone` - use `Data + RawDataClone` instead + - Remove deprecated `ArrayView::into_slice` - use `to_slice()` instead. + + https://github.com/rust-ndarray/ndarray/pull/874 + +- Remove already deprecated methods: rows, cols (for row and column count; the + new names are nrows and ncols) by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/872 + +- Renamed `Zip` methods by [@bluss] and [@SparrowLii]: + + - `apply` -> `for_each` + - `apply_collect` -> `map_collect` + - `apply_collect_into` -> `map_collect_into` + - (`par_` prefixed methods renamed accordingly) + + https://github.com/rust-ndarray/ndarray/pull/894
+ https://github.com/rust-ndarray/ndarray/pull/904
+ +- Deprecate `Array::uninitialized` and revamped its replacement by [@bluss] + + Please use new new `Array::uninit` which is based on `MaybeUninit` (renamed + from `Array::maybe_uninit`, the old name is also deprecated). + + https://github.com/rust-ndarray/ndarray/pull/902
+ https://github.com/rust-ndarray/ndarray/pull/876 + +- Renamed methods (old names are now deprecated) by [@bluss] and [@jturner314] + + - `genrows/_mut` -> `rows/_mut` + - `gencolumns/_mut` -> `columns/_mut` + - `stack_new_axis` -> `stack` (the new name already existed) + - `visit` -> `for_each` + + https://github.com/rust-ndarray/ndarray/pull/872
+ https://github.com/rust-ndarray/ndarray/pull/937
+ https://github.com/rust-ndarray/ndarray/pull/907
+ +- Updated `matrixmultiply` dependency to 0.3.0 by [@bluss] + and adding new feature flag `matrixmultiply-threading` to enable its threading + + https://github.com/rust-ndarray/ndarray/pull/888
+ https://github.com/rust-ndarray/ndarray/pull/938
+ +- Updated `num-complex` dependency to 0.4.0 by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/952 + +Bug fixes +--------- + +- Fix `Zip::indexed` for the 0-dimensional case by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/862 + +- Fix bug in layout computation that broke parallel collect to f-order + array in some circumstances by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/900 + +- Fix an unwanted panic in shape overflow checking by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/855 + +- Mark the `SliceInfo::new` method as `unsafe` due to the requirement that + `indices.as_ref()` always return the same value when called multiple times, + by [@bluss] and [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/570 + +Other changes +------------- + +- It was changed how we integrate with BLAS and `blas-src`. Users of BLAS need + to read the README for the updated instructions. Ndarray itself no longer + has public dependency on `blas-src`. Changes by [@bluss]. + + https://github.com/rust-ndarray/ndarray/pull/891
+ https://github.com/rust-ndarray/ndarray/pull/951 + +- Various improvements to tests and CI by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/934
+ https://github.com/rust-ndarray/ndarray/pull/924
+ +- The `sort-axis.rs` example file's implementation of sort was bugfixed and now + has tests, by [@dam5h] and [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/916
+ https://github.com/rust-ndarray/ndarray/pull/930 + +- We now link to the #rust-sci room on matrix in the readme by [@jturner314] + + https://github.com/rust-ndarray/ndarray/pull/619 + +- Internal cleanup with builder-like methods for creating arrays by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/908 + +- Implementation fix of `.swap(i, j)` by [@bluss] + + https://github.com/rust-ndarray/ndarray/pull/903 + +- Minimum supported Rust version (MSRV) is Rust 1.49. + + https://github.com/rust-ndarray/ndarray/pull/902 + +- Minor improvements to docs by [@insideoutclub] + + https://github.com/rust-ndarray/ndarray/pull/887 + + +Version 0.14.0 (2020-11-28) +=========================== + +New features +------------ + +- `Zip::apply_collect` and `Zip::par_apply_collect` now support all + elements (not just `Copy` elements) by [@bluss] + https://github.com/rust-ndarray/ndarray/pull/814 + https://github.com/rust-ndarray/ndarray/pull/817 + +- New function `stack` by [@andrei-papou] + https://github.com/rust-ndarray/ndarray/pull/844 + https://github.com/rust-ndarray/ndarray/pull/850 + +Enhancements +------------ + +- Handle inhomogeneous shape inputs better in Zip, in practice: guess better whether + to prefer c- or f-order for the inner loop by [@bluss] + https://github.com/rust-ndarray/ndarray/pull/809 + +- Improve code sharing in some commonly used code by [@bluss] + https://github.com/rust-ndarray/ndarray/pull/819 + +API changes +----------- + +- The **old function** `stack` has been renamed to `concatenate`. + A new function `stack` with numpy-like semantics have taken its place. + Old usages of `stack` should change to use `concatenate`. + + `concatenate` produces an array with the same number of axes as the inputs. + `stack` produces an array that has one more axis than the inputs. + + This change was unfortunately done without a deprecation period, due to the long period between releases. + + https://github.com/rust-ndarray/ndarray/pull/844 + https://github.com/rust-ndarray/ndarray/pull/850 + +- Enum ErrorKind is now properly non-exhaustive and has lost its old placeholder invalid variant. By [@Zuse64] + https://github.com/rust-ndarray/ndarray/pull/848 + +- Remove deprecated items: + + - RcArray (deprecated alias for ArcArray) + - Removed `subview_inplace` use `collapse_axis` + - Removed `subview_mut` use `index_axis_mut` + - Removed `into_subview` use `index_axis_move` + - Removed `subview` use `index_axis` + - Removed `slice_inplace` use `slice_collapse` + +- Undeprecated `remove_axis` because its replacement is hard to find out on your own. + +- Update public external dependencies to new versions by [@Eijebong] and [@bluss] + + - num-complex 0.3 + - approx 0.4 (optional) + - blas-src 0.6.1 and openblas-src 0.9.0 (optional) + + https://github.com/rust-ndarray/ndarray/pull/810 + https://github.com/rust-ndarray/ndarray/pull/851 + + +Other changes +------------- + +- Minor doc fixes by [@acj] + https://github.com/rust-ndarray/ndarray/pull/834 + +- Minor doc fixes by [@xd009642] + https://github.com/rust-ndarray/ndarray/pull/847 + +- The minimum required rust version is Rust 1.42. + +- Release management by [@bluss] + +Version 0.13.1 (2020-04-21) +=========================== + +New features +------------ + +- New *amazing* slicing methods `multi_slice_*` by [@jturner314] + https://github.com/rust-ndarray/ndarray/pull/717 +- New method `.cast()` for raw views by [@bluss] + https://github.com/rust-ndarray/ndarray/pull/734 +- New aliases `ArcArray1`, `ArcArray2` by [@d-dorazio] + https://github.com/rust-ndarray/ndarray/pull/741 +- New array constructor `from_shape_simple_fn` by [@bluss] + https://github.com/rust-ndarray/ndarray/pull/728 +- `Dimension::Larger` now requires `RemoveAxis` by [@TheLortex] + https://github.com/rust-ndarray/ndarray/pull/792 +- New methods for collecting Zip into an array by [@bluss] + https://github.com/rust-ndarray/ndarray/pull/797 +- New `Array::maybe_uninit` and `.assume_init()` by [@bluss] + https://github.com/rust-ndarray/ndarray/pull/803 + +Enhancements +------------ + +- Remove itertools as dependency by [@bluss] + https://github.com/rust-ndarray/ndarray/pull/730 +- Improve `zip_mut_with` (and thus arithmetic ops) for f-order arrays by [@nilgoyette] + https://github.com/rust-ndarray/ndarray/pull/754 +- Implement `fold` for `IndicesIter` by [@jturner314] + https://github.com/rust-ndarray/ndarray/pull/733 +- New Quick Start readme by [@lifuyang] + https://github.com/rust-ndarray/ndarray/pull/785 + +API changes +----------- + +- Remove alignment restriction on raw views by [@jturner314] + https://github.com/rust-ndarray/ndarray/pull/738 + +Other changes +------------- + +- Fix documentation in ndarray for numpy users by [@jturner314] +- Improve blas version documentation by [@jturner314] +- Doc improvements by [@mockersf] https://github.com/rust-ndarray/ndarray/pull/751 +- Doc and lint related improvements by [@viniciusd] https://github.com/rust-ndarray/ndarray/pull/750 +- Minor fixes related to best practices for unsafe code by [@bluss] + https://github.com/rust-ndarray/ndarray/pull/799 + https://github.com/rust-ndarray/ndarray/pull/802 +- Release management by [@bluss] + + Version 0.13.0 (2019-09-23) =========================== @@ -320,7 +1121,7 @@ Earlier releases - Add `Zip::indexed` - New methods `genrows/_mut, gencolumns/_mut, lanes/_mut` that - return iterable producers (producer means `Zip` compatibile). + return iterable producers (producer means `Zip` compatible). - New method `.windows()` by @Robbepop, returns an iterable producer - New function `general_mat_vec_mul` (with fast default and blas acceleration) - `Zip::apply` and `fold_while` now take `self` as the first argument @@ -884,14 +1685,44 @@ Earlier releases - Starting point for evolution to come +[@adamreichold]: https://github.com/adamreichold +[@aganders3]: https://github.com/aganders3 [@bluss]: https://github.com/bluss [@jturner314]: https://github.com/jturner314 [@LukeMathWalker]: https://github.com/LukeMathWalker -[@max-sixty]: https://github.com/max-sixty +[@acj]: https://github.com/acj +[@adamreichold]: https://github.com/adamreichold +[@atouchet]: https://github.com/atouchet +[@andrei-papou]: https://github.com/andrei-papou +[@benkay]: https://github.com/benkay +[@cassiersg]: https://github.com/cassiersg +[@chohner]: https://github.com/chohner +[@dam5h]: https://github.com/dam5h +[@ethanhs]: https://github.com/ethanhs +[@d-dorazio]: https://github.com/d-dorazio +[@Eijebong]: https://github.com/Eijebong +[@HyeokSuLee]: https://github.com/HyeokSuLee +[@insideoutclub]: https://github.com/insideoutclub [@JP-Ellis]: https://github.com/JP-Ellis -[@sebasv]: https://github.com/sebasv -[@andrei-papou]: https://github.com/sebasv +[@jimblandy]: https://github.com/jimblandy +[@LeSeulArtichaut]: https://github.com/LeSeulArtichaut +[@lifuyang]: https://github.com/liufuyang +[@kdubovikov]: https://github.com/kdubovikov +[@makotokato]: https://github.com/makotokato +[@max-sixty]: https://github.com/max-sixty [@mneumann]: https://github.com/mneumann -[@termoshtt]: https://github.com/termoshtt -[@rth]: https://github.com/rth +[@mockersf]: https://github.com/mockersf +[@nilgoyette]: https://github.com/nilgoyette [@nitsky]: https://github.com/nitsky +[@Rikorose]: https://github.com/Rikorose +[@rth]: https://github.com/rth +[@sebasv]: https://github.com/sebasv +[@SparrowLii]: https://github.com/SparrowLii +[@steffahn]: https://github.com/steffahn +[@stokhos]: https://github.com/stokhos +[@termoshtt]: https://github.com/termoshtt +[@TheLortex]: https://github.com/TheLortex +[@viniciusd]: https://github.com/viniciusd +[@VasanthakumarV]: https://github.com/VasanthakumarV +[@xd009642]: https://github.com/xd009642 +[@Zuse64]: https://github.com/Zuse64 diff --git a/benches/append.rs b/benches/append.rs new file mode 100644 index 000000000..a37df256f --- /dev/null +++ b/benches/append.rs @@ -0,0 +1,32 @@ +#![feature(test)] + +extern crate test; +use test::Bencher; + +use ndarray::prelude::*; + +#[bench] +fn select_axis0(bench: &mut Bencher) +{ + let a = Array::::zeros((256, 256)); + let selectable = vec![0, 1, 2, 0, 1, 3, 0, 4, 16, 32, 128, 147, 149, 220, 221, 255, 221, 0, 1]; + bench.iter(|| a.select(Axis(0), &selectable)); +} + +#[bench] +fn select_axis1(bench: &mut Bencher) +{ + let a = Array::::zeros((256, 256)); + let selectable = vec![0, 1, 2, 0, 1, 3, 0, 4, 16, 32, 128, 147, 149, 220, 221, 255, 221, 0, 1]; + bench.iter(|| a.select(Axis(1), &selectable)); +} + +#[bench] +fn select_1d(bench: &mut Bencher) +{ + let a = Array::::zeros(1024); + let mut selectable = (0..a.len()).step_by(17).collect::>(); + selectable.extend(selectable.clone().iter().rev()); + + bench.iter(|| a.select(Axis(0), &selectable)); +} diff --git a/benches/bench1.rs b/benches/bench1.rs index 190ab5065..33185844a 100644 --- a/benches/bench1.rs +++ b/benches/bench1.rs @@ -1,22 +1,23 @@ #![feature(test)] #![allow(unused_imports)] #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names )] extern crate test; -use ndarray::ShapeBuilder; +use std::mem::MaybeUninit; + use ndarray::{arr0, arr1, arr2, azip, s}; use ndarray::{Array, Array1, Array2, Axis, Ix, Zip}; +use ndarray::{Array3, Array4, ShapeBuilder}; +use ndarray::{Ix1, Ix2, Ix3, Ix5, IxDyn}; use test::black_box; #[bench] -fn iter_sum_1d_regular(bench: &mut test::Bencher) { +fn iter_sum_1d_regular(bench: &mut test::Bencher) +{ let a = Array::::zeros(64 * 64); let a = black_box(a); bench.iter(|| { @@ -29,7 +30,8 @@ fn iter_sum_1d_regular(bench: &mut test::Bencher) { } #[bench] -fn iter_sum_1d_raw(bench: &mut test::Bencher) { +fn iter_sum_1d_raw(bench: &mut test::Bencher) +{ // this is autovectorized to death (= great performance) let a = Array::::zeros(64 * 64); let a = black_box(a); @@ -43,7 +45,8 @@ fn iter_sum_1d_raw(bench: &mut test::Bencher) { } #[bench] -fn iter_sum_2d_regular(bench: &mut test::Bencher) { +fn iter_sum_2d_regular(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); let a = black_box(a); bench.iter(|| { @@ -56,12 +59,13 @@ fn iter_sum_2d_regular(bench: &mut test::Bencher) { } #[bench] -fn iter_sum_2d_by_row(bench: &mut test::Bencher) { +fn iter_sum_2d_by_row(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); let a = black_box(a); bench.iter(|| { let mut sum = 0; - for row in a.genrows() { + for row in a.rows() { for &elt in row { sum += elt; } @@ -71,7 +75,8 @@ fn iter_sum_2d_by_row(bench: &mut test::Bencher) { } #[bench] -fn iter_sum_2d_raw(bench: &mut test::Bencher) { +fn iter_sum_2d_raw(bench: &mut test::Bencher) +{ // this is autovectorized to death (= great performance) let a = Array::::zeros((64, 64)); let a = black_box(a); @@ -85,7 +90,8 @@ fn iter_sum_2d_raw(bench: &mut test::Bencher) { } #[bench] -fn iter_sum_2d_cutout(bench: &mut test::Bencher) { +fn iter_sum_2d_cutout(bench: &mut test::Bencher) +{ let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = black_box(av); @@ -99,7 +105,8 @@ fn iter_sum_2d_cutout(bench: &mut test::Bencher) { } #[bench] -fn iter_sum_2d_cutout_by_row(bench: &mut test::Bencher) { +fn iter_sum_2d_cutout_by_row(bench: &mut test::Bencher) +{ let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = black_box(av); @@ -115,13 +122,14 @@ fn iter_sum_2d_cutout_by_row(bench: &mut test::Bencher) { } #[bench] -fn iter_sum_2d_cutout_outer_iter(bench: &mut test::Bencher) { +fn iter_sum_2d_cutout_outer_iter(bench: &mut test::Bencher) +{ let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = black_box(av); bench.iter(|| { let mut sum = 0; - for row in a.genrows() { + for row in a.rows() { for &elt in row { sum += elt; } @@ -131,7 +139,8 @@ fn iter_sum_2d_cutout_outer_iter(bench: &mut test::Bencher) { } #[bench] -fn iter_sum_2d_transpose_regular(bench: &mut test::Bencher) { +fn iter_sum_2d_transpose_regular(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((64, 64)); a.swap_axes(0, 1); let a = black_box(a); @@ -145,7 +154,8 @@ fn iter_sum_2d_transpose_regular(bench: &mut test::Bencher) { } #[bench] -fn iter_sum_2d_transpose_by_row(bench: &mut test::Bencher) { +fn iter_sum_2d_transpose_by_row(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((64, 64)); a.swap_axes(0, 1); let a = black_box(a); @@ -161,14 +171,16 @@ fn iter_sum_2d_transpose_by_row(bench: &mut test::Bencher) { } #[bench] -fn sum_2d_regular(bench: &mut test::Bencher) { +fn sum_2d_regular(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); let a = black_box(a); bench.iter(|| a.sum()); } #[bench] -fn sum_2d_cutout(bench: &mut test::Bencher) { +fn sum_2d_cutout(bench: &mut test::Bencher) +{ let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = black_box(av); @@ -176,14 +188,16 @@ fn sum_2d_cutout(bench: &mut test::Bencher) { } #[bench] -fn sum_2d_float(bench: &mut test::Bencher) { +fn sum_2d_float(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); let a = black_box(a.view()); bench.iter(|| a.sum()); } #[bench] -fn sum_2d_float_cutout(bench: &mut test::Bencher) { +fn sum_2d_float_cutout(bench: &mut test::Bencher) +{ let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = black_box(av); @@ -191,7 +205,8 @@ fn sum_2d_float_cutout(bench: &mut test::Bencher) { } #[bench] -fn sum_2d_float_t_cutout(bench: &mut test::Bencher) { +fn sum_2d_float_t_cutout(bench: &mut test::Bencher) +{ let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]).reversed_axes(); let a = black_box(av); @@ -199,13 +214,15 @@ fn sum_2d_float_t_cutout(bench: &mut test::Bencher) { } #[bench] -fn fold_sum_i32_2d_regular(bench: &mut test::Bencher) { +fn fold_sum_i32_2d_regular(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); bench.iter(|| a.fold(0, |acc, &x| acc + x)); } #[bench] -fn fold_sum_i32_2d_cutout(bench: &mut test::Bencher) { +fn fold_sum_i32_2d_cutout(bench: &mut test::Bencher) +{ let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = black_box(av); @@ -213,7 +230,8 @@ fn fold_sum_i32_2d_cutout(bench: &mut test::Bencher) { } #[bench] -fn fold_sum_i32_2d_stride(bench: &mut test::Bencher) { +fn fold_sum_i32_2d_stride(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 128)); let av = a.slice(s![.., ..;2]); let a = black_box(av); @@ -221,14 +239,16 @@ fn fold_sum_i32_2d_stride(bench: &mut test::Bencher) { } #[bench] -fn fold_sum_i32_2d_transpose(bench: &mut test::Bencher) { +fn fold_sum_i32_2d_transpose(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); let a = a.t(); bench.iter(|| a.fold(0, |acc, &x| acc + x)); } #[bench] -fn fold_sum_i32_2d_cutout_transpose(bench: &mut test::Bencher) { +fn fold_sum_i32_2d_cutout_transpose(bench: &mut test::Bencher) +{ let a = Array::::zeros((66, 66)); let mut av = a.slice(s![1..-1, 1..-1]); av.swap_axes(0, 1); @@ -239,7 +259,8 @@ fn fold_sum_i32_2d_cutout_transpose(bench: &mut test::Bencher) { const ADD2DSZ: usize = 64; #[bench] -fn add_2d_regular(bench: &mut test::Bencher) { +fn add_2d_regular(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); let bv = b.view(); @@ -249,34 +270,76 @@ fn add_2d_regular(bench: &mut test::Bencher) { } #[bench] -fn add_2d_zip(bench: &mut test::Bencher) { +fn add_2d_zip(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); bench.iter(|| { - Zip::from(&mut a).and(&b).apply(|a, &b| *a += b); + Zip::from(&mut a).and(&b).for_each(|a, &b| *a += b); }); } #[bench] -fn add_2d_alloc(bench: &mut test::Bencher) { +fn add_2d_alloc_plus(bench: &mut test::Bencher) +{ let a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); bench.iter(|| &a + &b); } #[bench] -fn add_2d_zip_alloc(bench: &mut test::Bencher) { +fn add_2d_alloc_zip_uninit(bench: &mut test::Bencher) +{ let a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); bench.iter(|| unsafe { - let mut c = Array::uninitialized(a.dim()); - azip!((&a in &a, &b in &b, c in &mut c) *c = a + b); + let mut c = Array::::uninit(a.dim()); + azip!((&a in &a, &b in &b, c in c.raw_view_mut().cast::()) + c.write(a + b) + ); c }); } #[bench] -fn add_2d_assign_ops(bench: &mut test::Bencher) { +fn add_2d_alloc_zip_collect(bench: &mut test::Bencher) +{ + let a = Array::::zeros((ADD2DSZ, ADD2DSZ)); + let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); + bench.iter(|| Zip::from(&a).and(&b).map_collect(|&x, &y| x + y)); +} + +#[bench] +fn vec_string_collect(bench: &mut test::Bencher) +{ + let v = vec![""; 10240]; + bench.iter(|| v.iter().map(|s| s.to_owned()).collect::>()); +} + +#[bench] +fn array_string_collect(bench: &mut test::Bencher) +{ + let v = Array::from(vec![""; 10240]); + bench.iter(|| Zip::from(&v).map_collect(|s| s.to_owned())); +} + +#[bench] +fn vec_f64_collect(bench: &mut test::Bencher) +{ + let v = vec![1.; 10240]; + bench.iter(|| v.iter().map(|s| s + 1.).collect::>()); +} + +#[bench] +fn array_f64_collect(bench: &mut test::Bencher) +{ + let v = Array::from(vec![1.; 10240]); + bench.iter(|| Zip::from(&v).map_collect(|s| s + 1.)); +} + +#[bench] +fn add_2d_assign_ops(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); let bv = b.view(); @@ -288,7 +351,8 @@ fn add_2d_assign_ops(bench: &mut test::Bencher) { } #[bench] -fn add_2d_cutout(bench: &mut test::Bencher) { +fn add_2d_cutout(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ + 2, ADD2DSZ + 2)); let mut acut = a.slice_mut(s![1..-1, 1..-1]); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); @@ -299,56 +363,61 @@ fn add_2d_cutout(bench: &mut test::Bencher) { } #[bench] -fn add_2d_zip_cutout(bench: &mut test::Bencher) { +fn add_2d_zip_cutout(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ + 2, ADD2DSZ + 2)); let mut acut = a.slice_mut(s![1..-1, 1..-1]); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); bench.iter(|| { - Zip::from(&mut acut).and(&b).apply(|a, &b| *a += b); + Zip::from(&mut acut).and(&b).for_each(|a, &b| *a += b); }); } #[bench] #[allow(clippy::identity_op)] -fn add_2d_cutouts_by_4(bench: &mut test::Bencher) { +fn add_2d_cutouts_by_4(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((64 * 1, 64 * 1)); let b = Array::::zeros((64 * 1, 64 * 1)); let chunksz = (4, 4); bench.iter(|| { Zip::from(a.exact_chunks_mut(chunksz)) .and(b.exact_chunks(chunksz)) - .apply(|mut a, b| a += &b); + .for_each(|mut a, b| a += &b); }); } #[bench] #[allow(clippy::identity_op)] -fn add_2d_cutouts_by_16(bench: &mut test::Bencher) { +fn add_2d_cutouts_by_16(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((64 * 1, 64 * 1)); let b = Array::::zeros((64 * 1, 64 * 1)); let chunksz = (16, 16); bench.iter(|| { Zip::from(a.exact_chunks_mut(chunksz)) .and(b.exact_chunks(chunksz)) - .apply(|mut a, b| a += &b); + .for_each(|mut a, b| a += &b); }); } #[bench] #[allow(clippy::identity_op)] -fn add_2d_cutouts_by_32(bench: &mut test::Bencher) { +fn add_2d_cutouts_by_32(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((64 * 1, 64 * 1)); let b = Array::::zeros((64 * 1, 64 * 1)); let chunksz = (32, 32); bench.iter(|| { Zip::from(a.exact_chunks_mut(chunksz)) .and(b.exact_chunks(chunksz)) - .apply(|mut a, b| a += &b); + .for_each(|mut a, b| a += &b); }); } #[bench] -fn add_2d_broadcast_1_to_2(bench: &mut test::Bencher) { +fn add_2d_broadcast_1_to_2(bench: &mut test::Bencher) +{ let mut a = Array2::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array1::::zeros(ADD2DSZ); let bv = b.view(); @@ -358,7 +427,8 @@ fn add_2d_broadcast_1_to_2(bench: &mut test::Bencher) { } #[bench] -fn add_2d_broadcast_0_to_2(bench: &mut test::Bencher) { +fn add_2d_broadcast_0_to_2(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros(()); let bv = b.view(); @@ -368,34 +438,55 @@ fn add_2d_broadcast_0_to_2(bench: &mut test::Bencher) { } #[bench] -fn scalar_toowned(bench: &mut test::Bencher) { +fn scalar_toowned(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); bench.iter(|| a.to_owned()); } #[bench] -fn scalar_add_1(bench: &mut test::Bencher) { +fn scalar_add_1(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); let n = 1.; bench.iter(|| &a + n); } #[bench] -fn scalar_add_2(bench: &mut test::Bencher) { +fn scalar_add_2(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); let n = 1.; bench.iter(|| n + &a); } #[bench] -fn scalar_sub_1(bench: &mut test::Bencher) { +fn scalar_add_strided_1(bench: &mut test::Bencher) +{ + let a = Array::from_shape_fn((64, 64 * 2), |(i, j)| (i * 64 + j) as f32).slice_move(s![.., ..;2]); + let n = 1.; + bench.iter(|| &a + n); +} + +#[bench] +fn scalar_add_strided_2(bench: &mut test::Bencher) +{ + let a = Array::from_shape_fn((64, 64 * 2), |(i, j)| (i * 64 + j) as f32).slice_move(s![.., ..;2]); + let n = 1.; + bench.iter(|| n + &a); +} + +#[bench] +fn scalar_sub_1(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); let n = 1.; bench.iter(|| &a - n); } #[bench] -fn scalar_sub_2(bench: &mut test::Bencher) { +fn scalar_sub_2(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); let n = 1.; bench.iter(|| n - &a); @@ -403,7 +494,8 @@ fn scalar_sub_2(bench: &mut test::Bencher) { // This is for comparison with add_2d_broadcast_0_to_2 #[bench] -fn add_2d_0_to_2_iadd_scalar(bench: &mut test::Bencher) { +fn add_2d_0_to_2_iadd_scalar(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let n = black_box(0); bench.iter(|| { @@ -412,7 +504,8 @@ fn add_2d_0_to_2_iadd_scalar(bench: &mut test::Bencher) { } #[bench] -fn add_2d_strided(bench: &mut test::Bencher) { +fn add_2d_strided(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ * 2)); let mut a = a.slice_mut(s![.., ..;2]); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); @@ -423,7 +516,8 @@ fn add_2d_strided(bench: &mut test::Bencher) { } #[bench] -fn add_2d_regular_dyn(bench: &mut test::Bencher) { +fn add_2d_regular_dyn(bench: &mut test::Bencher) +{ let mut a = Array::::zeros(&[ADD2DSZ, ADD2DSZ][..]); let b = Array::::zeros(&[ADD2DSZ, ADD2DSZ][..]); let bv = b.view(); @@ -433,7 +527,8 @@ fn add_2d_regular_dyn(bench: &mut test::Bencher) { } #[bench] -fn add_2d_strided_dyn(bench: &mut test::Bencher) { +fn add_2d_strided_dyn(bench: &mut test::Bencher) +{ let mut a = Array::::zeros(&[ADD2DSZ, ADD2DSZ * 2][..]); let mut a = a.slice_mut(s![.., ..;2]); let b = Array::::zeros(&[ADD2DSZ, ADD2DSZ][..]); @@ -444,17 +539,19 @@ fn add_2d_strided_dyn(bench: &mut test::Bencher) { } #[bench] -fn add_2d_zip_strided(bench: &mut test::Bencher) { +fn add_2d_zip_strided(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ * 2)); let mut a = a.slice_mut(s![.., ..;2]); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); bench.iter(|| { - Zip::from(&mut a).and(&b).apply(|a, &b| *a += b); + Zip::from(&mut a).and(&b).for_each(|a, &b| *a += b); }); } #[bench] -fn add_2d_one_transposed(bench: &mut test::Bencher) { +fn add_2d_one_transposed(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); a.swap_axes(0, 1); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); @@ -464,17 +561,19 @@ fn add_2d_one_transposed(bench: &mut test::Bencher) { } #[bench] -fn add_2d_zip_one_transposed(bench: &mut test::Bencher) { +fn add_2d_zip_one_transposed(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); a.swap_axes(0, 1); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); bench.iter(|| { - Zip::from(&mut a).and(&b).apply(|a, &b| *a += b); + Zip::from(&mut a).and(&b).for_each(|a, &b| *a += b); }); } #[bench] -fn add_2d_both_transposed(bench: &mut test::Bencher) { +fn add_2d_both_transposed(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); a.swap_axes(0, 1); let mut b = Array::::zeros((ADD2DSZ, ADD2DSZ)); @@ -485,18 +584,20 @@ fn add_2d_both_transposed(bench: &mut test::Bencher) { } #[bench] -fn add_2d_zip_both_transposed(bench: &mut test::Bencher) { +fn add_2d_zip_both_transposed(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); a.swap_axes(0, 1); let mut b = Array::::zeros((ADD2DSZ, ADD2DSZ)); b.swap_axes(0, 1); bench.iter(|| { - Zip::from(&mut a).and(&b).apply(|a, &b| *a += b); + Zip::from(&mut a).and(&b).for_each(|a, &b| *a += b); }); } #[bench] -fn add_2d_f32_regular(bench: &mut test::Bencher) { +fn add_2d_f32_regular(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); let b = Array::::zeros((ADD2DSZ, ADD2DSZ)); let bv = b.view(); @@ -508,7 +609,8 @@ fn add_2d_f32_regular(bench: &mut test::Bencher) { const ADD3DSZ: usize = 16; #[bench] -fn add_3d_strided(bench: &mut test::Bencher) { +fn add_3d_strided(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD3DSZ, ADD3DSZ, ADD3DSZ * 2)); let mut a = a.slice_mut(s![.., .., ..;2]); let b = Array::::zeros(a.dim()); @@ -519,7 +621,8 @@ fn add_3d_strided(bench: &mut test::Bencher) { } #[bench] -fn add_3d_strided_dyn(bench: &mut test::Bencher) { +fn add_3d_strided_dyn(bench: &mut test::Bencher) +{ let mut a = Array::::zeros(&[ADD3DSZ, ADD3DSZ, ADD3DSZ * 2][..]); let mut a = a.slice_mut(s![.., .., ..;2]); let b = Array::::zeros(a.dim()); @@ -532,7 +635,8 @@ fn add_3d_strided_dyn(bench: &mut test::Bencher) { const ADD1D_SIZE: usize = 64 * 64; #[bench] -fn add_1d_regular(bench: &mut test::Bencher) { +fn add_1d_regular(bench: &mut test::Bencher) +{ let mut a = Array::::zeros(ADD1D_SIZE); let b = Array::::zeros(a.dim()); bench.iter(|| { @@ -541,7 +645,8 @@ fn add_1d_regular(bench: &mut test::Bencher) { } #[bench] -fn add_1d_strided(bench: &mut test::Bencher) { +fn add_1d_strided(bench: &mut test::Bencher) +{ let mut a = Array::::zeros(ADD1D_SIZE * 2); let mut av = a.slice_mut(s![..;2]); let b = Array::::zeros(av.dim()); @@ -551,7 +656,8 @@ fn add_1d_strided(bench: &mut test::Bencher) { } #[bench] -fn iadd_scalar_2d_regular(bench: &mut test::Bencher) { +fn iadd_scalar_2d_regular(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ)); bench.iter(|| { a += 1.; @@ -559,7 +665,8 @@ fn iadd_scalar_2d_regular(bench: &mut test::Bencher) { } #[bench] -fn iadd_scalar_2d_strided(bench: &mut test::Bencher) { +fn iadd_scalar_2d_strided(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((ADD2DSZ, ADD2DSZ * 2)); let mut a = a.slice_mut(s![.., ..;2]); bench.iter(|| { @@ -568,7 +675,8 @@ fn iadd_scalar_2d_strided(bench: &mut test::Bencher) { } #[bench] -fn iadd_scalar_2d_regular_dyn(bench: &mut test::Bencher) { +fn iadd_scalar_2d_regular_dyn(bench: &mut test::Bencher) +{ let mut a = Array::::zeros(vec![ADD2DSZ, ADD2DSZ]); bench.iter(|| { a += 1.; @@ -576,7 +684,8 @@ fn iadd_scalar_2d_regular_dyn(bench: &mut test::Bencher) { } #[bench] -fn iadd_scalar_2d_strided_dyn(bench: &mut test::Bencher) { +fn iadd_scalar_2d_strided_dyn(bench: &mut test::Bencher) +{ let mut a = Array::::zeros(vec![ADD2DSZ, ADD2DSZ * 2]); let mut a = a.slice_mut(s![.., ..;2]); bench.iter(|| { @@ -585,7 +694,8 @@ fn iadd_scalar_2d_strided_dyn(bench: &mut test::Bencher) { } #[bench] -fn scaled_add_2d_f32_regular(bench: &mut test::Bencher) { +fn scaled_add_2d_f32_regular(bench: &mut test::Bencher) +{ let mut av = Array::::zeros((ADD2DSZ, ADD2DSZ)); let bv = Array::::zeros((ADD2DSZ, ADD2DSZ)); let scalar = std::f32::consts::PI; @@ -595,7 +705,8 @@ fn scaled_add_2d_f32_regular(bench: &mut test::Bencher) { } #[bench] -fn assign_scalar_2d_corder(bench: &mut test::Bencher) { +fn assign_scalar_2d_corder(bench: &mut test::Bencher) +{ let a = Array::zeros((ADD2DSZ, ADD2DSZ)); let mut a = black_box(a); let s = 3.; @@ -603,7 +714,8 @@ fn assign_scalar_2d_corder(bench: &mut test::Bencher) { } #[bench] -fn assign_scalar_2d_cutout(bench: &mut test::Bencher) { +fn assign_scalar_2d_cutout(bench: &mut test::Bencher) +{ let mut a = Array::zeros((66, 66)); let a = a.slice_mut(s![1..-1, 1..-1]); let mut a = black_box(a); @@ -612,7 +724,8 @@ fn assign_scalar_2d_cutout(bench: &mut test::Bencher) { } #[bench] -fn assign_scalar_2d_forder(bench: &mut test::Bencher) { +fn assign_scalar_2d_forder(bench: &mut test::Bencher) +{ let mut a = Array::zeros((ADD2DSZ, ADD2DSZ)); a.swap_axes(0, 1); let mut a = black_box(a); @@ -621,14 +734,16 @@ fn assign_scalar_2d_forder(bench: &mut test::Bencher) { } #[bench] -fn assign_zero_2d_corder(bench: &mut test::Bencher) { +fn assign_zero_2d_corder(bench: &mut test::Bencher) +{ let a = Array::zeros((ADD2DSZ, ADD2DSZ)); let mut a = black_box(a); bench.iter(|| a.fill(0.)) } #[bench] -fn assign_zero_2d_cutout(bench: &mut test::Bencher) { +fn assign_zero_2d_cutout(bench: &mut test::Bencher) +{ let mut a = Array::zeros((66, 66)); let a = a.slice_mut(s![1..-1, 1..-1]); let mut a = black_box(a); @@ -636,7 +751,8 @@ fn assign_zero_2d_cutout(bench: &mut test::Bencher) { } #[bench] -fn assign_zero_2d_forder(bench: &mut test::Bencher) { +fn assign_zero_2d_forder(bench: &mut test::Bencher) +{ let mut a = Array::zeros((ADD2DSZ, ADD2DSZ)); a.swap_axes(0, 1); let mut a = black_box(a); @@ -644,7 +760,8 @@ fn assign_zero_2d_forder(bench: &mut test::Bencher) { } #[bench] -fn bench_iter_diag(bench: &mut test::Bencher) { +fn bench_iter_diag(bench: &mut test::Bencher) +{ let a = Array::::zeros((1024, 1024)); bench.iter(|| { for elt in a.diag() { @@ -654,7 +771,8 @@ fn bench_iter_diag(bench: &mut test::Bencher) { } #[bench] -fn bench_row_iter(bench: &mut test::Bencher) { +fn bench_row_iter(bench: &mut test::Bencher) +{ let a = Array::::zeros((1024, 1024)); let it = a.row(17); bench.iter(|| { @@ -665,7 +783,8 @@ fn bench_row_iter(bench: &mut test::Bencher) { } #[bench] -fn bench_col_iter(bench: &mut test::Bencher) { +fn bench_col_iter(bench: &mut test::Bencher) +{ let a = Array::::zeros((1024, 1024)); let it = a.column(17); bench.iter(|| { @@ -692,7 +811,7 @@ macro_rules! mat_mul { } )+ } - } + }; } mat_mul! {mat_mul_f32, f32, @@ -735,7 +854,8 @@ mat_mul! {mat_mul_i32, i32, } #[bench] -fn create_iter_4d(bench: &mut test::Bencher) { +fn create_iter_4d(bench: &mut test::Bencher) +{ let mut a = Array::from_elem((4, 5, 3, 2), 1.0); a.swap_axes(0, 1); a.swap_axes(2, 1); @@ -745,82 +865,94 @@ fn create_iter_4d(bench: &mut test::Bencher) { } #[bench] -fn bench_to_owned_n(bench: &mut test::Bencher) { +fn bench_to_owned_n(bench: &mut test::Bencher) +{ let a = Array::::zeros((32, 32)); bench.iter(|| a.to_owned()); } #[bench] -fn bench_to_owned_t(bench: &mut test::Bencher) { +fn bench_to_owned_t(bench: &mut test::Bencher) +{ let mut a = Array::::zeros((32, 32)); a.swap_axes(0, 1); bench.iter(|| a.to_owned()); } #[bench] -fn bench_to_owned_strided(bench: &mut test::Bencher) { +fn bench_to_owned_strided(bench: &mut test::Bencher) +{ let a = Array::::zeros((32, 64)); let a = a.slice(s![.., ..;2]); bench.iter(|| a.to_owned()); } #[bench] -fn equality_i32(bench: &mut test::Bencher) { +fn equality_i32(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); let b = Array::::zeros((64, 64)); bench.iter(|| a == b); } #[bench] -fn equality_f32(bench: &mut test::Bencher) { +fn equality_f32(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); let b = Array::::zeros((64, 64)); bench.iter(|| a == b); } #[bench] -fn equality_f32_mixorder(bench: &mut test::Bencher) { +fn equality_f32_mixorder(bench: &mut test::Bencher) +{ let a = Array::::zeros((64, 64)); let b = Array::::zeros((64, 64).f()); bench.iter(|| a == b); } #[bench] -fn dot_f32_16(bench: &mut test::Bencher) { +fn dot_f32_16(bench: &mut test::Bencher) +{ let a = Array::::zeros(16); let b = Array::::zeros(16); bench.iter(|| a.dot(&b)); } #[bench] -fn dot_f32_20(bench: &mut test::Bencher) { +fn dot_f32_20(bench: &mut test::Bencher) +{ let a = Array::::zeros(20); let b = Array::::zeros(20); bench.iter(|| a.dot(&b)); } #[bench] -fn dot_f32_32(bench: &mut test::Bencher) { +fn dot_f32_32(bench: &mut test::Bencher) +{ let a = Array::::zeros(32); let b = Array::::zeros(32); bench.iter(|| a.dot(&b)); } #[bench] -fn dot_f32_256(bench: &mut test::Bencher) { +fn dot_f32_256(bench: &mut test::Bencher) +{ let a = Array::::zeros(256); let b = Array::::zeros(256); bench.iter(|| a.dot(&b)); } #[bench] -fn dot_f32_1024(bench: &mut test::Bencher) { +fn dot_f32_1024(bench: &mut test::Bencher) +{ let av = Array::::zeros(1024); let bv = Array::::zeros(1024); bench.iter(|| av.dot(&bv)); } #[bench] -fn dot_f32_10e6(bench: &mut test::Bencher) { +fn dot_f32_10e6(bench: &mut test::Bencher) +{ let n = 1_000_000; let av = Array::::zeros(n); let bv = Array::::zeros(n); @@ -828,7 +960,8 @@ fn dot_f32_10e6(bench: &mut test::Bencher) { } #[bench] -fn dot_extended(bench: &mut test::Bencher) { +fn dot_extended(bench: &mut test::Bencher) +{ let m = 10; let n = 33; let k = 10; @@ -849,33 +982,122 @@ fn dot_extended(bench: &mut test::Bencher) { const MEAN_SUM_N: usize = 127; -fn range_mat(m: Ix, n: Ix) -> Array2 { +fn range_mat(m: Ix, n: Ix) -> Array2 +{ assert!(m * n != 0); Array::linspace(0., (m * n - 1) as f32, m * n) - .into_shape((m, n)) + .into_shape_with_order((m, n)) .unwrap() } #[bench] -fn mean_axis0(bench: &mut test::Bencher) { +fn mean_axis0(bench: &mut test::Bencher) +{ let a = range_mat(MEAN_SUM_N, MEAN_SUM_N); bench.iter(|| a.mean_axis(Axis(0))); } #[bench] -fn mean_axis1(bench: &mut test::Bencher) { +fn mean_axis1(bench: &mut test::Bencher) +{ let a = range_mat(MEAN_SUM_N, MEAN_SUM_N); bench.iter(|| a.mean_axis(Axis(1))); } #[bench] -fn sum_axis0(bench: &mut test::Bencher) { +fn sum_axis0(bench: &mut test::Bencher) +{ let a = range_mat(MEAN_SUM_N, MEAN_SUM_N); bench.iter(|| a.sum_axis(Axis(0))); } #[bench] -fn sum_axis1(bench: &mut test::Bencher) { +fn sum_axis1(bench: &mut test::Bencher) +{ let a = range_mat(MEAN_SUM_N, MEAN_SUM_N); bench.iter(|| a.sum_axis(Axis(1))); } + +#[bench] +fn into_dimensionality_ix1_ok(bench: &mut test::Bencher) +{ + let a = Array::::zeros(Ix1(10)); + let a = a.view(); + bench.iter(|| a.into_dimensionality::()); +} + +#[bench] +fn into_dimensionality_ix3_ok(bench: &mut test::Bencher) +{ + let a = Array::::zeros(Ix3(10, 10, 10)); + let a = a.view(); + bench.iter(|| a.into_dimensionality::()); +} + +#[bench] +fn into_dimensionality_ix3_err(bench: &mut test::Bencher) +{ + let a = Array::::zeros(Ix3(10, 10, 10)); + let a = a.view(); + bench.iter(|| a.into_dimensionality::()); +} + +#[bench] +fn into_dimensionality_dyn_to_ix3(bench: &mut test::Bencher) +{ + let a = Array::::zeros(IxDyn(&[10, 10, 10])); + let a = a.view(); + bench.iter(|| a.clone().into_dimensionality::()); +} + +#[bench] +fn into_dimensionality_dyn_to_dyn(bench: &mut test::Bencher) +{ + let a = Array::::zeros(IxDyn(&[10, 10, 10])); + let a = a.view(); + bench.iter(|| a.clone().into_dimensionality::()); +} + +#[bench] +fn into_dyn_ix3(bench: &mut test::Bencher) +{ + let a = Array::::zeros(Ix3(10, 10, 10)); + let a = a.view(); + bench.iter(|| a.into_dyn()); +} + +#[bench] +fn into_dyn_ix5(bench: &mut test::Bencher) +{ + let a = Array::::zeros(Ix5(2, 2, 2, 2, 2)); + let a = a.view(); + bench.iter(|| a.into_dyn()); +} + +#[bench] +fn into_dyn_dyn(bench: &mut test::Bencher) +{ + let a = Array::::zeros(IxDyn(&[10, 10, 10])); + let a = a.view(); + bench.iter(|| a.clone().into_dyn()); +} + +#[bench] +fn broadcast_same_dim(bench: &mut test::Bencher) +{ + let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let s = Array4::from_shape_vec((2, 2, 3, 2), s.to_vec()).unwrap(); + let a = s.slice(s![.., ..;-1, ..;2, ..]); + let b = s.slice(s![.., .., ..;2, ..]); + bench.iter(|| &a + &b); +} + +#[bench] +fn broadcast_one_side(bench: &mut test::Bencher) +{ + let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let s2 = [1, 2, 3, 4, 5, 6]; + let a = Array4::from_shape_vec((4, 1, 3, 2), s.to_vec()).unwrap(); + let b = Array3::from_shape_vec((1, 3, 2), s2.to_vec()).unwrap(); + bench.iter(|| &a + &b); +} diff --git a/benches/chunks.rs b/benches/chunks.rs index 5ea9ba466..46780492f 100644 --- a/benches/chunks.rs +++ b/benches/chunks.rs @@ -7,7 +7,8 @@ use ndarray::prelude::*; use ndarray::NdProducer; #[bench] -fn chunk2x2_iter_sum(bench: &mut Bencher) { +fn chunk2x2_iter_sum(bench: &mut Bencher) +{ let a = Array::::zeros((256, 256)); let chunksz = (2, 2); let mut sum = Array::zeros(a.exact_chunks(chunksz).raw_dim()); @@ -19,7 +20,8 @@ fn chunk2x2_iter_sum(bench: &mut Bencher) { } #[bench] -fn chunk2x2_sum(bench: &mut Bencher) { +fn chunk2x2_sum(bench: &mut Bencher) +{ let a = Array::::zeros((256, 256)); let chunksz = (2, 2); let mut sum = Array::zeros(a.exact_chunks(chunksz).raw_dim()); @@ -31,7 +33,8 @@ fn chunk2x2_sum(bench: &mut Bencher) { } #[bench] -fn chunk2x2_sum_get1(bench: &mut Bencher) { +fn chunk2x2_sum_get1(bench: &mut Bencher) +{ let a = Array::::zeros((256, 256)); let chunksz = (2, 2); let mut sum = Array::::zeros(a.exact_chunks(chunksz).raw_dim()); @@ -46,7 +49,8 @@ fn chunk2x2_sum_get1(bench: &mut Bencher) { } #[bench] -fn chunk2x2_sum_uget1(bench: &mut Bencher) { +fn chunk2x2_sum_uget1(bench: &mut Bencher) +{ let a = Array::::zeros((256, 256)); let chunksz = (2, 2); let mut sum = Array::::zeros(a.exact_chunks(chunksz).raw_dim()); @@ -64,7 +68,8 @@ fn chunk2x2_sum_uget1(bench: &mut Bencher) { #[bench] #[allow(clippy::identity_op)] -fn chunk2x2_sum_get2(bench: &mut Bencher) { +fn chunk2x2_sum_get2(bench: &mut Bencher) +{ let a = Array::::zeros((256, 256)); let chunksz = (2, 2); let mut sum = Array::::zeros(a.exact_chunks(chunksz).raw_dim()); diff --git a/benches/construct.rs b/benches/construct.rs index 3d77a89e0..278174388 100644 --- a/benches/construct.rs +++ b/benches/construct.rs @@ -1,9 +1,6 @@ #![feature(test)] #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names )] extern crate test; use test::Bencher; @@ -11,24 +8,32 @@ use test::Bencher; use ndarray::prelude::*; #[bench] -fn default_f64(bench: &mut Bencher) { +fn default_f64(bench: &mut Bencher) +{ bench.iter(|| Array::::default((128, 128))) } #[bench] -fn zeros_f64(bench: &mut Bencher) { +fn zeros_f64(bench: &mut Bencher) +{ bench.iter(|| Array::::zeros((128, 128))) } #[bench] -fn map_regular(bench: &mut test::Bencher) { - let a = Array::linspace(0., 127., 128).into_shape((8, 16)).unwrap(); +fn map_regular(bench: &mut test::Bencher) +{ + let a = Array::linspace(0., 127., 128) + .into_shape_with_order((8, 16)) + .unwrap(); bench.iter(|| a.map(|&x| 2. * x)); } #[bench] -fn map_stride(bench: &mut test::Bencher) { - let a = Array::linspace(0., 127., 256).into_shape((8, 32)).unwrap(); +fn map_stride(bench: &mut test::Bencher) +{ + let a = Array::linspace(0., 127., 256) + .into_shape_with_order((8, 32)) + .unwrap(); let av = a.slice(s![.., ..;2]); bench.iter(|| av.map(|&x| 2. * x)); } diff --git a/benches/gemv.rs b/benches/gemv.rs deleted file mode 100644 index 4bca08319..000000000 --- a/benches/gemv.rs +++ /dev/null @@ -1,47 +0,0 @@ -#![feature(test)] -#![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names -)] - -extern crate test; -use test::Bencher; - -use ndarray::prelude::*; - -use ndarray::linalg::general_mat_vec_mul; - -#[bench] -fn gemv_64_64c(bench: &mut Bencher) { - let a = Array::zeros((64, 64)); - let (m, n) = a.dim(); - let x = Array::zeros(n); - let mut y = Array::zeros(m); - bench.iter(|| { - general_mat_vec_mul(1.0, &a, &x, 1.0, &mut y); - }); -} - -#[bench] -fn gemv_64_64f(bench: &mut Bencher) { - let a = Array::zeros((64, 64).f()); - let (m, n) = a.dim(); - let x = Array::zeros(n); - let mut y = Array::zeros(m); - bench.iter(|| { - general_mat_vec_mul(1.0, &a, &x, 1.0, &mut y); - }); -} - -#[bench] -fn gemv_64_32(bench: &mut Bencher) { - let a = Array::zeros((64, 32)); - let (m, n) = a.dim(); - let x = Array::zeros(n); - let mut y = Array::zeros(m); - bench.iter(|| { - general_mat_vec_mul(1.0, &a, &x, 1.0, &mut y); - }); -} diff --git a/benches/gemv_gemm.rs b/benches/gemv_gemm.rs new file mode 100644 index 000000000..2d1642623 --- /dev/null +++ b/benches/gemv_gemm.rs @@ -0,0 +1,77 @@ +#![feature(test)] +#![allow( + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names +)] + +extern crate test; +use test::Bencher; + +use num_complex::Complex; +use num_traits::{Float, One, Zero}; + +use ndarray::prelude::*; + +use ndarray::linalg::general_mat_mul; +use ndarray::linalg::general_mat_vec_mul; +use ndarray::LinalgScalar; + +#[bench] +fn gemv_64_64c(bench: &mut Bencher) +{ + let a = Array::zeros((64, 64)); + let (m, n) = a.dim(); + let x = Array::zeros(n); + let mut y = Array::zeros(m); + bench.iter(|| { + general_mat_vec_mul(1.0, &a, &x, 1.0, &mut y); + }); +} + +#[bench] +fn gemv_64_64f(bench: &mut Bencher) +{ + let a = Array::zeros((64, 64).f()); + let (m, n) = a.dim(); + let x = Array::zeros(n); + let mut y = Array::zeros(m); + bench.iter(|| { + general_mat_vec_mul(1.0, &a, &x, 1.0, &mut y); + }); +} + +#[bench] +fn gemv_64_32(bench: &mut Bencher) +{ + let a = Array::zeros((64, 32)); + let (m, n) = a.dim(); + let x = Array::zeros(n); + let mut y = Array::zeros(m); + bench.iter(|| { + general_mat_vec_mul(1.0, &a, &x, 1.0, &mut y); + }); +} + +#[bench] +fn cgemm_100(bench: &mut Bencher) +{ + cgemm_bench::(100, bench); +} + +#[bench] +fn zgemm_100(bench: &mut Bencher) +{ + cgemm_bench::(100, bench); +} + +fn cgemm_bench(size: usize, bench: &mut Bencher) +where A: LinalgScalar + Float +{ + let (m, k, n) = (size, size, size); + let a = Array::, _>::zeros((m, k)); + + let x = Array::zeros((k, n)); + let mut y = Array::zeros((m, n)); + bench.iter(|| { + general_mat_mul(Complex::one(), &a, &x, Complex::zero(), &mut y); + }); +} diff --git a/benches/higher-order.rs b/benches/higher-order.rs index 2ea0721af..9cc3bd961 100644 --- a/benches/higher-order.rs +++ b/benches/higher-order.rs @@ -1,12 +1,8 @@ #![feature(test)] #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names )] extern crate test; -use std::iter::FromIterator; use test::black_box; use test::Bencher; @@ -17,19 +13,24 @@ const X: usize = 64; const Y: usize = 16; #[bench] -fn map_regular(bench: &mut Bencher) { - let a = Array::linspace(0., 127., N).into_shape((X, Y)).unwrap(); +fn map_regular(bench: &mut Bencher) +{ + let a = Array::linspace(0., 127., N) + .into_shape_with_order((X, Y)) + .unwrap(); bench.iter(|| a.map(|&x| 2. * x)); } -pub fn double_array(mut a: ArrayViewMut2<'_, f64>) { +pub fn double_array(mut a: ArrayViewMut2<'_, f64>) +{ a *= 2.0; } #[bench] -fn map_stride_double_f64(bench: &mut Bencher) { +fn map_stride_double_f64(bench: &mut Bencher) +{ let mut a = Array::linspace(0., 127., N * 2) - .into_shape([X, Y * 2]) + .into_shape_with_order([X, Y * 2]) .unwrap(); let mut av = a.slice_mut(s![.., ..;2]); bench.iter(|| { @@ -38,18 +39,20 @@ fn map_stride_double_f64(bench: &mut Bencher) { } #[bench] -fn map_stride_f64(bench: &mut Bencher) { +fn map_stride_f64(bench: &mut Bencher) +{ let a = Array::linspace(0., 127., N * 2) - .into_shape([X, Y * 2]) + .into_shape_with_order([X, Y * 2]) .unwrap(); let av = a.slice(s![.., ..;2]); bench.iter(|| av.map(|&x| 2. * x)); } #[bench] -fn map_stride_u32(bench: &mut Bencher) { +fn map_stride_u32(bench: &mut Bencher) +{ let a = Array::linspace(0., 127., N * 2) - .into_shape([X, Y * 2]) + .into_shape_with_order([X, Y * 2]) .unwrap(); let b = a.mapv(|x| x as u32); let av = b.slice(s![.., ..;2]); @@ -57,9 +60,10 @@ fn map_stride_u32(bench: &mut Bencher) { } #[bench] -fn fold_axis(bench: &mut Bencher) { +fn fold_axis(bench: &mut Bencher) +{ let a = Array::linspace(0., 127., N * 2) - .into_shape([X, Y * 2]) + .into_shape_with_order([X, Y * 2]) .unwrap(); bench.iter(|| a.fold_axis(Axis(0), 0., |&acc, &elt| acc + elt)); } @@ -68,17 +72,19 @@ const MA: usize = 64; const MASZ: usize = MA * MA; #[bench] -fn map_axis_0(bench: &mut Bencher) { +fn map_axis_0(bench: &mut Bencher) +{ let a = Array::from_iter(0..MASZ as i32) - .into_shape([MA, MA]) + .into_shape_with_order([MA, MA]) .unwrap(); bench.iter(|| a.map_axis(Axis(0), black_box)); } #[bench] -fn map_axis_1(bench: &mut Bencher) { +fn map_axis_1(bench: &mut Bencher) +{ let a = Array::from_iter(0..MASZ as i32) - .into_shape([MA, MA]) + .into_shape_with_order([MA, MA]) .unwrap(); bench.iter(|| a.map_axis(Axis(1), black_box)); } diff --git a/benches/iter.rs b/benches/iter.rs index 289f1fb50..77f511745 100644 --- a/benches/iter.rs +++ b/benches/iter.rs @@ -1,9 +1,6 @@ #![feature(test)] #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names )] extern crate test; @@ -16,13 +13,15 @@ use ndarray::Slice; use ndarray::{FoldWhile, Zip}; #[bench] -fn iter_sum_2d_regular(bench: &mut Bencher) { +fn iter_sum_2d_regular(bench: &mut Bencher) +{ let a = Array::::zeros((64, 64)); bench.iter(|| a.iter().sum::()); } #[bench] -fn iter_sum_2d_cutout(bench: &mut Bencher) { +fn iter_sum_2d_cutout(bench: &mut Bencher) +{ let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = av; @@ -30,7 +29,8 @@ fn iter_sum_2d_cutout(bench: &mut Bencher) { } #[bench] -fn iter_all_2d_cutout(bench: &mut Bencher) { +fn iter_all_2d_cutout(bench: &mut Bencher) +{ let a = Array::::zeros((66, 66)); let av = a.slice(s![1..-1, 1..-1]); let a = av; @@ -38,44 +38,58 @@ fn iter_all_2d_cutout(bench: &mut Bencher) { } #[bench] -fn iter_sum_2d_transpose(bench: &mut Bencher) { +fn iter_sum_2d_transpose(bench: &mut Bencher) +{ let a = Array::::zeros((66, 66)); let a = a.t(); bench.iter(|| a.iter().sum::()); } #[bench] -fn iter_filter_sum_2d_u32(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 256).into_shape((16, 16)).unwrap(); +fn iter_filter_sum_2d_u32(bench: &mut Bencher) +{ + let a = Array::linspace(0., 1., 256) + .into_shape_with_order((16, 16)) + .unwrap(); let b = a.mapv(|x| (x * 100.) as u32); bench.iter(|| b.iter().filter(|&&x| x < 75).sum::()); } #[bench] -fn iter_filter_sum_2d_f32(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 256).into_shape((16, 16)).unwrap(); +fn iter_filter_sum_2d_f32(bench: &mut Bencher) +{ + let a = Array::linspace(0., 1., 256) + .into_shape_with_order((16, 16)) + .unwrap(); let b = a * 100.; bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::()); } #[bench] -fn iter_filter_sum_2d_stride_u32(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 256).into_shape((16, 16)).unwrap(); +fn iter_filter_sum_2d_stride_u32(bench: &mut Bencher) +{ + let a = Array::linspace(0., 1., 256) + .into_shape_with_order((16, 16)) + .unwrap(); let b = a.mapv(|x| (x * 100.) as u32); let b = b.slice(s![.., ..;2]); bench.iter(|| b.iter().filter(|&&x| x < 75).sum::()); } #[bench] -fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 256).into_shape((16, 16)).unwrap(); +fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) +{ + let a = Array::linspace(0., 1., 256) + .into_shape_with_order((16, 16)) + .unwrap(); let b = a * 100.; let b = b.slice(s![.., ..;2]); bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::()); } #[bench] -fn iter_rev_step_by_contiguous(bench: &mut Bencher) { +fn iter_rev_step_by_contiguous(bench: &mut Bencher) +{ let a = Array::linspace(0., 1., 512); bench.iter(|| { a.iter().rev().step_by(2).for_each(|x| { @@ -85,7 +99,8 @@ fn iter_rev_step_by_contiguous(bench: &mut Bencher) { } #[bench] -fn iter_rev_step_by_discontiguous(bench: &mut Bencher) { +fn iter_rev_step_by_discontiguous(bench: &mut Bencher) +{ let mut a = Array::linspace(0., 1., 1024); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); bench.iter(|| { @@ -98,7 +113,8 @@ fn iter_rev_step_by_discontiguous(bench: &mut Bencher) { const ZIPSZ: usize = 10_000; #[bench] -fn sum_3_std_zip1(bench: &mut Bencher) { +fn sum_3_std_zip1(bench: &mut Bencher) +{ let a = vec![1; ZIPSZ]; let b = vec![1; ZIPSZ]; let c = vec![1; ZIPSZ]; @@ -110,7 +126,8 @@ fn sum_3_std_zip1(bench: &mut Bencher) { } #[bench] -fn sum_3_std_zip2(bench: &mut Bencher) { +fn sum_3_std_zip2(bench: &mut Bencher) +{ let a = vec![1; ZIPSZ]; let b = vec![1; ZIPSZ]; let c = vec![1; ZIPSZ]; @@ -123,7 +140,8 @@ fn sum_3_std_zip2(bench: &mut Bencher) { } #[bench] -fn sum_3_std_zip3(bench: &mut Bencher) { +fn sum_3_std_zip3(bench: &mut Bencher) +{ let a = vec![1; ZIPSZ]; let b = vec![1; ZIPSZ]; let c = vec![1; ZIPSZ]; @@ -137,7 +155,8 @@ fn sum_3_std_zip3(bench: &mut Bencher) { } #[bench] -fn vector_sum_3_std_zip(bench: &mut Bencher) { +fn vector_sum_3_std_zip(bench: &mut Bencher) +{ let a = vec![1.; ZIPSZ]; let b = vec![1.; ZIPSZ]; let mut c = vec![1.; ZIPSZ]; @@ -149,7 +168,8 @@ fn vector_sum_3_std_zip(bench: &mut Bencher) { } #[bench] -fn sum_3_azip(bench: &mut Bencher) { +fn sum_3_azip(bench: &mut Bencher) +{ let a = vec![1; ZIPSZ]; let b = vec![1; ZIPSZ]; let c = vec![1; ZIPSZ]; @@ -163,7 +183,8 @@ fn sum_3_azip(bench: &mut Bencher) { } #[bench] -fn sum_3_azip_fold(bench: &mut Bencher) { +fn sum_3_azip_fold(bench: &mut Bencher) +{ let a = vec![1; ZIPSZ]; let b = vec![1; ZIPSZ]; let c = vec![1; ZIPSZ]; @@ -177,7 +198,8 @@ fn sum_3_azip_fold(bench: &mut Bencher) { } #[bench] -fn vector_sum_3_azip(bench: &mut Bencher) { +fn vector_sum_3_azip(bench: &mut Bencher) +{ let a = vec![1.; ZIPSZ]; let b = vec![1.; ZIPSZ]; let mut c = vec![1.; ZIPSZ]; @@ -188,7 +210,8 @@ fn vector_sum_3_azip(bench: &mut Bencher) { }); } -fn vector_sum3_unchecked(a: &[f64], b: &[f64], c: &mut [f64]) { +fn vector_sum3_unchecked(a: &[f64], b: &[f64], c: &mut [f64]) +{ for i in 0..c.len() { unsafe { *c.get_unchecked_mut(i) += *a.get_unchecked(i) + *b.get_unchecked(i); @@ -197,7 +220,8 @@ fn vector_sum3_unchecked(a: &[f64], b: &[f64], c: &mut [f64]) { } #[bench] -fn vector_sum_3_zip_unchecked(bench: &mut Bencher) { +fn vector_sum_3_zip_unchecked(bench: &mut Bencher) +{ let a = vec![1.; ZIPSZ]; let b = vec![1.; ZIPSZ]; let mut c = vec![1.; ZIPSZ]; @@ -207,7 +231,8 @@ fn vector_sum_3_zip_unchecked(bench: &mut Bencher) { } #[bench] -fn vector_sum_3_zip_unchecked_manual(bench: &mut Bencher) { +fn vector_sum_3_zip_unchecked_manual(bench: &mut Bencher) +{ let a = vec![1.; ZIPSZ]; let b = vec![1.; ZIPSZ]; let mut c = vec![1.; ZIPSZ]; @@ -227,7 +252,8 @@ const ISZ: usize = 16; const I2DSZ: usize = 64; #[bench] -fn indexed_iter_1d_ix1(bench: &mut Bencher) { +fn indexed_iter_1d_ix1(bench: &mut Bencher) +{ let mut a = Array::::zeros(I2DSZ * I2DSZ); for (i, elt) in a.indexed_iter_mut() { *elt = i as _; @@ -242,14 +268,15 @@ fn indexed_iter_1d_ix1(bench: &mut Bencher) { } #[bench] -fn indexed_zip_1d_ix1(bench: &mut Bencher) { +fn indexed_zip_1d_ix1(bench: &mut Bencher) +{ let mut a = Array::::zeros(I2DSZ * I2DSZ); for (i, elt) in a.indexed_iter_mut() { *elt = i as _; } bench.iter(|| { - Zip::indexed(&a).apply(|i, &_elt| { + Zip::indexed(&a).for_each(|i, &_elt| { black_box(i); //assert!(a[i] == elt); }); @@ -257,7 +284,8 @@ fn indexed_zip_1d_ix1(bench: &mut Bencher) { } #[bench] -fn indexed_iter_2d_ix2(bench: &mut Bencher) { +fn indexed_iter_2d_ix2(bench: &mut Bencher) +{ let mut a = Array::::zeros((I2DSZ, I2DSZ)); for ((i, j), elt) in a.indexed_iter_mut() { *elt = (i + 100 * j) as _; @@ -271,14 +299,15 @@ fn indexed_iter_2d_ix2(bench: &mut Bencher) { }) } #[bench] -fn indexed_zip_2d_ix2(bench: &mut Bencher) { +fn indexed_zip_2d_ix2(bench: &mut Bencher) +{ let mut a = Array::::zeros((I2DSZ, I2DSZ)); for ((i, j), elt) in a.indexed_iter_mut() { *elt = (i + 100 * j) as _; } bench.iter(|| { - Zip::indexed(&a).apply(|i, &_elt| { + Zip::indexed(&a).for_each(|i, &_elt| { black_box(i); //assert!(a[i] == elt); }); @@ -286,7 +315,8 @@ fn indexed_zip_2d_ix2(bench: &mut Bencher) { } #[bench] -fn indexed_iter_3d_ix3(bench: &mut Bencher) { +fn indexed_iter_3d_ix3(bench: &mut Bencher) +{ let mut a = Array::::zeros((ISZ, ISZ, ISZ)); for ((i, j, k), elt) in a.indexed_iter_mut() { *elt = (i + 100 * j + 10000 * k) as _; @@ -301,14 +331,15 @@ fn indexed_iter_3d_ix3(bench: &mut Bencher) { } #[bench] -fn indexed_zip_3d_ix3(bench: &mut Bencher) { +fn indexed_zip_3d_ix3(bench: &mut Bencher) +{ let mut a = Array::::zeros((ISZ, ISZ, ISZ)); for ((i, j, k), elt) in a.indexed_iter_mut() { *elt = (i + 100 * j + 10000 * k) as _; } bench.iter(|| { - Zip::indexed(&a).apply(|i, &_elt| { + Zip::indexed(&a).for_each(|i, &_elt| { black_box(i); //assert!(a[i] == elt); }); @@ -316,12 +347,13 @@ fn indexed_zip_3d_ix3(bench: &mut Bencher) { } #[bench] -fn indexed_iter_3d_dyn(bench: &mut Bencher) { +fn indexed_iter_3d_dyn(bench: &mut Bencher) +{ let mut a = Array::::zeros((ISZ, ISZ, ISZ)); for ((i, j, k), elt) in a.indexed_iter_mut() { *elt = (i + 100 * j + 10000 * k) as _; } - let a = a.into_shape(&[ISZ; 3][..]).unwrap(); + let a = a.into_shape_with_order(&[ISZ; 3][..]).unwrap(); bench.iter(|| { for (i, &_elt) in a.indexed_iter() { @@ -332,27 +364,31 @@ fn indexed_iter_3d_dyn(bench: &mut Bencher) { } #[bench] -fn iter_sum_1d_strided_fold(bench: &mut Bencher) { +fn iter_sum_1d_strided_fold(bench: &mut Bencher) +{ let mut a = Array::::ones(10240); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); bench.iter(|| a.iter().sum::()); } #[bench] -fn iter_sum_1d_strided_rfold(bench: &mut Bencher) { +fn iter_sum_1d_strided_rfold(bench: &mut Bencher) +{ let mut a = Array::::ones(10240); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); bench.iter(|| a.iter().rfold(0, |acc, &x| acc + x)); } #[bench] -fn iter_axis_iter_sum(bench: &mut Bencher) { +fn iter_axis_iter_sum(bench: &mut Bencher) +{ let a = Array::::zeros((64, 64)); bench.iter(|| a.axis_iter(Axis(0)).map(|plane| plane.sum()).sum::()); } #[bench] -fn iter_axis_chunks_1_iter_sum(bench: &mut Bencher) { +fn iter_axis_chunks_1_iter_sum(bench: &mut Bencher) +{ let a = Array::::zeros((64, 64)); bench.iter(|| { a.axis_chunks_iter(Axis(0), 1) @@ -362,7 +398,8 @@ fn iter_axis_chunks_1_iter_sum(bench: &mut Bencher) { } #[bench] -fn iter_axis_chunks_5_iter_sum(bench: &mut Bencher) { +fn iter_axis_chunks_5_iter_sum(bench: &mut Bencher) +{ let a = Array::::zeros((64, 64)); bench.iter(|| { a.axis_chunks_iter(Axis(0), 5) @@ -370,3 +407,26 @@ fn iter_axis_chunks_5_iter_sum(bench: &mut Bencher) { .sum::() }); } + +pub fn zip_mut_with(data: &Array3, out: &mut Array3) +{ + out.zip_mut_with(&data, |o, &i| { + *o = i; + }); +} + +#[bench] +fn zip_mut_with_cc(b: &mut Bencher) +{ + let data: Array3 = Array3::zeros((ISZ, ISZ, ISZ)); + let mut out = Array3::zeros(data.dim()); + b.iter(|| zip_mut_with(&data, &mut out)); +} + +#[bench] +fn zip_mut_with_ff(b: &mut Bencher) +{ + let data: Array3 = Array3::zeros((ISZ, ISZ, ISZ).f()); + let mut out = Array3::zeros(data.dim().f()); + b.iter(|| zip_mut_with(&data, &mut out)); +} diff --git a/benches/numeric.rs b/benches/numeric.rs index 4c579eb71..e2ffa1b84 100644 --- a/benches/numeric.rs +++ b/benches/numeric.rs @@ -10,9 +10,10 @@ const X: usize = 64; const Y: usize = 16; #[bench] -fn clip(bench: &mut Bencher) { +fn clip(bench: &mut Bencher) +{ let mut a = Array::linspace(0., 127., N * 2) - .into_shape([X, Y * 2]) + .into_shape_with_order([X, Y * 2]) .unwrap(); let min = 2.; let max = 5.; diff --git a/benches/par_rayon.rs b/benches/par_rayon.rs index e8c4cfef3..1301ae75a 100644 --- a/benches/par_rayon.rs +++ b/benches/par_rayon.rs @@ -12,7 +12,8 @@ use ndarray::Zip; const EXP_N: usize = 256; const ADDN: usize = 512; -fn set_threads() { +fn set_threads() +{ // Consider setting a fixed number of threads here, for example to avoid // oversubscribing on hyperthreaded cores. // let n = 4; @@ -20,7 +21,8 @@ fn set_threads() { } #[bench] -fn map_exp_regular(bench: &mut Bencher) { +fn map_exp_regular(bench: &mut Bencher) +{ let mut a = Array2::::zeros((EXP_N, EXP_N)); a.swap_axes(0, 1); bench.iter(|| { @@ -29,7 +31,8 @@ fn map_exp_regular(bench: &mut Bencher) { } #[bench] -fn rayon_exp_regular(bench: &mut Bencher) { +fn rayon_exp_regular(bench: &mut Bencher) +{ set_threads(); let mut a = Array2::::zeros((EXP_N, EXP_N)); a.swap_axes(0, 1); @@ -41,19 +44,22 @@ fn rayon_exp_regular(bench: &mut Bencher) { const FASTEXP: usize = EXP_N; #[inline] -fn fastexp(x: f64) -> f64 { +fn fastexp(x: f64) -> f64 +{ let x = 1. + x / 1024.; x.powi(1024) } #[bench] -fn map_fastexp_regular(bench: &mut Bencher) { +fn map_fastexp_regular(bench: &mut Bencher) +{ let mut a = Array2::::zeros((FASTEXP, FASTEXP)); bench.iter(|| a.mapv_inplace(|x| fastexp(x))); } #[bench] -fn rayon_fastexp_regular(bench: &mut Bencher) { +fn rayon_fastexp_regular(bench: &mut Bencher) +{ set_threads(); let mut a = Array2::::zeros((FASTEXP, FASTEXP)); bench.iter(|| { @@ -62,14 +68,16 @@ fn rayon_fastexp_regular(bench: &mut Bencher) { } #[bench] -fn map_fastexp_cut(bench: &mut Bencher) { +fn map_fastexp_cut(bench: &mut Bencher) +{ let mut a = Array2::::zeros((FASTEXP, FASTEXP)); let mut a = a.slice_mut(s![.., ..-1]); bench.iter(|| a.mapv_inplace(|x| fastexp(x))); } #[bench] -fn rayon_fastexp_cut(bench: &mut Bencher) { +fn rayon_fastexp_cut(bench: &mut Bencher) +{ set_threads(); let mut a = Array2::::zeros((FASTEXP, FASTEXP)); let mut a = a.slice_mut(s![.., ..-1]); @@ -79,7 +87,8 @@ fn rayon_fastexp_cut(bench: &mut Bencher) { } #[bench] -fn map_fastexp_by_axis(bench: &mut Bencher) { +fn map_fastexp_by_axis(bench: &mut Bencher) +{ let mut a = Array2::::zeros((FASTEXP, FASTEXP)); bench.iter(|| { for mut sheet in a.axis_iter_mut(Axis(0)) { @@ -89,7 +98,8 @@ fn map_fastexp_by_axis(bench: &mut Bencher) { } #[bench] -fn rayon_fastexp_by_axis(bench: &mut Bencher) { +fn rayon_fastexp_by_axis(bench: &mut Bencher) +{ set_threads(); let mut a = Array2::::zeros((FASTEXP, FASTEXP)); bench.iter(|| { @@ -100,7 +110,8 @@ fn rayon_fastexp_by_axis(bench: &mut Bencher) { } #[bench] -fn rayon_fastexp_zip(bench: &mut Bencher) { +fn rayon_fastexp_zip(bench: &mut Bencher) +{ set_threads(); let mut a = Array2::::zeros((FASTEXP, FASTEXP)); bench.iter(|| { @@ -111,7 +122,8 @@ fn rayon_fastexp_zip(bench: &mut Bencher) { } #[bench] -fn add(bench: &mut Bencher) { +fn add(bench: &mut Bencher) +{ let mut a = Array2::::zeros((ADDN, ADDN)); let b = Array2::::zeros((ADDN, ADDN)); let c = Array2::::zeros((ADDN, ADDN)); @@ -124,7 +136,8 @@ fn add(bench: &mut Bencher) { } #[bench] -fn rayon_add(bench: &mut Bencher) { +fn rayon_add(bench: &mut Bencher) +{ set_threads(); let mut a = Array2::::zeros((ADDN, ADDN)); let b = Array2::::zeros((ADDN, ADDN)); @@ -136,3 +149,34 @@ fn rayon_add(bench: &mut Bencher) { }); }); } + +const COLL_STRING_N: usize = 64; +const COLL_F64_N: usize = 128; + +#[bench] +fn vec_string_collect(bench: &mut test::Bencher) +{ + let v = vec![""; COLL_STRING_N * COLL_STRING_N]; + bench.iter(|| v.iter().map(|s| s.to_owned()).collect::>()); +} + +#[bench] +fn array_string_collect(bench: &mut test::Bencher) +{ + let v = Array::from_elem((COLL_STRING_N, COLL_STRING_N), ""); + bench.iter(|| Zip::from(&v).par_map_collect(|s| s.to_owned())); +} + +#[bench] +fn vec_f64_collect(bench: &mut test::Bencher) +{ + let v = vec![1.; COLL_F64_N * COLL_F64_N]; + bench.iter(|| v.iter().map(|s| s + 1.).collect::>()); +} + +#[bench] +fn array_f64_collect(bench: &mut test::Bencher) +{ + let v = Array::from_elem((COLL_F64_N, COLL_F64_N), 1.); + bench.iter(|| Zip::from(&v).par_map_collect(|s| s + 1.)); +} diff --git a/benches/reserve.rs b/benches/reserve.rs new file mode 100644 index 000000000..14ebf9f1a --- /dev/null +++ b/benches/reserve.rs @@ -0,0 +1,31 @@ +#![feature(test)] + +extern crate test; +use test::Bencher; + +use ndarray::prelude::*; + +#[bench] +fn push_reserve(bench: &mut Bencher) +{ + let ones: Array = array![1f32]; + bench.iter(|| { + let mut a: Array = array![]; + a.reserve(Axis(0), 100).unwrap(); + for _ in 0..100 { + a.append(Axis(0), ones.view()).unwrap(); + } + }); +} + +#[bench] +fn push_no_reserve(bench: &mut Bencher) +{ + let ones: Array = array![1f32]; + bench.iter(|| { + let mut a: Array = array![]; + for _ in 0..100 { + a.append(Axis(0), ones.view()).unwrap(); + } + }); +} diff --git a/benches/to_shape.rs b/benches/to_shape.rs new file mode 100644 index 000000000..f056a9852 --- /dev/null +++ b/benches/to_shape.rs @@ -0,0 +1,95 @@ +#![feature(test)] + +extern crate test; +use test::Bencher; + +use ndarray::prelude::*; +use ndarray::Order; + +#[bench] +fn to_shape2_1(bench: &mut Bencher) +{ + let a = Array::::zeros((4, 5)); + let view = a.view(); + bench.iter(|| view.to_shape(4 * 5).unwrap()); +} + +#[bench] +fn to_shape2_2_same(bench: &mut Bencher) +{ + let a = Array::::zeros((4, 5)); + let view = a.view(); + bench.iter(|| view.to_shape((4, 5)).unwrap()); +} + +#[bench] +fn to_shape2_2_flip(bench: &mut Bencher) +{ + let a = Array::::zeros((4, 5)); + let view = a.view(); + bench.iter(|| view.to_shape((5, 4)).unwrap()); +} + +#[bench] +fn to_shape2_3(bench: &mut Bencher) +{ + let a = Array::::zeros((4, 5)); + let view = a.view(); + bench.iter(|| view.to_shape((2, 5, 2)).unwrap()); +} + +#[bench] +fn to_shape3_1(bench: &mut Bencher) +{ + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| view.to_shape(3 * 4 * 5).unwrap()); +} + +#[bench] +fn to_shape3_2_order(bench: &mut Bencher) +{ + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| view.to_shape((12, 5)).unwrap()); +} + +#[bench] +fn to_shape3_2_outoforder(bench: &mut Bencher) +{ + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| view.to_shape((4, 15)).unwrap()); +} + +#[bench] +fn to_shape3_3c(bench: &mut Bencher) +{ + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| view.to_shape((3, 4, 5)).unwrap()); +} + +#[bench] +fn to_shape3_3f(bench: &mut Bencher) +{ + let a = Array::::zeros((3, 4, 5).f()); + let view = a.view(); + bench.iter(|| view.to_shape(((3, 4, 5), Order::F)).unwrap()); +} + +#[bench] +fn to_shape3_4c(bench: &mut Bencher) +{ + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| view.to_shape(((2, 3, 2, 5), Order::C)).unwrap()); +} + +#[bench] +fn to_shape3_4f(bench: &mut Bencher) +{ + let a = Array::::zeros((3, 4, 5).f()); + let view = a.view(); + bench.iter(|| view.to_shape(((2, 3, 2, 5), Order::F)).unwrap()); +} diff --git a/benches/zip.rs b/benches/zip.rs new file mode 100644 index 000000000..461497310 --- /dev/null +++ b/benches/zip.rs @@ -0,0 +1,133 @@ +#![feature(test)] +extern crate test; +use ndarray::s; +use ndarray::IntoNdProducer; +use ndarray::{Array3, ShapeBuilder, Zip}; +use test::{black_box, Bencher}; + +pub fn zip_copy<'a, A, P, Q>(data: P, out: Q) +where + P: IntoNdProducer, + Q: IntoNdProducer, + A: Copy + 'a, +{ + Zip::from(data).and(out).for_each(|&i, o| { + *o = i; + }); +} + +pub fn zip_copy_split<'a, A, P, Q>(data: P, out: Q) +where + P: IntoNdProducer, + Q: IntoNdProducer, + A: Copy + 'a, +{ + let z = Zip::from(data).and(out); + let (z1, z2) = z.split(); + let (z11, z12) = z1.split(); + let (z21, z22) = z2.split(); + let f = |&i: &A, o: &mut A| *o = i; + z11.for_each(f); + z12.for_each(f); + z21.for_each(f); + z22.for_each(f); +} + +pub fn zip_indexed(data: &Array3, out: &mut Array3) +{ + Zip::indexed(data).and(out).for_each(|idx, &i, o| { + let _ = black_box(idx); + *o = i; + }); +} + +// array size in benchmarks +const SZ3: (usize, usize, usize) = (100, 110, 100); + +#[bench] +fn zip_cc(b: &mut Bencher) +{ + let data: Array3 = Array3::zeros(SZ3); + let mut out = Array3::zeros(data.dim()); + b.iter(|| zip_copy(&data, &mut out)); +} + +#[bench] +fn zip_cf(b: &mut Bencher) +{ + let data: Array3 = Array3::zeros(SZ3); + let mut out = Array3::zeros(data.dim().f()); + b.iter(|| zip_copy(&data, &mut out)); +} + +#[bench] +fn zip_fc(b: &mut Bencher) +{ + let data: Array3 = Array3::zeros(SZ3.f()); + let mut out = Array3::zeros(data.dim()); + b.iter(|| zip_copy(&data, &mut out)); +} + +#[bench] +fn zip_ff(b: &mut Bencher) +{ + let data: Array3 = Array3::zeros(SZ3.f()); + let mut out = Array3::zeros(data.dim().f()); + b.iter(|| zip_copy(&data, &mut out)); +} + +#[bench] +fn zip_indexed_cc(b: &mut Bencher) +{ + let data: Array3 = Array3::zeros(SZ3); + let mut out = Array3::zeros(data.dim()); + b.iter(|| zip_indexed(&data, &mut out)); +} + +#[bench] +fn zip_indexed_ff(b: &mut Bencher) +{ + let data: Array3 = Array3::zeros(SZ3.f()); + let mut out = Array3::zeros(data.dim().f()); + b.iter(|| zip_indexed(&data, &mut out)); +} + +#[bench] +fn slice_zip_cc(b: &mut Bencher) +{ + let data: Array3 = Array3::zeros(SZ3); + let mut out = Array3::zeros(data.dim()); + let data = data.slice(s![1.., 1.., 1..]); + let mut out = out.slice_mut(s![1.., 1.., 1..]); + b.iter(|| zip_copy(&data, &mut out)); +} + +#[bench] +fn slice_zip_ff(b: &mut Bencher) +{ + let data: Array3 = Array3::zeros(SZ3.f()); + let mut out = Array3::zeros(data.dim().f()); + let data = data.slice(s![1.., 1.., 1..]); + let mut out = out.slice_mut(s![1.., 1.., 1..]); + b.iter(|| zip_copy(&data, &mut out)); +} + +#[bench] +fn slice_split_zip_cc(b: &mut Bencher) +{ + let data: Array3 = Array3::zeros(SZ3); + let mut out = Array3::zeros(data.dim()); + let data = data.slice(s![1.., 1.., 1..]); + let mut out = out.slice_mut(s![1.., 1.., 1..]); + b.iter(|| zip_copy_split(&data, &mut out)); +} + +#[bench] +fn slice_split_zip_ff(b: &mut Bencher) +{ + let data: Array3 = Array3::zeros(SZ3.f()); + let mut out = Array3::zeros(data.dim().f()); + let data = data.slice(s![1.., 1.., 1..]); + let mut out = out.slice_mut(s![1.., 1.., 1..]); + b.iter(|| zip_copy_split(&data, &mut out)); +} diff --git a/blas-tests/Cargo.toml b/blas-tests/Cargo.toml deleted file mode 100644 index 9853ac634..000000000 --- a/blas-tests/Cargo.toml +++ /dev/null @@ -1,16 +0,0 @@ -[package] -name = "blas-tests" -version = "0.1.0" -authors = ["bluss"] -publish = false - -[lib] -test = false - -[dev-dependencies] -approx = "0.3.2" -ndarray = { path = "../", features = ["approx", "blas"] } -blas-src = { version = "0.2.0", default-features = false, features = ["openblas"] } -openblas-src = { version = "0.6.0", default-features = false, features = ["cblas", "system"] } -defmac = "0.2" -num-traits = "0.2" diff --git a/blas-tests/tests/oper.rs b/blas-tests/tests/oper.rs deleted file mode 100644 index 2741123f9..000000000 --- a/blas-tests/tests/oper.rs +++ /dev/null @@ -1,614 +0,0 @@ -extern crate approx; -extern crate defmac; -extern crate ndarray; -extern crate num_traits; - -use ndarray::linalg::general_mat_mul; -use ndarray::linalg::general_mat_vec_mul; -use ndarray::prelude::*; -use ndarray::{Data, LinalgScalar}; -use ndarray::{Ix, Ixs, SliceInfo, SliceOrIndex}; -use std::iter::FromIterator; - -use approx::{assert_abs_diff_eq, assert_relative_eq}; -use defmac::defmac; - -fn reference_dot<'a, A, V1, V2>(a: V1, b: V2) -> A -where - A: NdFloat, - V1: AsArray<'a, A>, - V2: AsArray<'a, A>, -{ - let a = a.into(); - let b = b.into(); - a.iter() - .zip(b.iter()) - .fold(A::zero(), |acc, (&x, &y)| acc + x * y) -} - -#[test] -fn dot_product() { - let a = Array::range(0., 69., 1.); - let b = &a * 2. - 7.; - let dot = 197846.; - assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5); - - // test different alignments - let max = 8 as Ixs; - for i in 1..max { - let a1 = a.slice(s![i..]); - let b1 = b.slice(s![i..]); - assert_abs_diff_eq!(a1.dot(&b1), reference_dot(&a1, &b1), epsilon = 1e-5); - let a2 = a.slice(s![..-i]); - let b2 = b.slice(s![i..]); - assert_abs_diff_eq!(a2.dot(&b2), reference_dot(&a2, &b2), epsilon = 1e-5); - } - - let a = a.map(|f| *f as f32); - let b = b.map(|f| *f as f32); - assert_abs_diff_eq!(a.dot(&b), dot as f32, epsilon = 1e-5); - - let max = 8 as Ixs; - for i in 1..max { - let a1 = a.slice(s![i..]); - let b1 = b.slice(s![i..]); - assert_abs_diff_eq!(a1.dot(&b1), reference_dot(&a1, &b1), epsilon = 1e-5); - let a2 = a.slice(s![..-i]); - let b2 = b.slice(s![i..]); - assert_abs_diff_eq!(a2.dot(&b2), reference_dot(&a2, &b2), epsilon = 1e-5); - } - - let a = a.map(|f| *f as i32); - let b = b.map(|f| *f as i32); - assert_eq!(a.dot(&b), dot as i32); -} - -#[test] -fn mat_vec_product_1d() { - let a = arr2(&[[1.], [2.]]); - let b = arr1(&[1., 2.]); - let ans = arr1(&[5.]); - assert_eq!(a.t().dot(&b), ans); -} - -// test that we can dot product with a broadcast array -#[test] -fn dot_product_0() { - let a = Array::range(0., 69., 1.); - let x = 1.5; - let b = aview0(&x); - let b = b.broadcast(a.dim()).unwrap(); - assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5); - - // test different alignments - let max = 8 as Ixs; - for i in 1..max { - let a1 = a.slice(s![i..]); - let b1 = b.slice(s![i..]); - assert_abs_diff_eq!(a1.dot(&b1), reference_dot(&a1, &b1), epsilon = 1e-5); - let a2 = a.slice(s![..-i]); - let b2 = b.slice(s![i..]); - assert_abs_diff_eq!(a2.dot(&b2), reference_dot(&a2, &b2), epsilon = 1e-5); - } -} - -#[test] -fn dot_product_neg_stride() { - // test that we can dot with negative stride - let a = Array::range(0., 69., 1.); - let b = &a * 2. - 7.; - for stride in -10..0 { - // both negative - let a = a.slice(s![..;stride]); - let b = b.slice(s![..;stride]); - assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5); - } - for stride in -10..0 { - // mixed - let a = a.slice(s![..;-stride]); - let b = b.slice(s![..;stride]); - assert_abs_diff_eq!(a.dot(&b), reference_dot(&a, &b), epsilon = 1e-5); - } -} - -fn range_mat(m: Ix, n: Ix) -> Array2 { - Array::linspace(0., (m * n) as f32 - 1., m * n) - .into_shape((m, n)) - .unwrap() -} - -fn range_mat64(m: Ix, n: Ix) -> Array2 { - Array::linspace(0., (m * n) as f64 - 1., m * n) - .into_shape((m, n)) - .unwrap() -} - -fn range1_mat64(m: Ix) -> Array1 { - Array::linspace(0., m as f64 - 1., m) -} - -fn range_i32(m: Ix, n: Ix) -> Array2 { - Array::from_iter(0..(m * n) as i32) - .into_shape((m, n)) - .unwrap() -} - -// simple, slow, correct (hopefully) mat mul -fn reference_mat_mul(lhs: &ArrayBase, rhs: &ArrayBase) -> Array2 -where - A: LinalgScalar, - S: Data, - S2: Data, -{ - let ((m, k), (k2, n)) = (lhs.dim(), rhs.dim()); - assert!(m.checked_mul(n).is_some()); - assert_eq!(k, k2); - let mut res_elems = Vec::::with_capacity(m * n); - unsafe { - res_elems.set_len(m * n); - } - - let mut i = 0; - let mut j = 0; - for rr in &mut res_elems { - unsafe { - *rr = (0..k).fold(A::zero(), move |s, x| { - s + *lhs.uget((i, x)) * *rhs.uget((x, j)) - }); - } - j += 1; - if j == n { - j = 0; - i += 1; - } - } - unsafe { ArrayBase::from_shape_vec_unchecked((m, n), res_elems) } -} - -// simple, slow, correct (hopefully) mat mul -fn reference_mat_vec_mul(lhs: &ArrayBase, rhs: &ArrayBase) -> Array1 -where - A: LinalgScalar, - S: Data, - S2: Data, -{ - let ((m, _), k) = (lhs.dim(), rhs.dim()); - reference_mat_mul(lhs, &rhs.to_owned().into_shape((k, 1)).unwrap()) - .into_shape(m) - .unwrap() -} - -// simple, slow, correct (hopefully) mat mul -fn reference_vec_mat_mul(lhs: &ArrayBase, rhs: &ArrayBase) -> Array1 -where - A: LinalgScalar, - S: Data, - S2: Data, -{ - let (m, (_, n)) = (lhs.dim(), rhs.dim()); - reference_mat_mul(&lhs.to_owned().into_shape((1, m)).unwrap(), rhs) - .into_shape(n) - .unwrap() -} - -#[test] -fn mat_mul() { - let (m, n, k) = (8, 8, 8); - let a = range_mat(m, n); - let b = range_mat(n, k); - let mut b = b / 4.; - { - let mut c = b.column_mut(0); - c += 1.0; - } - let ab = a.dot(&b); - - let mut af = Array::zeros(a.dim().f()); - let mut bf = Array::zeros(b.dim().f()); - af.assign(&a); - bf.assign(&b); - - assert_eq!(ab, a.dot(&bf)); - assert_eq!(ab, af.dot(&b)); - assert_eq!(ab, af.dot(&bf)); - - let (m, n, k) = (10, 5, 11); - let a = range_mat(m, n); - let b = range_mat(n, k); - let mut b = b / 4.; - { - let mut c = b.column_mut(0); - c += 1.0; - } - let ab = a.dot(&b); - - let mut af = Array::zeros(a.dim().f()); - let mut bf = Array::zeros(b.dim().f()); - af.assign(&a); - bf.assign(&b); - - assert_eq!(ab, a.dot(&bf)); - assert_eq!(ab, af.dot(&b)); - assert_eq!(ab, af.dot(&bf)); - - let (m, n, k) = (10, 8, 1); - let a = range_mat(m, n); - let b = range_mat(n, k); - let mut b = b / 4.; - { - let mut c = b.column_mut(0); - c += 1.0; - } - let ab = a.dot(&b); - - let mut af = Array::zeros(a.dim().f()); - let mut bf = Array::zeros(b.dim().f()); - af.assign(&a); - bf.assign(&b); - - assert_eq!(ab, a.dot(&bf)); - assert_eq!(ab, af.dot(&b)); - assert_eq!(ab, af.dot(&bf)); -} - -// Check that matrix multiplication of contiguous matrices returns a -// matrix with the same order -#[test] -fn mat_mul_order() { - let (m, n, k) = (8, 8, 8); - let a = range_mat(m, n); - let b = range_mat(n, k); - let mut af = Array::zeros(a.dim().f()); - let mut bf = Array::zeros(b.dim().f()); - af.assign(&a); - bf.assign(&b); - - let cc = a.dot(&b); - let ff = af.dot(&bf); - - assert_eq!(cc.strides()[1], 1); - assert_eq!(ff.strides()[0], 1); -} - -// test matrix multiplication shape mismatch -#[test] -#[should_panic] -fn mat_mul_shape_mismatch() { - let (m, k, k2, n) = (8, 8, 9, 8); - let a = range_mat(m, k); - let b = range_mat(k2, n); - a.dot(&b); -} - -// test matrix multiplication shape mismatch -#[test] -#[should_panic] -fn mat_mul_shape_mismatch_2() { - let (m, k, k2, n) = (8, 8, 8, 8); - let a = range_mat(m, k); - let b = range_mat(k2, n); - let mut c = range_mat(m, n + 1); - general_mat_mul(1., &a, &b, 1., &mut c); -} - -// Check that matrix multiplication -// supports broadcast arrays. -#[test] -fn mat_mul_broadcast() { - let (m, n, k) = (16, 16, 16); - let a = range_mat(m, n); - let x1 = 1.; - let x = Array::from(vec![x1]); - let b0 = x.broadcast((n, k)).unwrap(); - let b1 = Array::from_elem(n, x1); - let b1 = b1.broadcast((n, k)).unwrap(); - let b2 = Array::from_elem((n, k), x1); - - let c2 = a.dot(&b2); - let c1 = a.dot(&b1); - let c0 = a.dot(&b0); - assert_eq!(c2, c1); - assert_eq!(c2, c0); -} - -// Check that matrix multiplication supports reversed axes -#[test] -fn mat_mul_rev() { - let (m, n, k) = (16, 16, 16); - let a = range_mat(m, n); - let b = range_mat(n, k); - let mut rev = Array::zeros(b.dim()); - let mut rev = rev.slice_mut(s![..;-1, ..]); - rev.assign(&b); - println!("{:.?}", rev); - - let c1 = a.dot(&b); - let c2 = a.dot(&rev); - assert_eq!(c1, c2); -} - -// Check that matrix multiplication supports arrays with zero rows or columns -#[test] -fn mat_mut_zero_len() { - defmac!(mat_mul_zero_len range_mat_fn => { - for n in 0..4 { - for m in 0..4 { - let a = range_mat_fn(m, n); - let b = range_mat_fn(n, 0); - assert_eq!(a.dot(&b), Array2::zeros((m, 0))); - } - for k in 0..4 { - let a = range_mat_fn(0, n); - let b = range_mat_fn(n, k); - assert_eq!(a.dot(&b), Array2::zeros((0, k))); - } - } - }); - mat_mul_zero_len!(range_mat); - mat_mul_zero_len!(range_mat64); - mat_mul_zero_len!(range_i32); -} - -#[test] -fn scaled_add() { - let a = range_mat(16, 15); - let mut b = range_mat(16, 15); - b.mapv_inplace(f32::exp); - - let alpha = 0.2_f32; - let mut c = a.clone(); - c.scaled_add(alpha, &b); - - let d = alpha * &b + &a; - assert_eq!(c, d); -} - -#[test] -fn scaled_add_2() { - let beta = -2.3; - let sizes = vec![ - (4, 4, 1, 4), - (8, 8, 1, 8), - (17, 15, 17, 15), - (4, 17, 4, 17), - (17, 3, 1, 3), - (19, 18, 19, 18), - (16, 17, 16, 17), - (15, 16, 15, 16), - (67, 63, 1, 63), - ]; - // test different strides - for &s1 in &[1, 2, -1, -2] { - for &s2 in &[1, 2, -1, -2] { - for &(m, k, n, q) in &sizes { - let mut a = range_mat64(m, k); - let mut answer = a.clone(); - let c = range_mat64(n, q); - - { - let mut av = a.slice_mut(s![..;s1, ..;s2]); - let c = c.slice(s![..;s1, ..;s2]); - - let mut answerv = answer.slice_mut(s![..;s1, ..;s2]); - answerv += &(beta * &c); - av.scaled_add(beta, &c); - } - assert_relative_eq!(a, answer, epsilon = 1e-12, max_relative = 1e-7); - } - } - } -} - -#[test] -fn scaled_add_3() { - let beta = -2.3; - let sizes = vec![ - (4, 4, 1, 4), - (8, 8, 1, 8), - (17, 15, 17, 15), - (4, 17, 4, 17), - (17, 3, 1, 3), - (19, 18, 19, 18), - (16, 17, 16, 17), - (15, 16, 15, 16), - (67, 63, 1, 63), - ]; - // test different strides - for &s1 in &[1, 2, -1, -2] { - for &s2 in &[1, 2, -1, -2] { - for &(m, k, n, q) in &sizes { - let mut a = range_mat64(m, k); - let mut answer = a.clone(); - let cdim = if n == 1 { vec![q] } else { vec![n, q] }; - let cslice = if n == 1 { - vec![SliceOrIndex::from(..).step_by(s2)] - } else { - vec![ - SliceOrIndex::from(..).step_by(s1), - SliceOrIndex::from(..).step_by(s2), - ] - }; - - let c = range_mat64(n, q).into_shape(cdim).unwrap(); - - { - let mut av = a.slice_mut(s![..;s1, ..;s2]); - let c = c.slice(SliceInfo::<_, IxDyn>::new(cslice).unwrap().as_ref()); - - let mut answerv = answer.slice_mut(s![..;s1, ..;s2]); - answerv += &(beta * &c); - av.scaled_add(beta, &c); - } - assert_relative_eq!(a, answer, epsilon = 1e-12, max_relative = 1e-7); - } - } - } -} - -#[test] -fn gen_mat_mul() { - let alpha = -2.3; - let beta = 3.14; - let sizes = vec![ - (4, 4, 4), - (8, 8, 8), - (17, 15, 16), - (4, 17, 3), - (17, 3, 22), - (19, 18, 2), - (16, 17, 15), - (15, 16, 17), - (67, 63, 62), - ]; - // test different strides - for &s1 in &[1, 2, -1, -2] { - for &s2 in &[1, 2, -1, -2] { - for &(m, k, n) in &sizes { - let a = range_mat64(m, k); - let b = range_mat64(k, n); - let mut c = range_mat64(m, n); - let mut answer = c.clone(); - - { - let a = a.slice(s![..;s1, ..;s2]); - let b = b.slice(s![..;s2, ..;s2]); - let mut cv = c.slice_mut(s![..;s1, ..;s2]); - - let answer_part = alpha * reference_mat_mul(&a, &b) + beta * &cv; - answer.slice_mut(s![..;s1, ..;s2]).assign(&answer_part); - - general_mat_mul(alpha, &a, &b, beta, &mut cv); - } - assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7); - } - } - } -} - -// Test y = A x where A is f-order -#[test] -fn gemm_64_1_f() { - let a = range_mat64(64, 64).reversed_axes(); - let (m, n) = a.dim(); - // m x n times n x 1 == m x 1 - let x = range_mat64(n, 1); - let mut y = range_mat64(m, 1); - let answer = reference_mat_mul(&a, &x) + &y; - general_mat_mul(1.0, &a, &x, 1.0, &mut y); - assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7); -} - -#[test] -fn gen_mat_mul_i32() { - let alpha = -1; - let beta = 2; - let sizes = vec![ - (4, 4, 4), - (8, 8, 8), - (17, 15, 16), - (4, 17, 3), - (17, 3, 22), - (19, 18, 2), - (16, 17, 15), - (15, 16, 17), - (67, 63, 62), - ]; - for &(m, k, n) in &sizes { - let a = range_i32(m, k); - let b = range_i32(k, n); - let mut c = range_i32(m, n); - - let answer = alpha * reference_mat_mul(&a, &b) + beta * &c; - general_mat_mul(alpha, &a, &b, beta, &mut c); - assert_eq!(&c, &answer); - } -} - -#[test] -fn gen_mat_vec_mul() { - let alpha = -2.3; - let beta = 3.14; - let sizes = vec![ - (4, 4), - (8, 8), - (17, 15), - (4, 17), - (17, 3), - (19, 18), - (16, 17), - (15, 16), - (67, 63), - ]; - // test different strides - for &s1 in &[1, 2, -1, -2] { - for &s2 in &[1, 2, -1, -2] { - for &(m, k) in &sizes { - for &rev in &[false, true] { - let mut a = range_mat64(m, k); - if rev { - a = a.reversed_axes(); - } - let (m, k) = a.dim(); - let b = range1_mat64(k); - let mut c = range1_mat64(m); - let mut answer = c.clone(); - - { - let a = a.slice(s![..;s1, ..;s2]); - let b = b.slice(s![..;s2]); - let mut cv = c.slice_mut(s![..;s1]); - - let answer_part = alpha * reference_mat_vec_mul(&a, &b) + beta * &cv; - answer.slice_mut(s![..;s1]).assign(&answer_part); - - general_mat_vec_mul(alpha, &a, &b, beta, &mut cv); - } - assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7); - } - } - } - } -} - -#[test] -fn vec_mat_mul() { - let sizes = vec![ - (4, 4), - (8, 8), - (17, 15), - (4, 17), - (17, 3), - (19, 18), - (16, 17), - (15, 16), - (67, 63), - ]; - // test different strides - for &s1 in &[1, 2, -1, -2] { - for &s2 in &[1, 2, -1, -2] { - for &(m, n) in &sizes { - for &rev in &[false, true] { - let mut b = range_mat64(m, n); - if rev { - b = b.reversed_axes(); - } - let (m, n) = b.dim(); - let a = range1_mat64(m); - let mut c = range1_mat64(n); - let mut answer = c.clone(); - - { - let b = b.slice(s![..;s1, ..;s2]); - let a = a.slice(s![..;s1]); - - let answer_part = reference_vec_mat_mul(&a, &b); - answer.slice_mut(s![..;s2]).assign(&answer_part); - - c.slice_mut(s![..;s2]).assign(&a.dot(&b)); - } - assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7); - } - } - } - } -} diff --git a/build.rs b/build.rs deleted file mode 100644 index ceeeff389..000000000 --- a/build.rs +++ /dev/null @@ -1,10 +0,0 @@ -/// -/// This build script emits the openblas linking directive if requested -/// - -fn main() { - println!("cargo:rerun-if-changed=build.rs"); - if cfg!(feature = "test-blas-openblas-sys") { - println!("cargo:rustc-link-lib={}=openblas", "dylib"); - } -} diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 000000000..5ef4300ee --- /dev/null +++ b/clippy.toml @@ -0,0 +1 @@ +single-char-binding-names-threshold = 1000 diff --git a/crates/blas-mock-tests/Cargo.toml b/crates/blas-mock-tests/Cargo.toml new file mode 100644 index 000000000..39ef9cf99 --- /dev/null +++ b/crates/blas-mock-tests/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "blas-mock-tests" +version = "0.1.0" +edition = "2018" +publish = false + +[lib] +test = false +doc = false +doctest = false + +[dependencies] +cblas-sys = { workspace = true } + +[dev-dependencies] +ndarray = { workspace = true, features = ["approx", "blas"] } +ndarray-gen = { workspace = true } +itertools = { workspace = true } diff --git a/crates/blas-mock-tests/src/lib.rs b/crates/blas-mock-tests/src/lib.rs new file mode 100644 index 000000000..11fc5975e --- /dev/null +++ b/crates/blas-mock-tests/src/lib.rs @@ -0,0 +1,100 @@ +//! Mock interfaces to BLAS + +use core::cell::RefCell; +use core::ffi::{c_double, c_float, c_int}; +use std::thread_local; + +use cblas_sys::{c_double_complex, c_float_complex, CBLAS_LAYOUT, CBLAS_TRANSPOSE}; + +thread_local! { + /// This counter is incremented every time a gemm function is called + pub static CALL_COUNT: RefCell = RefCell::new(0); +} + +#[rustfmt::skip] +#[no_mangle] +#[allow(unused)] +pub unsafe extern "C" fn cblas_sgemm( + layout: CBLAS_LAYOUT, + transa: CBLAS_TRANSPOSE, + transb: CBLAS_TRANSPOSE, + m: c_int, + n: c_int, + k: c_int, + alpha: c_float, + a: *const c_float, + lda: c_int, + b: *const c_float, + ldb: c_int, + beta: c_float, + c: *mut c_float, + ldc: c_int +) { + CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1); +} + +#[rustfmt::skip] +#[no_mangle] +#[allow(unused)] +pub unsafe extern "C" fn cblas_dgemm( + layout: CBLAS_LAYOUT, + transa: CBLAS_TRANSPOSE, + transb: CBLAS_TRANSPOSE, + m: c_int, + n: c_int, + k: c_int, + alpha: c_double, + a: *const c_double, + lda: c_int, + b: *const c_double, + ldb: c_int, + beta: c_double, + c: *mut c_double, + ldc: c_int +) { + CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1); +} + +#[rustfmt::skip] +#[no_mangle] +#[allow(unused)] +pub unsafe extern "C" fn cblas_cgemm( + layout: CBLAS_LAYOUT, + transa: CBLAS_TRANSPOSE, + transb: CBLAS_TRANSPOSE, + m: c_int, + n: c_int, + k: c_int, + alpha: *const c_float_complex, + a: *const c_float_complex, + lda: c_int, + b: *const c_float_complex, + ldb: c_int, + beta: *const c_float_complex, + c: *mut c_float_complex, + ldc: c_int +) { + CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1); +} + +#[rustfmt::skip] +#[no_mangle] +#[allow(unused)] +pub unsafe extern "C" fn cblas_zgemm( + layout: CBLAS_LAYOUT, + transa: CBLAS_TRANSPOSE, + transb: CBLAS_TRANSPOSE, + m: c_int, + n: c_int, + k: c_int, + alpha: *const c_double_complex, + a: *const c_double_complex, + lda: c_int, + b: *const c_double_complex, + ldb: c_int, + beta: *const c_double_complex, + c: *mut c_double_complex, + ldc: c_int +) { + CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1); +} diff --git a/crates/blas-mock-tests/tests/use-blas.rs b/crates/blas-mock-tests/tests/use-blas.rs new file mode 100644 index 000000000..217508af6 --- /dev/null +++ b/crates/blas-mock-tests/tests/use-blas.rs @@ -0,0 +1,88 @@ +extern crate ndarray; + +use ndarray::prelude::*; + +use blas_mock_tests::CALL_COUNT; +use ndarray::linalg::general_mat_mul; +use ndarray::Order; +use ndarray_gen::array_builder::ArrayBuilder; + +use itertools::iproduct; + +#[test] +fn test_gen_mat_mul_uses_blas() +{ + let alpha = 1.0; + let beta = 0.0; + + let sizes = vec![ + (8, 8, 8), + (10, 10, 10), + (8, 8, 1), + (1, 10, 10), + (10, 1, 10), + (10, 10, 1), + (1, 10, 1), + (10, 1, 1), + (1, 1, 10), + (4, 17, 3), + (17, 3, 22), + (19, 18, 2), + (16, 17, 15), + (15, 16, 17), + (67, 63, 62), + ]; + let strides = &[1, 2, -1, -2]; + let cf_order = [Order::C, Order::F]; + + // test different strides and memory orders + for &(m, k, n) in &sizes { + for (&s1, &s2) in iproduct!(strides, strides) { + for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) { + println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3); + + let a = ArrayBuilder::new((m, k)).memory_order(ord1).build(); + let b = ArrayBuilder::new((k, n)).memory_order(ord2).build(); + let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build(); + + { + let av; + let bv; + let mut cv; + + if s1 != 1 || s2 != 1 { + av = a.slice(s![..;s1, ..;s2]); + bv = b.slice(s![..;s2, ..;s2]); + cv = c.slice_mut(s![..;s1, ..;s2]); + } else { + // different stride cases for slicing versus not sliced (for axes of + // len=1); so test not sliced here. + av = a.view(); + bv = b.view(); + cv = c.view_mut(); + } + + let pre_count = CALL_COUNT.with(|ctx| *ctx.borrow()); + general_mat_mul(alpha, &av, &bv, beta, &mut cv); + let after_count = CALL_COUNT.with(|ctx| *ctx.borrow()); + let ncalls = after_count - pre_count; + debug_assert!(ncalls <= 1); + + let always_uses_blas = s1 == 1 && s2 == 1; + + if always_uses_blas { + assert_eq!(ncalls, 1, "Contiguous arrays should use blas, orders={:?}", (ord1, ord2, ord3)); + } + + let should_use_blas = av.strides().iter().all(|&s| s > 0) + && bv.strides().iter().all(|&s| s > 0) + && cv.strides().iter().all(|&s| s > 0) + && av.strides().iter().any(|&s| s == 1) + && bv.strides().iter().any(|&s| s == 1) + && cv.strides().iter().any(|&s| s == 1); + assert_eq!(should_use_blas, ncalls > 0); + } + } + } + } +} diff --git a/crates/blas-tests/Cargo.toml b/crates/blas-tests/Cargo.toml new file mode 100644 index 000000000..ff556873a --- /dev/null +++ b/crates/blas-tests/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "blas-tests" +version = "0.1.0" +authors = ["bluss"] +publish = false +edition = "2018" + +[lib] +test = false +doc = false +doctest = false + +[dependencies] +ndarray = { workspace = true, features = ["approx", "blas"] } +ndarray-gen = { workspace = true } + +blas-src = { version = "0.10", optional = true } +openblas-src = { version = ">=0.10.11", optional = true } +netlib-src = { version = "0.8", optional = true } +blis-src = { version = "0.2", features = ["system"], optional = true } + +[dev-dependencies] +defmac = "0.2" +approx = { workspace = true } +num-traits = { workspace = true } +num-complex = { workspace = true } +itertools = { workspace = true } + +[features] +# Just for making an example and to help testing, , multiple different possible +# configurations are selectable here. +openblas-system = ["blas-src", "blas-src/openblas", "openblas-src/system"] +openblas-cache = ["blas-src", "blas-src/openblas", "openblas-src/cache"] +netlib = ["blas-src", "blas-src/netlib"] +netlib-system = ["blas-src", "blas-src/netlib", "netlib-src/system"] +blis-system = ["blas-src", "blas-src/blis", "blis-src/system"] diff --git a/crates/blas-tests/src/lib.rs b/crates/blas-tests/src/lib.rs new file mode 100644 index 000000000..fc031eedb --- /dev/null +++ b/crates/blas-tests/src/lib.rs @@ -0,0 +1,4 @@ +#[cfg(not(feature = "blas-src"))] +compile_error!("Missing backend: could not compile. + Help: For this testing crate, select one of the blas backend features, for example \ + openblas-system"); diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs new file mode 100644 index 000000000..a9dca7e83 --- /dev/null +++ b/crates/blas-tests/tests/oper.rs @@ -0,0 +1,447 @@ +extern crate approx; +extern crate blas_src; +extern crate defmac; +extern crate ndarray; +extern crate num_complex; +extern crate num_traits; + +use ndarray::prelude::*; + +use ndarray::linalg::general_mat_mul; +use ndarray::linalg::general_mat_vec_mul; +use ndarray::Order; +use ndarray::{Data, Ix, LinalgScalar}; +use ndarray_gen::array_builder::ArrayBuilder; +use ndarray_gen::array_builder::ElementGenerator; + +use approx::assert_relative_eq; +use defmac::defmac; +use itertools::iproduct; +use num_complex::Complex32; +use num_complex::Complex64; +use num_traits::Num; + +#[test] +fn mat_vec_product_1d() +{ + let a = arr2(&[[1.], [2.]]); + let b = arr1(&[1., 2.]); + let ans = arr1(&[5.]); + assert_eq!(a.t().dot(&b), ans); +} + +#[test] +fn mat_vec_product_1d_broadcast() +{ + let a = arr2(&[[1.], [2.], [3.]]); + let b = arr1(&[1.]); + let b = b.broadcast(3).unwrap(); + let ans = arr1(&[6.]); + assert_eq!(a.t().dot(&b), ans); +} + +#[test] +fn mat_vec_product_1d_inverted_axis() +{ + let a = arr2(&[[1.], [2.], [3.]]); + let mut b = arr1(&[1., 2., 3.]); + b.invert_axis(Axis(0)); + + let ans = arr1(&[3. + 4. + 3.]); + assert_eq!(a.t().dot(&b), ans); +} + +fn range_mat(m: Ix, n: Ix) -> Array2 +{ + ArrayBuilder::new((m, n)).build() +} + +fn range_mat_complex(m: Ix, n: Ix) -> Array2 +{ + ArrayBuilder::new((m, n)).build() +} + +fn range_mat_complex64(m: Ix, n: Ix) -> Array2 +{ + ArrayBuilder::new((m, n)).build() +} + +fn range1_mat64(m: Ix) -> Array1 +{ + ArrayBuilder::new(m).build() +} + +fn range_i32(m: Ix, n: Ix) -> Array2 +{ + ArrayBuilder::new((m, n)).build() +} + +// simple, slow, correct (hopefully) mat mul +fn reference_mat_mul(lhs: &ArrayBase, rhs: &ArrayBase) -> Array2 +where + A: LinalgScalar, + S: Data, + S2: Data, +{ + let ((m, k), (k2, n)) = (lhs.dim(), rhs.dim()); + assert!(m.checked_mul(n).is_some()); + assert_eq!(k, k2); + let mut res_elems = Vec::::with_capacity(m * n); + unsafe { + res_elems.set_len(m * n); + } + + let mut i = 0; + let mut j = 0; + for rr in &mut res_elems { + unsafe { + *rr = (0..k).fold(A::zero(), move |s, x| s + *lhs.uget((i, x)) * *rhs.uget((x, j))); + } + j += 1; + if j == n { + j = 0; + i += 1; + } + } + unsafe { ArrayBase::from_shape_vec_unchecked((m, n), res_elems) } +} + +// simple, slow, correct (hopefully) mat mul +fn reference_mat_vec_mul(lhs: &ArrayBase, rhs: &ArrayBase) -> Array1 +where + A: LinalgScalar, + S: Data, + S2: Data, +{ + let ((m, _), k) = (lhs.dim(), rhs.dim()); + reference_mat_mul( + lhs, + &rhs.as_standard_layout() + .into_shape_with_order((k, 1)) + .unwrap(), + ) + .into_shape_with_order(m) + .unwrap() +} + +// simple, slow, correct (hopefully) mat mul +fn reference_vec_mat_mul(lhs: &ArrayBase, rhs: &ArrayBase) -> Array1 +where + A: LinalgScalar, + S: Data, + S2: Data, +{ + let (m, (_, n)) = (lhs.dim(), rhs.dim()); + reference_mat_mul( + &lhs.as_standard_layout() + .into_shape_with_order((1, m)) + .unwrap(), + rhs, + ) + .into_shape_with_order(n) + .unwrap() +} + +// Check that matrix multiplication of contiguous matrices returns a +// matrix with the same order +#[test] +fn mat_mul_order() +{ + let (m, n, k) = (50, 50, 50); + let a = range_mat::(m, n); + let b = range_mat::(n, k); + let mut af = Array::zeros(a.dim().f()); + let mut bf = Array::zeros(b.dim().f()); + af.assign(&a); + bf.assign(&b); + + let cc = a.dot(&b); + let ff = af.dot(&bf); + + assert_eq!(cc.strides()[1], 1); + assert_eq!(ff.strides()[0], 1); +} + +// Check that matrix multiplication +// supports broadcast arrays. +#[test] +fn mat_mul_broadcast() +{ + let (m, n, k) = (16, 16, 16); + let a = range_mat::(m, n); + let x1 = 1.; + let x = Array::from(vec![x1]); + let b0 = x.broadcast((n, k)).unwrap(); + let b1 = Array::from_elem(n, x1); + let b1 = b1.broadcast((n, k)).unwrap(); + let b2 = Array::from_elem((n, k), x1); + + let c2 = a.dot(&b2); + let c1 = a.dot(&b1); + let c0 = a.dot(&b0); + assert_eq!(c2, c1); + assert_eq!(c2, c0); +} + +// Check that matrix multiplication supports reversed axes +#[test] +fn mat_mul_rev() +{ + let (m, n, k) = (16, 16, 16); + let a = range_mat::(m, n); + let b = range_mat::(n, k); + let mut rev = Array::zeros(b.dim()); + let mut rev = rev.slice_mut(s![..;-1, ..]); + rev.assign(&b); + println!("{:.?}", rev); + + let c1 = a.dot(&b); + let c2 = a.dot(&rev); + assert_eq!(c1, c2); +} + +// Check that matrix multiplication supports arrays with zero rows or columns +#[test] +fn mat_mut_zero_len() +{ + defmac!(mat_mul_zero_len range_mat_fn => { + for n in 0..4 { + for m in 0..4 { + let a = range_mat_fn(m, n); + let b = range_mat_fn(n, 0); + assert_eq!(a.dot(&b), Array2::zeros((m, 0))); + } + for k in 0..4 { + let a = range_mat_fn(0, n); + let b = range_mat_fn(n, k); + assert_eq!(a.dot(&b), Array2::zeros((0, k))); + } + } + }); + mat_mul_zero_len!(range_mat::); + mat_mul_zero_len!(range_mat::); + mat_mul_zero_len!(range_i32); +} + +#[test] +fn gen_mat_mul() +{ + let alpha = -2.3; + let beta = 3.14; + let sizes = vec![ + (4, 4, 4), + (8, 8, 8), + (8, 8, 1), + (1, 10, 10), + (10, 1, 10), + (10, 10, 1), + (1, 10, 1), + (10, 1, 1), + (1, 1, 10), + (4, 17, 3), + (17, 3, 22), + (19, 18, 2), + (15, 16, 17), + (67, 50, 62), + ]; + let strides = &[1, 2, -1, -2]; + let cf_order = [Order::C, Order::F]; + let generator = [ElementGenerator::Sequential, ElementGenerator::Checkerboard]; + + // test different strides and memory orders + for (&s1, &s2, &gen) in iproduct!(strides, strides, &generator) { + for &(m, k, n) in &sizes { + for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) { + println!("Case s1={}, s2={}, gen={:?}, orders={:?}, {:?}, {:?}", s1, s2, gen, ord1, ord2, ord3); + let a = ArrayBuilder::new((m, k)) + .memory_order(ord1) + .generator(gen) + .build() + * 0.5; + let b = ArrayBuilder::new((k, n)).memory_order(ord2).build(); + let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build(); + + let mut answer = c.clone(); + + { + let av; + let bv; + let mut cv; + + if s1 != 1 || s2 != 1 { + av = a.slice(s![..;s1, ..;s2]); + bv = b.slice(s![..;s2, ..;s2]); + cv = c.slice_mut(s![..;s1, ..;s2]); + } else { + // different stride cases for slicing versus not sliced (for axes of + // len=1); so test not sliced here. + av = a.view(); + bv = b.view(); + cv = c.view_mut(); + } + + let answer_part = alpha * reference_mat_mul(&av, &bv) + beta * &cv; + answer.slice_mut(s![..;s1, ..;s2]).assign(&answer_part); + + general_mat_mul(alpha, &av, &bv, beta, &mut cv); + } + assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7); + } + } + } +} + +// Test y = A x where A is f-order +#[test] +fn gemm_64_1_f() +{ + let a = range_mat::(64, 64).reversed_axes(); + let (m, n) = a.dim(); + // m x n times n x 1 == m x 1 + let x = range_mat::(n, 1); + let mut y = range_mat::(m, 1); + let answer = reference_mat_mul(&a, &x) + &y; + general_mat_mul(1.0, &a, &x, 1.0, &mut y); + assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7); +} + +#[test] +fn gemm_c64_1_f() +{ + let a = range_mat_complex64(64, 64).reversed_axes(); + let (m, n) = a.dim(); + // m x n times n x 1 == m x 1 + let x = range_mat_complex64(n, 1); + let mut y = range_mat_complex64(m, 1); + let answer = reference_mat_mul(&a, &x) + &y; + general_mat_mul(Complex64::new(1.0, 0.), &a, &x, Complex64::new(1.0, 0.), &mut y); + assert_relative_eq!( + y.mapv(|i| i.norm_sqr()), + answer.mapv(|i| i.norm_sqr()), + epsilon = 1e-12, + max_relative = 1e-7 + ); +} + +#[test] +fn gemm_c32_1_f() +{ + let a = range_mat_complex(64, 64).reversed_axes(); + let (m, n) = a.dim(); + // m x n times n x 1 == m x 1 + let x = range_mat_complex(n, 1); + let mut y = range_mat_complex(m, 1); + let answer = reference_mat_mul(&a, &x) + &y; + general_mat_mul(Complex32::new(1.0, 0.), &a, &x, Complex32::new(1.0, 0.), &mut y); + assert_relative_eq!( + y.mapv(|i| i.norm_sqr()), + answer.mapv(|i| i.norm_sqr()), + epsilon = 1e-12, + max_relative = 1e-7 + ); +} + +#[test] +fn gemm_c64_actually_complex() +{ + let mut a = range_mat_complex64(4, 4); + a = a.map(|&i| if i.re > 8. { i.conj() } else { i }); + let mut b = range_mat_complex64(4, 6); + b = b.map(|&i| if i.re > 4. { i.conj() } else { i }); + let mut y = range_mat_complex64(4, 6); + let alpha = Complex64::new(0., 1.0); + let beta = Complex64::new(1.0, 1.0); + let answer = alpha * reference_mat_mul(&a, &b) + beta * &y; + general_mat_mul(alpha.clone(), &a, &b, beta.clone(), &mut y); + assert_relative_eq!( + y.mapv(|i| i.norm_sqr()), + answer.mapv(|i| i.norm_sqr()), + epsilon = 1e-12, + max_relative = 1e-7 + ); +} + +#[test] +fn gen_mat_vec_mul() +{ + let alpha = -2.3; + let beta = 3.14; + let sizes = vec![ + (4, 4), + (8, 8), + (17, 15), + (4, 17), + (17, 3), + (19, 18), + (16, 17), + (15, 16), + (67, 63), + ]; + // test different strides + for &s1 in &[1, 2, -1, -2] { + for &s2 in &[1, 2, -1, -2] { + for &(m, k) in &sizes { + for order in [Order::C, Order::F] { + let a = ArrayBuilder::new((m, k)).memory_order(order).build(); + let (m, k) = a.dim(); + let b = range1_mat64(k); + let mut c = range1_mat64(m); + let mut answer = c.clone(); + + { + let a = a.slice(s![..;s1, ..;s2]); + let b = b.slice(s![..;s2]); + let mut cv = c.slice_mut(s![..;s1]); + + let answer_part = alpha * reference_mat_vec_mul(&a, &b) + beta * &cv; + answer.slice_mut(s![..;s1]).assign(&answer_part); + + general_mat_vec_mul(alpha, &a, &b, beta, &mut cv); + } + assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7); + } + } + } + } +} + +#[test] +fn vec_mat_mul() +{ + let sizes = vec![ + (4, 4), + (8, 8), + (17, 15), + (4, 17), + (17, 3), + (19, 18), + (16, 17), + (15, 16), + (67, 63), + ]; + // test different strides + for &s1 in &[1, 2, -1, -2] { + for &s2 in &[1, 2, -1, -2] { + for &(m, n) in &sizes { + for order in [Order::C, Order::F] { + let b = ArrayBuilder::new((m, n)).memory_order(order).build(); + let (m, n) = b.dim(); + let a = range1_mat64(m); + let mut c = range1_mat64(n); + let mut answer = c.clone(); + + { + let b = b.slice(s![..;s1, ..;s2]); + let a = a.slice(s![..;s1]); + + let answer_part = reference_vec_mat_mul(&a, &b); + answer.slice_mut(s![..;s2]).assign(&answer_part); + + c.slice_mut(s![..;s2]).assign(&a.dot(&b)); + } + assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7); + } + } + } + } +} diff --git a/crates/ndarray-gen/Cargo.toml b/crates/ndarray-gen/Cargo.toml new file mode 100644 index 000000000..6818e4b65 --- /dev/null +++ b/crates/ndarray-gen/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "ndarray-gen" +version = "0.1.0" +edition = "2018" +publish = false + +[dependencies] +ndarray = { workspace = true, default-features = false } +num-traits = { workspace = true } diff --git a/crates/ndarray-gen/README.md b/crates/ndarray-gen/README.md new file mode 100644 index 000000000..7dd02320c --- /dev/null +++ b/crates/ndarray-gen/README.md @@ -0,0 +1,4 @@ + +## ndarray-gen + +Array generation functions, used for testing. diff --git a/crates/ndarray-gen/src/array_builder.rs b/crates/ndarray-gen/src/array_builder.rs new file mode 100644 index 000000000..9351aadc5 --- /dev/null +++ b/crates/ndarray-gen/src/array_builder.rs @@ -0,0 +1,96 @@ +// Copyright 2024 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use ndarray::Array; +use ndarray::Dimension; +use ndarray::IntoDimension; +use ndarray::Order; + +use num_traits::Num; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct ArrayBuilder +{ + dim: D, + memory_order: Order, + generator: ElementGenerator, +} + +/// How to generate elements +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ElementGenerator +{ + Sequential, + Checkerboard, + Zero, +} + +impl Default for ArrayBuilder +{ + fn default() -> Self + { + Self::new(D::zeros(D::NDIM.unwrap_or(1))) + } +} + +impl ArrayBuilder +where D: Dimension +{ + pub fn new(dim: impl IntoDimension) -> Self + { + ArrayBuilder { + dim: dim.into_dimension(), + memory_order: Order::C, + generator: ElementGenerator::Sequential, + } + } + + pub fn memory_order(mut self, order: Order) -> Self + { + self.memory_order = order; + self + } + + pub fn generator(mut self, generator: ElementGenerator) -> Self + { + self.generator = generator; + self + } + + pub fn build(self) -> Array + where T: Num + Clone + { + let zero = T::zero(); + let size = self.dim.size(); + (match self.generator { + ElementGenerator::Sequential => + Array::from_iter(core::iter::successors(Some(zero), |elt| Some(elt.clone() + T::one())).take(size)), + ElementGenerator::Checkerboard => Array::from_iter([T::one(), zero].iter().cycle().take(size).cloned()), + ElementGenerator::Zero => Array::zeros(size), + }) + .into_shape_with_order((self.dim, self.memory_order)) + .unwrap() + } +} + +#[test] +fn test_order() +{ + let (m, n) = (12, 13); + let c = ArrayBuilder::new((m, n)) + .memory_order(Order::C) + .build::(); + let f = ArrayBuilder::new((m, n)) + .memory_order(Order::F) + .build::(); + + assert_eq!(c.shape(), &[m, n]); + assert_eq!(f.shape(), &[m, n]); + assert_eq!(c.strides(), &[n as isize, 1]); + assert_eq!(f.strides(), &[1, m as isize]); +} diff --git a/crates/ndarray-gen/src/lib.rs b/crates/ndarray-gen/src/lib.rs new file mode 100644 index 000000000..09440e68d --- /dev/null +++ b/crates/ndarray-gen/src/lib.rs @@ -0,0 +1,11 @@ +#![no_std] +// Copyright 2024 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +/// Build ndarray arrays for test purposes +pub mod array_builder; diff --git a/crates/numeric-tests/Cargo.toml b/crates/numeric-tests/Cargo.toml new file mode 100644 index 000000000..93a182e66 --- /dev/null +++ b/crates/numeric-tests/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "numeric-tests" +version = "0.1.0" +authors = ["bluss"] +publish = false +edition = "2018" + +[lib] +test = false +doc = false +doctest = false + +[dependencies] +ndarray = { workspace = true, features = ["approx"] } +ndarray-rand = { workspace = true } + +approx = { workspace = true } +rand = { workspace = true } +rand_distr = { workspace = true } + +blas-src = { optional = true, version = "0.10", default-features = false, features = ["openblas"] } +openblas-src = { optional = true, version = ">=0.10.11", default-features = false, features = ["cblas", "system"] } + +[dev-dependencies] +num-traits = { workspace = true } +num-complex = { workspace = true } + +[features] +test_blas = ["ndarray/blas", "blas-src", "openblas-src"] diff --git a/crates/numeric-tests/src/lib.rs b/crates/numeric-tests/src/lib.rs new file mode 100644 index 000000000..79ffc274e --- /dev/null +++ b/crates/numeric-tests/src/lib.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "test_blas")] +extern crate blas_src; diff --git a/crates/numeric-tests/tests/accuracy.rs b/crates/numeric-tests/tests/accuracy.rs new file mode 100644 index 000000000..c594f020d --- /dev/null +++ b/crates/numeric-tests/tests/accuracy.rs @@ -0,0 +1,303 @@ +extern crate approx; +extern crate ndarray; +extern crate ndarray_rand; +extern crate rand; +extern crate rand_distr; + +extern crate numeric_tests; + +use std::fmt; + +use ndarray_rand::RandomExt; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; + +use ndarray::linalg::general_mat_mul; +use ndarray::prelude::*; +use ndarray::{Data, LinalgScalar}; + +use num_complex::Complex; +use num_traits::{AsPrimitive, Float}; +use rand_distr::{Distribution, Normal, StandardNormal}; + +use approx::{assert_abs_diff_eq, assert_relative_eq}; + +fn kahan_sum(iter: impl Iterator) -> A +where A: LinalgScalar +{ + let mut sum = A::zero(); + let mut compensation = A::zero(); + + for elt in iter { + let y = elt - compensation; + let t = sum + y; + compensation = (t - sum) - y; + sum = t; + } + + sum +} + +// simple, slow, correct (hopefully) mat mul +fn reference_mat_mul(lhs: &ArrayBase, rhs: &ArrayBase) -> Array +where + A: LinalgScalar, + S: Data, + S2: Data, +{ + let ((m, k), (_, n)) = (lhs.dim(), rhs.dim()); + let mut res_elems = Array::zeros(m * n); + + let mut i = 0; + let mut j = 0; + for rr in &mut res_elems { + let lhs_i = lhs.row(i); + let rhs_j = rhs.column(j); + *rr = kahan_sum((0..k).map(move |x| lhs_i[x] * rhs_j[x])); + + j += 1; + if j == n { + j = 0; + i += 1; + } + } + + res_elems.into_shape_with_order((m, n)).unwrap() +} + +fn gen(d: D, rng: &mut SmallRng) -> Array +where + D: Dimension, + A: Float, + StandardNormal: Distribution, +{ + Array::random_using(d, Normal::new(A::zero(), A::one()).unwrap(), rng) +} + +fn gen_complex(d: D, rng: &mut SmallRng) -> Array, D> +where + D: Dimension, + A: Float, + StandardNormal: Distribution, +{ + gen(d.clone(), rng).mapv(Complex::from) + gen(d, rng).mapv(|x| Complex::new(A::zero(), x)) +} + +#[test] +fn accurate_eye_f32() +{ + let rng = &mut SmallRng::from_entropy(); + for i in 0..20 { + let eye = Array::eye(i); + for j in 0..20 { + let a = gen::(Ix2(i, j), rng); + let a2 = eye.dot(&a); + assert_abs_diff_eq!(a, a2, epsilon = 1e-6); + let a3 = a.t().dot(&eye); + assert_abs_diff_eq!(a.t(), a3, epsilon = 1e-6); + } + } + // pick a few random sizes + for _ in 0..10 { + let i = rng.gen_range(15..512); + let j = rng.gen_range(15..512); + println!("Testing size {} by {}", i, j); + let a = gen::(Ix2(i, j), rng); + let eye = Array::eye(i); + let a2 = eye.dot(&a); + assert_abs_diff_eq!(a, a2, epsilon = 1e-6); + let a3 = a.t().dot(&eye); + assert_abs_diff_eq!(a.t(), a3, epsilon = 1e-6); + } +} + +#[test] +fn accurate_eye_f64() +{ + let rng = &mut SmallRng::from_entropy(); + let abs_tol = 1e-15; + for i in 0..20 { + let eye = Array::eye(i); + for j in 0..20 { + let a = gen::(Ix2(i, j), rng); + let a2 = eye.dot(&a); + assert_abs_diff_eq!(a, a2, epsilon = abs_tol); + let a3 = a.t().dot(&eye); + assert_abs_diff_eq!(a.t(), a3, epsilon = abs_tol); + } + } + // pick a few random sizes + for _ in 0..10 { + let i = rng.gen_range(15..512); + let j = rng.gen_range(15..512); + println!("Testing size {} by {}", i, j); + let a = gen::(Ix2(i, j), rng); + let eye = Array::eye(i); + let a2 = eye.dot(&a); + assert_abs_diff_eq!(a, a2, epsilon = 1e-6); + let a3 = a.t().dot(&eye); + assert_abs_diff_eq!(a.t(), a3, epsilon = 1e-6); + } +} + +#[test] +fn accurate_mul_f32_dot() +{ + accurate_mul_float_general::(1e-5, false); +} + +#[test] +fn accurate_mul_f32_general() +{ + accurate_mul_float_general::(1e-5, true); +} + +#[test] +fn accurate_mul_f64_dot() +{ + accurate_mul_float_general::(1e-14, false); +} + +#[test] +fn accurate_mul_f64_general() +{ + accurate_mul_float_general::(1e-14, true); +} + +/// Generate random sized matrices using the given generator function. +/// Compute gemm using either .dot() (if use_general is false) otherwise general_mat_mul. +/// Return tuple of actual result matrix and reference matrix, which should be equal. +fn random_matrix_mul( + rng: &mut SmallRng, use_stride: bool, use_general: bool, generator: fn(Ix2, &mut SmallRng) -> Array2, +) -> (Array2, Array2) +where A: LinalgScalar +{ + let m = rng.gen_range(15..128); + let k = rng.gen_range(15..128); + let n = rng.gen_range(15..512); + let a = generator(Ix2(m, k), rng); + let b = generator(Ix2(n, k), rng); + let c = if use_general { + Some(generator(Ix2(m, n), rng)) + } else { + None + }; + + let b = b.t(); + let (a, b, mut c) = if use_stride { + (a.slice(s![..;2, ..;2]), b.slice(s![..;2, ..;2]), c.map(|c_| c_.slice_move(s![..;2, ..;2]))) + } else { + (a.view(), b, c) + }; + + println!("Testing size {} by {} by {}", a.shape()[0], a.shape()[1], b.shape()[1]); + if let Some(c) = &mut c { + general_mat_mul(A::one(), &a, &b, A::zero(), c); + } else { + c = Some(a.dot(&b)); + } + let c = c.unwrap(); + let reference = reference_mat_mul(&a, &b); + + (c, reference) +} + +fn accurate_mul_float_general(limit: f64, use_general: bool) +where + A: Float + Copy + 'static + AsPrimitive, + StandardNormal: Distribution, + A: fmt::Debug, +{ + // pick a few random sizes + let mut rng = SmallRng::from_entropy(); + for i in 0..20 { + let (c, reference) = random_matrix_mul(&mut rng, i > 10, use_general, gen::); + + let diff = &c - &reference; + let max_diff = diff.iter().copied().fold(A::zero(), A::max); + let max_elt = reference.iter().copied().fold(A::zero(), A::max); + println!("Max elt diff={:?}, max={:?}, ratio={:.4e}", max_diff, max_elt, (max_diff/max_elt).as_()); + assert!((max_diff / max_elt).as_() < limit, + "Expected relative norm diff < {:e}, found {:?} / {:?}", limit, max_diff, max_elt); + } +} + +#[test] +fn accurate_mul_complex32() +{ + accurate_mul_complex_general::(1e-5); +} + +#[test] +fn accurate_mul_complex64() +{ + accurate_mul_complex_general::(1e-14); +} + +fn accurate_mul_complex_general(limit: f64) +where + A: Float + Copy + 'static + AsPrimitive, + StandardNormal: Distribution, + A: fmt::Debug, +{ + // pick a few random sizes + let mut rng = SmallRng::from_entropy(); + for i in 0..20 { + let (c, reference) = random_matrix_mul(&mut rng, i > 10, true, gen_complex::); + + let diff = &c - &reference; + let max_elt = |elt: &Complex<_>| A::max(A::abs(elt.re), A::abs(elt.im)); + let max_diff = diff.iter().map(max_elt).fold(A::zero(), A::max); + let max_elt = reference.iter().map(max_elt).fold(A::zero(), A::max); + println!("Max elt diff={:?}, max={:?}, ratio={:.4e}", max_diff, max_elt, (max_diff/max_elt).as_()); + assert!((max_diff / max_elt).as_() < limit, + "Expected relative norm diff < {:e}, found {:?} / {:?}", limit, max_diff, max_elt); + } +} + +#[test] +fn accurate_mul_with_column_f64() +{ + // pick a few random sizes + let rng = &mut SmallRng::from_entropy(); + for i in 0..10 { + let m = rng.gen_range(1..128); + let k = rng.gen_range(1..350); + let a = gen::(Ix2(m, k), rng); + let b_owner = gen::(Ix2(k, k), rng); + let b_row_col; + let b_sq; + + // pick dense square or broadcasted to square matrix + match i { + 0..=3 => b_sq = b_owner.view(), + 4..=7 => { + b_row_col = b_owner.column(0); + b_sq = b_row_col.broadcast((k, k)).unwrap(); + } + _otherwise => { + b_row_col = b_owner.row(0); + b_sq = b_row_col.broadcast((k, k)).unwrap(); + } + }; + + for j in 0..k { + for &flip in &[true, false] { + let j = j as isize; + let b = if flip { + // one row in 2D + b_sq.slice(s![j..j + 1, ..]).reversed_axes() + } else { + // one column in 2D + b_sq.slice(s![.., j..j + 1]) + }; + println!("Testing size ({} × {}) by ({} × {})", a.shape()[0], a.shape()[1], b.shape()[0], b.shape()[1]); + println!("Strides ({:?}) by ({:?})", a.strides(), b.strides()); + let c = a.dot(&b); + let reference = reference_mat_mul(&a, &b); + + assert_relative_eq!(c, reference, epsilon = 1e-12, max_relative = 1e-7); + } + } + } +} diff --git a/crates/serialization-tests/Cargo.toml b/crates/serialization-tests/Cargo.toml new file mode 100644 index 000000000..be7c4c17b --- /dev/null +++ b/crates/serialization-tests/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "serialization-tests" +version = "0.1.0" +authors = ["bluss"] +publish = false +edition = "2018" + +[lib] +test = false +doc = false +doctest = false + +[dependencies] +ndarray = { workspace = true, features = ["serde"] } + +serde = { version = "1.0.100", default-features = false } +ron = { version = "0.8.1" } + +[dev-dependencies] +serde_json = { version = "1.0.40" } +# Old version to work with Rust 1.64+ +rmp = { version = "=0.8.10" } +# Old version to work with Rust 1.64+ +rmp-serde = { version = "0.14" } diff --git a/blas-tests/src/lib.rs b/crates/serialization-tests/src/lib.rs similarity index 100% rename from blas-tests/src/lib.rs rename to crates/serialization-tests/src/lib.rs diff --git a/serialization-tests/tests/serialize.rs b/crates/serialization-tests/tests/serialize.rs similarity index 88% rename from serialization-tests/tests/serialize.rs rename to crates/serialization-tests/tests/serialize.rs index efb3bacd9..478eb20ef 100644 --- a/serialization-tests/tests/serialize.rs +++ b/crates/serialization-tests/tests/serialize.rs @@ -6,13 +6,13 @@ extern crate serde_json; extern crate rmp_serde; -#[cfg(feature = "ron")] extern crate ron; use ndarray::{arr0, arr1, arr2, s, ArcArray, ArcArray2, ArrayD, IxDyn}; #[test] -fn serial_many_dim_serde() { +fn serial_many_dim_serde() +{ { let a = arr0::(2.72); let serial = serde_json::to_string(&a).unwrap(); @@ -45,18 +45,21 @@ fn serial_many_dim_serde() { { // Test a sliced array. - let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4)); + let mut a = ArcArray::from_iter(0..32) + .into_shape_with_order((2, 2, 2, 4)) + .unwrap(); a.slice_collapse(s![..;-1, .., .., ..2]); let serial = serde_json::to_string(&a).unwrap(); println!("Encode {:?} => {:?}", a, serial); - let res = serde_json::from_str::>(&serial); + let res = serde_json::from_str::>(&serial); println!("{:?}", res); assert_eq!(a, res.unwrap()); } } #[test] -fn serial_ixdyn_serde() { +fn serial_ixdyn_serde() +{ { let a = arr0::(2.72).into_dyn(); let serial = serde_json::to_string(&a).unwrap(); @@ -77,7 +80,7 @@ fn serial_ixdyn_serde() { { let a = arr2(&[[3., 1., 2.2], [3.1, 4., 7.]]) - .into_shape(IxDyn(&[3, 1, 1, 1, 2, 1])) + .into_shape_with_order(IxDyn(&[3, 1, 1, 1, 2, 1])) .unwrap(); let serial = serde_json::to_string(&a).unwrap(); println!("Serde encode {:?} => {:?}", a, serial); @@ -95,7 +98,8 @@ fn serial_ixdyn_serde() { } #[test] -fn serial_wrong_count_serde() { +fn serial_wrong_count_serde() +{ // one element too few let text = r##"{"v":1,"dim":[2,3],"data":[3,1,2.2,3.1,4]}"##; let arr = serde_json::from_str::>(text); @@ -110,7 +114,8 @@ fn serial_wrong_count_serde() { } #[test] -fn serial_many_dim_serde_msgpack() { +fn serial_many_dim_serde_msgpack() +{ { let a = arr0::(2.72); @@ -155,7 +160,9 @@ fn serial_many_dim_serde_msgpack() { { // Test a sliced array. - let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4)); + let mut a = ArcArray::from_iter(0..32) + .into_shape_with_order((2, 2, 2, 4)) + .unwrap(); a.slice_collapse(s![..;-1, .., .., ..2]); let mut buf = Vec::new(); @@ -164,15 +171,15 @@ fn serial_many_dim_serde_msgpack() { .unwrap(); let mut deserializer = rmp_serde::Deserializer::new(&buf[..]); - let a_de: ArcArray = serde::Deserialize::deserialize(&mut deserializer).unwrap(); + let a_de: ArcArray = serde::Deserialize::deserialize(&mut deserializer).unwrap(); assert_eq!(a, a_de); } } #[test] -#[cfg(feature = "ron")] -fn serial_many_dim_ron() { +fn serial_many_dim_ron() +{ use ron::de::from_str as ron_deserialize; use ron::ser::to_string as ron_serialize; @@ -208,12 +215,14 @@ fn serial_many_dim_ron() { { // Test a sliced array. - let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4)); + let mut a = ArcArray::from_iter(0..32) + .into_shape_with_order((2, 2, 2, 4)) + .unwrap(); a.slice_collapse(s![..;-1, .., .., ..2]); let a_s = ron_serialize(&a).unwrap(); - let a_de: ArcArray = ron_deserialize(&a_s).unwrap(); + let a_de: ArcArray = ron_deserialize(&a_s).unwrap(); assert_eq!(a, a_de); } diff --git a/examples/axis_ops.rs b/examples/axis_ops.rs index 3dbf0eee9..3a54a52fb 100644 --- a/examples/axis_ops.rs +++ b/examples/axis_ops.rs @@ -1,34 +1,45 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names )] use ndarray::prelude::*; -fn regularize(a: &mut Array) -> Result<(), ()> +/// Reorder a's axes so that they are in "standard" axis order; +/// make sure axes are in positive stride direction, and merge adjacent +/// axes if possible. +/// +/// This changes the logical order of the elements in the +/// array, so that if we read them in row-major order after regularization, +/// it corresponds to their order in memory. +/// +/// Errors if array has a 0-stride axis +fn regularize(a: &mut Array) -> Result<(), &'static str> where D: Dimension, A: ::std::fmt::Debug, { println!("Regularize:\n{:?}", a); // reverse all neg axes - while let Some(ax) = a.axes().find(|ax| ax.stride() <= 0) { - if ax.stride() == 0 { - return Err(()); + while let Some(ax) = a.axes().find(|ax| ax.stride <= 0) { + if ax.stride == 0 { + // no real reason to error on this case; other than + // stride == 0 is incompatible with an owned array. + return Err("Cannot regularize array with stride == 0 axis"); } // reverse ax - println!("Reverse {:?}", ax.axis()); - a.invert_axis(ax.axis()); + println!("Reverse {:?}", ax.axis); + a.invert_axis(ax.axis); } // sort by least stride let mut i = 0; let n = a.ndim(); - while let Some(ax) = a.axes().rev().skip(i).min_by_key(|ax| ax.stride().abs()) { - a.swap_axes(n - 1 - i, ax.axis().index()); - println!("Swap {:?} <=> {}", ax.axis(), n - 1 - i); + while let Some(ax) = a.axes().rev().skip(i).min_by_key(|ax| ax.stride.abs()) { + let cur_axis = Axis(n - 1 - i); + if ax.axis != cur_axis { + a.swap_axes(cur_axis.index(), ax.axis.index()); + println!("Swap {:?} <=> {:?}", cur_axis, ax.axis); + } i += 1; } @@ -40,11 +51,12 @@ where break; } } - println!("{:?}", a); + println!("Result:\n{:?}\n", a); Ok(()) } -fn main() { +fn main() +{ let mut a = Array::::zeros((2, 3, 4)); for (i, elt) in (0..).zip(&mut a) { *elt = i; @@ -52,22 +64,24 @@ fn main() { a.swap_axes(0, 1); a.swap_axes(0, 2); a.slice_collapse(s![.., ..;-1, ..]); - regularize(&mut a).ok(); + regularize(&mut a).unwrap(); let mut b = Array::::zeros((2, 3, 4)); for (i, elt) in (0..).zip(&mut b) { *elt = i; } - regularize(&mut b).ok(); - let mut b = b.into_shape(a.len()).unwrap(); - regularize(&mut b).ok(); + regularize(&mut b).unwrap(); + + let mut b = b.into_shape_with_order(a.len()).unwrap(); + regularize(&mut b).unwrap(); + b.invert_axis(Axis(0)); - regularize(&mut b).ok(); + regularize(&mut b).unwrap(); let mut a = Array::::zeros((2, 3, 4)); for (i, elt) in (0..).zip(&mut a) { *elt = i; } a.slice_collapse(s![..;-1, ..;2, ..]); - regularize(&mut a).ok(); + regularize(&mut a).unwrap(); } diff --git a/examples/bounds_check_elim.rs b/examples/bounds_check_elim.rs index d1c247433..f1a91cca0 100644 --- a/examples/bounds_check_elim.rs +++ b/examples/bounds_check_elim.rs @@ -1,9 +1,6 @@ #![crate_type = "lib"] #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names )] // Test cases for bounds check elimination @@ -38,7 +35,8 @@ pub fn testvec_as_slice(a: &Vec) -> f64 { */ #[no_mangle] -pub fn test1d_single(a: &Array1, i: usize) -> f64 { +pub fn test1d_single(a: &Array1, i: usize) -> f64 +{ if i < a.len() { a[i] } else { @@ -47,7 +45,8 @@ pub fn test1d_single(a: &Array1, i: usize) -> f64 { } #[no_mangle] -pub fn test1d_single_mut(a: &mut Array1, i: usize) -> f64 { +pub fn test1d_single_mut(a: &mut Array1, i: usize) -> f64 +{ if i < a.len() { *&mut a[i] } else { @@ -56,8 +55,9 @@ pub fn test1d_single_mut(a: &mut Array1, i: usize) -> f64 { } #[no_mangle] -pub fn test1d_len_of(a: &Array1) -> f64 { - let a = &*a; +pub fn test1d_len_of(a: &Array1) -> f64 +{ + let a = a; let mut sum = 0.; for i in 0..a.len_of(Axis(0)) { sum += a[i]; @@ -66,7 +66,8 @@ pub fn test1d_len_of(a: &Array1) -> f64 { } #[no_mangle] -pub fn test1d_range(a: &Array1) -> f64 { +pub fn test1d_range(a: &Array1) -> f64 +{ let mut sum = 0.; for i in 0..a.len() { sum += a[i]; @@ -75,7 +76,8 @@ pub fn test1d_range(a: &Array1) -> f64 { } #[no_mangle] -pub fn test1d_while(a: &Array1) -> f64 { +pub fn test1d_while(a: &Array1) -> f64 +{ let mut sum = 0.; let mut i = 0; while i < a.len() { @@ -86,7 +88,8 @@ pub fn test1d_while(a: &Array1) -> f64 { } #[no_mangle] -pub fn test2d_ranges(a: &Array2) -> f64 { +pub fn test2d_ranges(a: &Array2) -> f64 +{ let mut sum = 0.; for i in 0..a.nrows() { for j in 0..a.ncols() { @@ -97,7 +100,8 @@ pub fn test2d_ranges(a: &Array2) -> f64 { } #[no_mangle] -pub fn test2d_whiles(a: &Array2) -> f64 { +pub fn test2d_whiles(a: &Array2) -> f64 +{ let mut sum = 0.; let mut i = 0; while i < a.nrows() { diff --git a/examples/column_standardize.rs b/examples/column_standardize.rs index c360170bd..329ad2ccb 100644 --- a/examples/column_standardize.rs +++ b/examples/column_standardize.rs @@ -1,31 +1,27 @@ +#[cfg(feature = "std")] use ndarray::prelude::*; -fn std1d(a: ArrayView1<'_, f64>) -> f64 { - let n = a.len() as f64; - if n == 0. { - return 0.; - } - let mean = a.sum() / n; - (a.fold(0., |acc, &x| acc + (x - mean).powi(2)) / n).sqrt() -} - -fn std(a: &Array2, axis: Axis) -> Array1 { - a.map_axis(axis, std1d) -} - -fn main() { - // "recreating the following" +#[cfg(feature = "std")] +fn main() +{ + // This example recreates the following from python/numpy // counts -= np.mean(counts, axis=0) // counts /= np.std(counts, axis=0) let mut data = array![[-1., -2., -3.], [1., -3., 5.], [2., 2., 2.]]; println!("{:8.4}", data); - println!("{:8.4} (Mean axis=0)", data.mean_axis(Axis(0)).unwrap()); + println!("Mean along axis=0 (along columns):\n{:8.4}", data.mean_axis(Axis(0)).unwrap()); data -= &data.mean_axis(Axis(0)).unwrap(); - println!("{:8.4}", data); + println!("Centered around mean:\n{:8.4}", data); - data /= &std(&data, Axis(0)); - println!("{:8.4}", data); + data /= &data.std_axis(Axis(0), 0.); + println!("Scaled to normalize std:\n{:8.4}", data); + + println!("New mean:\n{:8.4}", data.mean_axis(Axis(0)).unwrap()); + println!("New std: \n{:8.4}", data.std_axis(Axis(0), 0.)); } + +#[cfg(not(feature = "std"))] +fn main() {} diff --git a/examples/convo.rs b/examples/convo.rs index a9f073bd5..a59795e12 100644 --- a/examples/convo.rs +++ b/examples/convo.rs @@ -1,7 +1,7 @@ #![allow(unused)] extern crate ndarray; -extern crate num_traits; +#[cfg(feature = "std")] use num_traits::Float; use ndarray::prelude::*; @@ -13,9 +13,9 @@ const SHARPEN: [[f32; 3]; 3] = [[0., -1., 0.], [-1., 5., -1.], [0., -1., 0.]]; type Kernel3x3 = [[A; 3]; 3]; #[inline(never)] +#[cfg(feature = "std")] fn conv_3x3(a: &ArrayView2<'_, F>, out: &mut ArrayViewMut2<'_, F>, kernel: &Kernel3x3) -where - F: Float, +where F: Float { let (n, m) = a.dim(); let (np, mp) = out.dim(); @@ -41,7 +41,9 @@ where } } -fn main() { +#[cfg(feature = "std")] +fn main() +{ let n = 16; let mut a = Array::zeros((n, n)); // make a circle @@ -61,3 +63,5 @@ fn main() { } println!("{:2}", res); } +#[cfg(not(feature = "std"))] +fn main() {} diff --git a/examples/life.rs b/examples/life.rs index 1c2789389..7db384678 100644 --- a/examples/life.rs +++ b/examples/life.rs @@ -1,12 +1,8 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names )] use ndarray::prelude::*; -use std::iter::FromIterator; const INPUT: &[u8] = include_bytes!("life.txt"); @@ -14,7 +10,8 @@ const N: usize = 100; type Board = Array2; -fn parse(x: &[u8]) -> Board { +fn parse(x: &[u8]) -> Board +{ // make a border of 0 cells let mut map = Board::from_elem(((N + 2), (N + 2)), 0); let a = Array::from_iter(x.iter().filter_map(|&b| match b { @@ -23,7 +20,7 @@ fn parse(x: &[u8]) -> Board { _ => None, })); - let a = a.into_shape((N, N)).unwrap(); + let a = a.into_shape_with_order((N, N)).unwrap(); map.slice_mut(s![1..-1, 1..-1]).assign(&a); map } @@ -34,7 +31,8 @@ fn parse(x: &[u8]) -> Board { // 3 neighbors: birth // otherwise: death -fn iterate(z: &mut Board, scratch: &mut Board) { +fn iterate(z: &mut Board, scratch: &mut Board) +{ // compute number of neighbors let mut neigh = scratch.view_mut(); neigh.fill(0); @@ -57,7 +55,8 @@ fn iterate(z: &mut Board, scratch: &mut Board) { zv.zip_mut_with(&neigh, |y, &n| *y = ((n == 3) || (n == 2 && *y > 0)) as u8); } -fn turn_on_corners(z: &mut Board) { +fn turn_on_corners(z: &mut Board) +{ let n = z.nrows(); let m = z.ncols(); z[[1, 1]] = 1; @@ -66,8 +65,9 @@ fn turn_on_corners(z: &mut Board) { z[[n - 2, m - 2]] = 1; } -fn render(a: &Board) { - for row in a.genrows() { +fn render(a: &Board) +{ + for row in a.rows() { for &x in row { if x > 0 { print!("#"); @@ -79,7 +79,8 @@ fn render(a: &Board) { } } -fn main() { +fn main() +{ let mut a = parse(INPUT); let mut scratch = Board::zeros((N, N)); let steps = 100; diff --git a/examples/rollaxis.rs b/examples/rollaxis.rs index 8efdd0ce0..82c381297 100644 --- a/examples/rollaxis.rs +++ b/examples/rollaxis.rs @@ -22,7 +22,8 @@ where a } -fn main() { +fn main() +{ let mut data = array![ [[-1., 0., -2.], [1., 7., -3.]], [[1., 0., -3.], [1., 7., 5.]], diff --git a/examples/sort-axis.rs b/examples/sort-axis.rs index 4465464b8..4da3a64d5 100644 --- a/examples/sort-axis.rs +++ b/examples/sort-axis.rs @@ -1,18 +1,27 @@ +//! This is an example of sorting arrays along an axis. +//! This file may not be so instructive except for advanced users, instead it +//! could be a "feature preview" before sorting is added to the main crate. +//! use ndarray::prelude::*; use ndarray::{Data, RemoveAxis, Zip}; +use rawpointer::PointerExt; + use std::cmp::Ordering; use std::ptr::copy_nonoverlapping; // Type invariant: Each index appears exactly once #[derive(Clone, Debug)] -pub struct Permutation { +pub struct Permutation +{ indices: Vec, } -impl Permutation { +impl Permutation +{ /// Checks if the permutation is correct - pub fn from_indices(v: Vec) -> Result { + pub fn from_indices(v: Vec) -> Result + { let perm = Permutation { indices: v }; if perm.correct() { Ok(perm) @@ -21,34 +30,35 @@ impl Permutation { } } - fn correct(&self) -> bool { + fn correct(&self) -> bool + { let axis_len = self.indices.len(); let mut seen = vec![false; axis_len]; for &i in &self.indices { match seen.get_mut(i) { None => return false, - Some(s) => { + Some(s) => if *s { return false; } else { *s = true; - } - } + }, } } true } } -pub trait SortArray { +pub trait SortArray +{ /// ***Panics*** if `axis` is out of bounds. fn identity(&self, axis: Axis) -> Permutation; fn sort_axis_by(&self, axis: Axis, less_than: F) -> Permutation - where - F: FnMut(usize, usize) -> bool; + where F: FnMut(usize, usize) -> bool; } -pub trait PermuteArray { +pub trait PermuteArray +{ type Elem; type Dim; fn permute_axis(self, axis: Axis, perm: &Permutation) -> Array @@ -62,15 +72,15 @@ where S: Data, D: Dimension, { - fn identity(&self, axis: Axis) -> Permutation { + fn identity(&self, axis: Axis) -> Permutation + { Permutation { indices: (0..self.len_of(axis)).collect(), } } fn sort_axis_by(&self, axis: Axis, mut less_than: F) -> Permutation - where - F: FnMut(usize, usize) -> bool, + where F: FnMut(usize, usize) -> bool { let mut perm = self.identity(axis); perm.indices.sort_by(move |&a, &b| { @@ -87,46 +97,81 @@ where } impl PermuteArray for Array -where - D: Dimension, +where D: Dimension { type Elem = A; type Dim = D; fn permute_axis(self, axis: Axis, perm: &Permutation) -> Array - where - D: RemoveAxis, + where D: RemoveAxis { - let axis = axis; let axis_len = self.len_of(axis); + let axis_stride = self.stride_of(axis); assert_eq!(axis_len, perm.indices.len()); debug_assert!(perm.correct()); - let mut v = Vec::with_capacity(self.len()); - let mut result; + if self.is_empty() { + return self; + } + + let mut result = Array::uninit(self.dim()); - // panic-critical begin: we must not panic unsafe { - v.set_len(self.len()); - result = Array::from_shape_vec_unchecked(self.dim(), v); - for i in 0..axis_len { - let perm_i = perm.indices[i]; - Zip::from(result.index_axis_mut(axis, perm_i)) - .and(self.index_axis(axis, i)) - .apply(|to, from| copy_nonoverlapping(from, to, 1)); - } - // forget moved array elements but not its vec - let mut old_storage = self.into_raw_vec(); + // logically move ownership of all elements from self into result + // the result realizes this ownership at .assume_init() further down + let mut moved_elements = 0; + + // the permutation vector is used like this: + // + // index: 0 1 2 3 (index in result) + // permut: 2 3 0 1 (index in the source) + // + // move source 2 -> result 0, + // move source 3 -> result 1, + // move source 0 -> result 2, + // move source 1 -> result 3, + // et.c. + + let source_0 = self.raw_view().index_axis_move(axis, 0); + + Zip::from(&perm.indices) + .and(result.axis_iter_mut(axis)) + .for_each(|&perm_i, result_pane| { + // Use a shortcut to avoid bounds checking in `index_axis` for the source. + // + // It works because for any given element pointer in the array we have the + // relationship: + // + // .index_axis(axis, 0) + .stride_of(axis) * j == .index_axis(axis, j) + // + // where + is pointer arithmetic on the element pointers. + // + // Here source_0 and the offset is equivalent to self.index_axis(axis, perm_i) + Zip::from(result_pane) + .and(source_0.clone()) + .for_each(|to, from_0| { + let from = from_0.stride_offset(axis_stride, perm_i); + copy_nonoverlapping(from, to.as_mut_ptr(), 1); + moved_elements += 1; + }); + }); + debug_assert_eq!(result.len(), moved_elements); + // forget the old elements but not the allocation + let mut old_storage = self.into_raw_vec_and_offset().0; old_storage.set_len(0); - // old_storage drops empty + + // transfer ownership of the elements into the result + result.assume_init() } - // panic-critical end - result } } -fn main() { - let a = Array::linspace(0., 63., 64).into_shape((8, 8)).unwrap(); +#[cfg(feature = "std")] +fn main() +{ + let a = Array::linspace(0., 63., 64) + .into_shape_with_order((8, 8)) + .unwrap(); let strings = a.map(|x| x.to_string()); let perm = a.sort_axis_by(Axis(1), |i, j| a[[i, 0]] > a[[j, 0]]); @@ -138,3 +183,77 @@ fn main() { let c = strings.permute_axis(Axis(1), &perm); println!("{:?}", c); } + +#[cfg(not(feature = "std"))] +fn main() {} + +#[cfg(test)] +mod tests +{ + use super::*; + #[test] + fn test_permute_axis() + { + let a = array![ + [107998.96, 1.], + [107999.08, 2.], + [107999.20, 3.], + [108000.33, 4.], + [107999.45, 5.], + [107999.57, 6.], + [108010.69, 7.], + [107999.81, 8.], + [107999.94, 9.], + [75600.09, 10.], + [75600.21, 11.], + [75601.33, 12.], + [75600.45, 13.], + [75600.58, 14.], + [109000.70, 15.], + [75600.82, 16.], + [75600.94, 17.], + [75601.06, 18.], + ]; + let answer = array![ + [75600.09, 10.], + [75600.21, 11.], + [75600.45, 13.], + [75600.58, 14.], + [75600.82, 16.], + [75600.94, 17.], + [75601.06, 18.], + [75601.33, 12.], + [107998.96, 1.], + [107999.08, 2.], + [107999.20, 3.], + [107999.45, 5.], + [107999.57, 6.], + [107999.81, 8.], + [107999.94, 9.], + [108000.33, 4.], + [108010.69, 7.], + [109000.70, 15.], + ]; + + // f layout copy of a + let mut af = Array::zeros(a.dim().f()); + af.assign(&a); + + // transposed copy of a + let at = a.t().to_owned(); + + // c layout permute + let perm = a.sort_axis_by(Axis(0), |i, j| a[[i, 0]] < a[[j, 0]]); + + let b = a.permute_axis(Axis(0), &perm); + assert_eq!(b, answer); + + // f layout permute + let bf = af.permute_axis(Axis(0), &perm); + assert_eq!(bf, answer); + + // transposed permute + let bt = at.permute_axis(Axis(1), &perm); + assert_eq!(bt, answer.t()); + } +} diff --git a/examples/type_conversion.rs b/examples/type_conversion.rs new file mode 100644 index 000000000..a419af740 --- /dev/null +++ b/examples/type_conversion.rs @@ -0,0 +1,120 @@ +#[cfg(feature = "approx")] +use std::convert::TryFrom; + +#[cfg(feature = "approx")] +use approx::assert_abs_diff_eq; +#[cfg(feature = "approx")] +use ndarray::prelude::*; + +#[cfg(feature = "approx")] +fn main() +{ + // Converting an array from one datatype to another is implemented with the + // `ArrayBase::mapv()` function. We pass a closure that is applied to each + // element independently. This allows for more control and flexiblity in + // converting types. + // + // Below, we illustrate four different approaches for the actual conversion + // in the closure. + // - `From` ensures lossless conversions known at compile time and is the + // best default choice. + // - `TryFrom` either converts data losslessly or panics, ensuring that the + // rest of the program does not continue with unexpected data. + // - `as` never panics and may silently convert in a lossy way, depending + // on the source and target datatypes. More details can be found in the + // reference: https://doc.rust-lang.org/reference/expressions/operator-expr.html#numeric-cast + // - Using custom logic in the closure, e.g. to clip values or for NaN + // handling in floats. + // + // For a brush-up on casting between numeric types in Rust, refer to: + // https://doc.rust-lang.org/rust-by-example/types/cast.html + + // Infallible, lossless conversion with `From` + // The trait `std::convert::From` is only implemented for conversions that + // can be guaranteed to be lossless at compile time. This is the safest + // approach. + let a_u8: Array = array![[1, 2, 3], [4, 5, 6]]; + let a_f32 = a_u8.mapv(|element| f32::from(element)); + assert_abs_diff_eq!(a_f32, array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + + // Fallible, lossless conversion with `TryFrom` + // `i8` numbers can be negative, in such a case, there is no perfect + // conversion to `u8` defined. In this example, all numbers are positive and + // in bounds and can be converted at runtime. But for unknown runtime input, + // this would panic with the message provided in `.expect()`. Note that you + // can also use `.unwrap()` to be more concise. + let a_i8: Array = array![120, 8, 0]; + let a_u8 = a_i8.mapv(|element| u8::try_from(element).expect("Could not convert i8 to u8")); + assert_eq!(a_u8, array![120u8, 8u8, 0u8]); + + // Unsigned to signed integer conversion with `as` + // A real-life example of this would be coordinates on a grid. + // A `usize` value can be larger than what fits into a `isize`, therefore, + // it would be safer to use `TryFrom`. Nevertheless, `as` can be used for + // either simplicity or performance. + // The example includes `usize::MAX` to illustrate potentially undesired + // behavior. It will be interpreted as -1 (noop-casting + 2-complement), see + // https://doc.rust-lang.org/reference/expressions/operator-expr.html#numeric-cast + let a_usize: Array = array![1, 2, 3, usize::MAX]; + let a_isize = a_usize.mapv(|element| element as isize); + assert_eq!(a_isize, array![1_isize, 2_isize, 3_isize, -1_isize]); + + // Simple upcasting with `as` + // Every `u8` fits perfectly into a `u32`, therefore this is a lossless + // conversion. + // Note that it is up to the programmer to ensure the validity of the + // conversion over the lifetime of a program. With type inference, subtle + // bugs can creep in since conversions with `as` will always compile, so a + // programmer might not notice that a prior lossless conversion became a + // lossy conversion. With `From`, this would be noticed at compile-time and + // with `TryFrom`, it would also be either handled or make the program + // panic. + let a_u8: Array = array![[1, 2, 3], [4, 5, 6]]; + let a_u32 = a_u8.mapv(|element| element as u32); + assert_eq!(a_u32, array![[1u32, 2u32, 3u32], [4u32, 5u32, 6u32]]); + + // Saturating cast with `as` + // The `as` keyword performs a *saturating cast* When casting floats to + // ints. This means that numbers which do not fit into the target datatype + // will silently be clipped to the maximum/minimum numbers. Since this is + // not obvious, we discourage the intentional use of casting with `as` with + // silent saturation and recommend a custom logic instead which makes the + // intent clear. + let a_f32: Array = array![ + 256.0, // saturated to 255 + 255.7, // saturated to 255 + 255.1, // saturated to 255 + 254.7, // rounded down to 254 by cutting the decimal part + 254.1, // rounded down to 254 by cutting the decimal part + -1.0, // saturated to 0 on the lower end + f32::INFINITY, // saturated to 255 + f32::NAN, // converted to zero + ]; + let a_u8 = a_f32.mapv(|element| element as u8); + assert_eq!(a_u8, array![255, 255, 255, 254, 254, 0, 255, 0]); + + // Custom mapping logic + // Given that we pass a closure for the conversion, we can also define + // custom logic to e.g. replace NaN values and clip others. This also + // makes the intent clear. + let a_f32: Array = array![ + 270.0, // clipped to 200 + -1.2, // clipped to 0 + 4.7, // rounded up to 5 instead of just stripping decimals + f32::INFINITY, // clipped to 200 + f32::NAN, // replaced with upper bound 200 + ]; + let a_u8_custom = a_f32.mapv(|element| { + if element == f32::INFINITY || element.is_nan() { + return 200; + } + if let Some(std::cmp::Ordering::Less) = element.partial_cmp(&0.0) { + return 0; + } + 200.min(element.round() as u8) + }); + assert_eq!(a_u8_custom, array![200, 0, 5, 200, 200]); +} + +#[cfg(not(feature = "approx"))] +fn main() {} diff --git a/examples/zip_many.rs b/examples/zip_many.rs index 0059cd650..57d66a956 100644 --- a/examples/zip_many.rs +++ b/examples/zip_many.rs @@ -1,53 +1,49 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names )] use ndarray::prelude::*; use ndarray::Zip; -fn main() { - let n = 16; +fn main() +{ + let n = 6; + let mut a = Array::::zeros((n, n)); - let mut b = Array::::from_elem((n, n), 1.); + let mut b = Array::::zeros((n, n)); for ((i, j), elt) in b.indexed_iter_mut() { - *elt /= 1. + (i + 2 * j) as f32; + *elt = 1. / (1. + (i + 2 * j) as f32); } let c = Array::::from_elem((n, n + 1), 1.7); let c = c.slice(s![.., ..-1]); - { - let a = a.view_mut().reversed_axes(); - azip!((a in a, &b in b.t()) *a = b); - } - assert_eq!(a, b); + // Using Zip for arithmetic ops across a, b, c + Zip::from(&mut a) + .and(&b) + .and(&c) + .for_each(|a, &b, &c| *a = b + c); + assert_eq!(a, &b + &c); + // and this is how to do the *same thing* with azip!() azip!((a in &mut a, &b in &b, &c in c) *a = b + c); - assert_eq!(a, &b + &c); - // sum of each row - let ax = Axis(0); - let mut sums = Array::zeros(a.len_of(ax)); - azip!((s in &mut sums, a in a.axis_iter(ax)) *s = a.sum()); + println!("{:8.4}", a); - // sum of each chunk + // sum of each row + let mut sums = Array::zeros(a.nrows()); + Zip::from(a.rows()) + .and(&mut sums) + .for_each(|row, sum| *sum = row.sum()); + // show sums as a column matrix + println!("{:8.4}", sums.insert_axis(Axis(1))); + + // sum of each 2x2 chunk let chunk_sz = (2, 2); let nchunks = (n / chunk_sz.0, n / chunk_sz.1); let mut sums = Array::zeros(nchunks); - azip!((s in &mut sums, a in a.exact_chunks(chunk_sz)) *s = a.sum()); - - // Let's imagine we split to parallelize - { - let (x, y) = Zip::indexed(&mut a).split(); - x.apply(|(_, j), elt| { - *elt = elt.powi(j as i32); - }); - - y.apply(|(_, j), elt| { - *elt = elt.powi(j as i32); - }); - } - println!("{:8.3?}", a); + + Zip::from(a.exact_chunks(chunk_sz)) + .and(&mut sums) + .for_each(|chunk, sum| *sum = chunk.sum()); + println!("{:8.4}", sums); } diff --git a/docgen/images/axis_iter.svg b/misc/axis_iter.svg similarity index 100% rename from docgen/images/axis_iter.svg rename to misc/axis_iter.svg diff --git a/docgen/images/split_at.svg b/misc/split_at.svg similarity index 100% rename from docgen/images/split_at.svg rename to misc/split_at.svg diff --git a/ndarray-rand/Cargo.toml b/ndarray-rand/Cargo.toml index aa4715fc4..b58e752a5 100644 --- a/ndarray-rand/Cargo.toml +++ b/ndarray-rand/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "ndarray-rand" -version = "0.11.0" +version = "0.15.0" edition = "2018" authors = ["bluss"] -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" repository = "https://github.com/rust-ndarray/ndarray" documentation = "https://docs.rs/ndarray-rand/" @@ -14,17 +14,15 @@ description = "Constructors for randomized arrays. `rand` integration for `ndarr keywords = ["multidimensional", "matrix", "rand", "ndarray"] [dependencies] -ndarray = { version = "0.13", path = ".." } -rand_distr = "0.2.1" -quickcheck = { version = "0.9", default-features = false, optional = true } +ndarray = { workspace = true } -[dependencies.rand] -version = "0.7.0" -features = ["small_rng"] +rand = { workspace = true } +rand_distr = { workspace = true } +quickcheck = { workspace = true, optional = true } [dev-dependencies] -rand_isaac = "0.2.0" -quickcheck = { version = "0.9", default-features = false } +rand_isaac = "0.3.0" +quickcheck = { workspace = true } [package.metadata.release] no-dev-version = true diff --git a/ndarray-rand/README.md b/ndarray-rand/README.md index e993440df..0109e9732 100644 --- a/ndarray-rand/README.md +++ b/ndarray-rand/README.md @@ -26,9 +26,9 @@ fn main() { Dependencies ============ -``ndarray-rand`` depends on ``rand`` 0.7. +``ndarray-rand`` depends on ``rand``. -[`rand`](https://docs.rs/rand/0.7.0/rand/) and [`rand-distr`](https://docs.rs/rand_distr/0.2.1/rand_distr/) are +[`rand`](https://docs.rs/rand/) and [`rand-distr`](https://docs.rs/rand_distr/) are re-exported as sub-modules, `ndarray_rand::rand` and `ndarray_rand::rand_distr` respectively. Please rely on these submodules for guaranteed version compatibility. @@ -41,15 +41,7 @@ necessary trait). Recent changes ============== -0.10.0 ------- - - - Require `rand` 0.7 - - Require Rust 1.32 or later - - Re-export `rand` as a submodule, `ndarray_rand::rand` - - Re-export `rand-distr` as a submodule, `ndarray_rand::rand_distr` - -Check _[Changelogs](https://github.com/rust-ndarray/ndarray/ndarray-rand/RELEASES.md)_ to see +Check _[RELEASES.md](https://github.com/rust-ndarray/ndarray/blob/master/ndarray-rand/RELEASES.md)_ to see the changes introduced in previous releases. diff --git a/ndarray-rand/RELEASES.md b/ndarray-rand/RELEASES.md index feea2dbce..cff9acd96 100644 --- a/ndarray-rand/RELEASES.md +++ b/ndarray-rand/RELEASES.md @@ -1,6 +1,30 @@ Recent Changes -------------- +- 0.15.0 + + - Require ndarray 0.16 + - Remove deprecated F32 by [@bluss](https://github.com/bluss) [#1409](https://github.com/rust-ndarray/ndarray/pull/1409) + +- 0.14.0 + + - Require ndarray 0.15 + - Require rand 0.8 (unchanged from previous version) + - The F32 wrapper is now deprecated, it's redundant + +- 0.13.0 + + - Require ndarray 0.14 (unchanged from previous version) + - Require rand 0.8 + - Require rand_distr 0.4 + - Fix methods `sample_axis` and `sample_axis_using` so that they can be used on array views too. + +- 0.12.0 + + - Require ndarray 0.14 + - Require rand 0.7 (unchanged from previous version) + - Require rand_distr 0.3 + - 0.11.0 - Require ndarray 0.13 diff --git a/ndarray-rand/benches/bench.rs b/ndarray-rand/benches/bench.rs index bdd010bc1..0e5eb2ff7 100644 --- a/ndarray-rand/benches/bench.rs +++ b/ndarray-rand/benches/bench.rs @@ -4,26 +4,28 @@ extern crate test; use ndarray::Array; use ndarray_rand::RandomExt; -use ndarray_rand::F32; use rand_distr::Normal; use rand_distr::Uniform; use test::Bencher; #[bench] -fn uniform_f32(b: &mut Bencher) { +fn uniform_f32(b: &mut Bencher) +{ let m = 100; b.iter(|| Array::random((m, m), Uniform::new(-1f32, 1.))); } #[bench] -fn norm_f32(b: &mut Bencher) { +fn norm_f32(b: &mut Bencher) +{ let m = 100; - b.iter(|| Array::random((m, m), F32(Normal::new(0., 1.).unwrap()))); + b.iter(|| Array::random((m, m), Normal::new(0f32, 1.).unwrap())); } #[bench] -fn norm_f64(b: &mut Bencher) { +fn norm_f64(b: &mut Bencher) +{ let m = 100; - b.iter(|| Array::random((m, m), Normal::new(0., 1.).unwrap())); + b.iter(|| Array::random((m, m), Normal::new(0f64, 1.).unwrap())); } diff --git a/ndarray-rand/src/lib.rs b/ndarray-rand/src/lib.rs index 63cf1c397..6671ab334 100644 --- a/ndarray-rand/src/lib.rs +++ b/ndarray-rand/src/lib.rs @@ -8,20 +8,20 @@ //! Constructors for randomized arrays: `rand` integration for `ndarray`. //! -//! See [**`RandomExt`**](trait.RandomExt.html) for usage examples. +//! See **[`RandomExt`]** for usage examples. //! //! ## Note //! -//! `ndarray-rand` depends on [`rand` 0.7][rand]. +//! `ndarray-rand` depends on [`rand` 0.8][rand]. //! //! [`rand`][rand] and [`rand_distr`][rand_distr] -//! are re-exported as sub-modules, [`ndarray_rand::rand`](rand/index.html) -//! and [`ndarray_rand::rand_distr`](rand_distr/index.html) respectively. +//! are re-exported as sub-modules, [`ndarray_rand::rand`](rand) +//! and [`ndarray_rand::rand_distr`](rand_distr) respectively. //! You can use these submodules for guaranteed version compatibility or //! convenience. //! -//! [rand]: https://docs.rs/rand/0.7 -//! [rand_distr]: https://docs.rs/rand_distr/0.2 +//! [rand]: https://docs.rs/rand/0.8 +//! [rand_distr]: https://docs.rs/rand_distr/0.4 //! //! If you want to use a random number generator or distribution from another crate //! with `ndarray-rand`, you need to make sure that the other crate also depends on the @@ -35,17 +35,19 @@ use crate::rand::seq::index; use crate::rand::{thread_rng, Rng, SeedableRng}; use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder}; -use ndarray::{ArrayBase, DataOwned, Dimension}; +use ndarray::{ArrayBase, Data, DataOwned, Dimension, RawData}; #[cfg(feature = "quickcheck")] use quickcheck::{Arbitrary, Gen}; -/// [`rand`](https://docs.rs/rand/0.7), re-exported for convenience and version-compatibility. -pub mod rand { +/// `rand`, re-exported for convenience and version-compatibility. +pub mod rand +{ pub use rand::*; } -/// [`rand-distr`](https://docs.rs/rand_distr/0.2), re-exported for convenience and version-compatibility. -pub mod rand_distr { +/// `rand-distr`, re-exported for convenience and version-compatibility. +pub mod rand_distr +{ pub use rand_distr::*; } @@ -55,16 +57,15 @@ pub mod rand_distr { /// for other types. /// /// The default RNG is a fast automatically seeded rng (currently -/// [`rand::rngs::SmallRng`](https://docs.rs/rand/0.7/rand/rngs/struct.SmallRng.html) -/// seeded from [`rand::thread_rng`](https://docs.rs/rand/0.7/rand/fn.thread_rng.html)). +/// [`rand::rngs::SmallRng`], seeded from [`rand::thread_rng`]). /// /// Note that `SmallRng` is cheap to initialize and fast, but it may generate /// low-quality random numbers, and reproducibility is not guaranteed. See its /// documentation for information. You can select a different RNG with -/// [`.random_using()`](#tymethod.random_using). +/// [`.random_using()`](Self::random_using). pub trait RandomExt where - S: DataOwned, + S: RawData, D: Dimension, { /// Create an array with shape `dim` with elements drawn from @@ -88,6 +89,7 @@ where fn random(shape: Sh, distribution: IdS) -> ArrayBase where IdS: Distribution, + S: DataOwned, Sh: ShapeBuilder; /// Create an array with shape `dim` with elements drawn from @@ -118,6 +120,7 @@ where where IdS: Distribution, R: Rng + ?Sized, + S: DataOwned, Sh: ShapeBuilder; /// Sample `n_samples` lanes slicing along `axis` using the default RNG. @@ -164,6 +167,7 @@ where fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array where A: Copy, + S: Data, D: RemoveAxis; /// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`. @@ -215,26 +219,24 @@ where /// # } /// ``` fn sample_axis_using( - &self, - axis: Axis, - n_samples: usize, - strategy: SamplingStrategy, - rng: &mut R, + &self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R, ) -> Array where R: Rng + ?Sized, A: Copy, + S: Data, D: RemoveAxis; } impl RandomExt for ArrayBase where - S: DataOwned, + S: RawData, D: Dimension, { fn random(shape: Sh, dist: IdS) -> ArrayBase where IdS: Distribution, + S: DataOwned, Sh: ShapeBuilder, { Self::random_using(shape, dist, &mut get_rng()) @@ -244,6 +246,7 @@ where where IdS: Distribution, R: Rng + ?Sized, + S: DataOwned, Sh: ShapeBuilder, { Self::from_shape_simple_fn(shape, move || dist.sample(rng)) @@ -252,21 +255,17 @@ where fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array where A: Copy, + S: Data, D: RemoveAxis, { self.sample_axis_using(axis, n_samples, strategy, &mut get_rng()) } - fn sample_axis_using( - &self, - axis: Axis, - n_samples: usize, - strategy: SamplingStrategy, - rng: &mut R, - ) -> Array + fn sample_axis_using(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R) -> Array where R: Rng + ?Sized, A: Copy, + S: Data, D: RemoveAxis, { let indices: Vec<_> = match strategy { @@ -274,9 +273,7 @@ where let distribution = Uniform::from(0..self.len_of(axis)); (0..n_samples).map(|_| distribution.sample(rng)).collect() } - SamplingStrategy::WithoutReplacement => { - index::sample(rng, self.len_of(axis), n_samples).into_vec() - } + SamplingStrategy::WithoutReplacement => index::sample(rng, self.len_of(axis), n_samples).into_vec(), }; self.select(axis, &indices) } @@ -286,19 +283,22 @@ where /// if lanes from the original array should only be sampled once (*without replacement*) or /// multiple times (*with replacement*). /// -/// [`sample_axis`]: trait.RandomExt.html#tymethod.sample_axis -/// [`sample_axis_using`]: trait.RandomExt.html#tymethod.sample_axis_using +/// [`sample_axis`]: RandomExt::sample_axis +/// [`sample_axis_using`]: RandomExt::sample_axis_using #[derive(Debug, Clone)] -pub enum SamplingStrategy { +pub enum SamplingStrategy +{ WithReplacement, WithoutReplacement, } // `Arbitrary` enables `quickcheck` to generate random `SamplingStrategy` values for testing. #[cfg(feature = "quickcheck")] -impl Arbitrary for SamplingStrategy { - fn arbitrary(g: &mut G) -> Self { - if g.gen_bool(0.5) { +impl Arbitrary for SamplingStrategy +{ + fn arbitrary(g: &mut Gen) -> Self + { + if bool::arbitrary(g) { SamplingStrategy::WithReplacement } else { SamplingStrategy::WithoutReplacement @@ -306,33 +306,7 @@ impl Arbitrary for SamplingStrategy { } } -fn get_rng() -> SmallRng { - SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed") -} - -/// A wrapper type that allows casting f64 distributions to f32 -/// -/// ``` -/// use ndarray::Array; -/// use ndarray_rand::{RandomExt, F32}; -/// use ndarray_rand::rand_distr::Normal; -/// -/// # fn main() { -/// let distribution_f64 = Normal::new(0., 1.).expect("Failed to create normal distribution"); -/// let a = Array::random((2, 5), F32(distribution_f64)); -/// println!("{:8.4}", a); -/// // Example Output: -/// // [[ -0.6910, 1.1730, 1.0902, -0.4092, -1.7340], -/// // [ -0.6810, 0.1678, -0.9487, 0.3150, 1.2981]] -/// # } -#[derive(Copy, Clone, Debug)] -pub struct F32(pub S); - -impl Distribution for F32 -where - S: Distribution, +fn get_rng() -> SmallRng { - fn sample(&self, rng: &mut R) -> f32 { - self.0.sample(rng) as f32 - } + SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed") } diff --git a/ndarray-rand/tests/tests.rs b/ndarray-rand/tests/tests.rs index f7860ac12..d38e8636e 100644 --- a/ndarray-rand/tests/tests.rs +++ b/ndarray-rand/tests/tests.rs @@ -5,10 +5,11 @@ use ndarray_rand::rand::{distributions::Distribution, thread_rng}; use ndarray::ShapeBuilder; use ndarray_rand::rand_distr::Uniform; use ndarray_rand::{RandomExt, SamplingStrategy}; -use quickcheck::quickcheck; +use quickcheck::{quickcheck, TestResult}; #[test] -fn test_dim() { +fn test_dim() +{ let (mm, nn) = (5, 5); for m in 0..mm { for n in 0..nn { @@ -22,7 +23,8 @@ fn test_dim() { } #[test] -fn test_dim_f() { +fn test_dim_f() +{ let (mm, nn) = (5, 5); for m in 0..mm { for n in 0..nn { @@ -35,16 +37,29 @@ fn test_dim_f() { } } +#[test] +fn sample_axis_on_view() +{ + let m = 5; + let a = Array::random((m, 4), Uniform::new(0., 2.)); + let _samples = a + .view() + .sample_axis(Axis(0), m, SamplingStrategy::WithoutReplacement); +} + #[test] #[should_panic] -fn oversampling_without_replacement_should_panic() { +fn oversampling_without_replacement_should_panic() +{ let m = 5; let a = Array::random((m, 4), Uniform::new(0., 2.)); let _samples = a.sample_axis(Axis(0), m + 1, SamplingStrategy::WithoutReplacement); } quickcheck! { - fn oversampling_with_replacement_is_fine(m: usize, n: usize) -> bool { + #[cfg_attr(miri, ignore)] // Takes an insufferably long time + fn oversampling_with_replacement_is_fine(m: u8, n: u8) -> TestResult { + let (m, n) = (m as usize, n as usize); let a = Array::random((m, n), Uniform::new(0., 2.)); // Higher than the length of both axes let n_samples = m + n + 1; @@ -52,24 +67,29 @@ quickcheck! { // We don't want to deal with sampling from 0-length axes in this test if m != 0 { if !sampling_works(&a, SamplingStrategy::WithReplacement, Axis(0), n_samples) { - return false; + return TestResult::failed(); } + } else { + return TestResult::discard(); } // We don't want to deal with sampling from 0-length axes in this test if n != 0 { if !sampling_works(&a, SamplingStrategy::WithReplacement, Axis(1), n_samples) { - return false; + return TestResult::failed(); } + } else { + return TestResult::discard(); } - - true + TestResult::passed() } } #[cfg(feature = "quickcheck")] quickcheck! { - fn sampling_behaves_as_expected(m: usize, n: usize, strategy: SamplingStrategy) -> bool { + #[cfg_attr(miri, ignore)] // This takes *forever* with Miri + fn sampling_behaves_as_expected(m: u8, n: u8, strategy: SamplingStrategy) -> TestResult { + let (m, n) = (m as usize, n as usize); let a = Array::random((m, n), Uniform::new(0., 2.)); let mut rng = &mut thread_rng(); @@ -77,42 +97,44 @@ quickcheck! { if m != 0 { let n_row_samples = Uniform::from(1..m+1).sample(&mut rng); if !sampling_works(&a, strategy.clone(), Axis(0), n_row_samples) { - return false; + return TestResult::failed(); } + } else { + return TestResult::discard(); } // We don't want to deal with sampling from 0-length axes in this test if n != 0 { let n_col_samples = Uniform::from(1..n+1).sample(&mut rng); if !sampling_works(&a, strategy, Axis(1), n_col_samples) { - return false; + return TestResult::failed(); } + } else { + return TestResult::discard(); } - true + TestResult::passed() } } -fn sampling_works( - a: &Array2, - strategy: SamplingStrategy, - axis: Axis, - n_samples: usize, -) -> bool { +fn sampling_works(a: &Array2, strategy: SamplingStrategy, axis: Axis, n_samples: usize) -> bool +{ let samples = a.sample_axis(axis, n_samples, strategy); samples .axis_iter(axis) - .all(|lane| is_subset(&a, &lane, axis)) + .all(|lane| is_subset(a, &lane, axis)) } // Check if, when sliced along `axis`, there is at least one lane in `a` equal to `b` -fn is_subset(a: &Array2, b: &ArrayView1, axis: Axis) -> bool { +fn is_subset(a: &Array2, b: &ArrayView1, axis: Axis) -> bool +{ a.axis_iter(axis).any(|lane| &lane == b) } #[test] #[should_panic] -fn sampling_without_replacement_from_a_zero_length_axis_should_panic() { +fn sampling_without_replacement_from_a_zero_length_axis_should_panic() +{ let n = 5; let a = Array::random((0, n), Uniform::new(0., 2.)); let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithoutReplacement); @@ -120,7 +142,8 @@ fn sampling_without_replacement_from_a_zero_length_axis_should_panic() { #[test] #[should_panic] -fn sampling_with_replacement_from_a_zero_length_axis_should_panic() { +fn sampling_with_replacement_from_a_zero_length_axis_should_panic() +{ let n = 5; let a = Array::random((0, n), Uniform::new(0., 2.)); let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithReplacement); diff --git a/numeric-tests/Cargo.toml b/numeric-tests/Cargo.toml deleted file mode 100644 index dc1261512..000000000 --- a/numeric-tests/Cargo.toml +++ /dev/null @@ -1,26 +0,0 @@ -[package] -name = "numeric-tests" -version = "0.1.0" -authors = ["bluss"] -publish = false - -[dependencies] -approx = "0.3.2" -ndarray = { path = "..", features = ["approx"] } -ndarray-rand = { path = "../ndarray-rand/" } -rand_distr = "0.2.1" - -[dependencies.rand] -version = "0.7.0" -features = ["small_rng"] - -[lib] -test = false - -[features] -test_blas = ["ndarray/blas", "ndarray/test-blas-openblas-sys"] - -[profile.dev] -opt-level = 2 -[profile.test] -opt-level = 2 diff --git a/numeric-tests/src/lib.rs b/numeric-tests/src/lib.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/numeric-tests/tests/accuracy.rs b/numeric-tests/tests/accuracy.rs deleted file mode 100644 index 4f61248df..000000000 --- a/numeric-tests/tests/accuracy.rs +++ /dev/null @@ -1,263 +0,0 @@ -extern crate approx; -extern crate rand_distr; -extern crate ndarray; -extern crate ndarray_rand; -extern crate rand; - -use ndarray_rand::{RandomExt, F32}; -use rand::{Rng, SeedableRng}; -use rand::rngs::SmallRng; - -use ndarray::prelude::*; -use ndarray::{ - Data, - LinalgScalar, -}; -use ndarray::linalg::general_mat_mul; - -use rand_distr::Normal; - -use approx::{assert_abs_diff_eq, assert_relative_eq}; - -// simple, slow, correct (hopefully) mat mul -fn reference_mat_mul(lhs: &ArrayBase, rhs: &ArrayBase) - -> Array - where A: LinalgScalar, - S: Data, - S2: Data, -{ - let ((m, k), (_, n)) = (lhs.dim(), rhs.dim()); - let mut res_elems = Vec::::with_capacity(m * n); - unsafe { - res_elems.set_len(m * n); - } - - let mut i = 0; - let mut j = 0; - for rr in &mut res_elems { - unsafe { - *rr = (0..k).fold(A::zero(), - move |s, x| s + *lhs.uget((i, x)) * *rhs.uget((x, j))); - } - j += 1; - if j == n { - j = 0; - i += 1; - } - } - unsafe { - ArrayBase::from_shape_vec_unchecked((m, n), res_elems) - } -} - -fn gen(d: D) -> Array - where D: Dimension, -{ - Array::random(d, F32(Normal::new(0., 1.).unwrap())) -} -fn gen_f64(d: D) -> Array - where D: Dimension, -{ - Array::random(d, Normal::new(0., 1.).unwrap()) -} - -#[test] -fn accurate_eye_f32() { - for i in 0..20 { - let eye = Array::eye(i); - for j in 0..20 { - let a = gen(Ix2(i, j)); - let a2 = eye.dot(&a); - assert_abs_diff_eq!(a, a2, epsilon = 1e-6); - let a3 = a.t().dot(&eye); - assert_abs_diff_eq!(a.t(), a3, epsilon = 1e-6); - } - } - // pick a few random sizes - let mut rng = SmallRng::from_entropy(); - for _ in 0..10 { - let i = rng.gen_range(15, 512); - let j = rng.gen_range(15, 512); - println!("Testing size {} by {}", i, j); - let a = gen(Ix2(i, j)); - let eye = Array::eye(i); - let a2 = eye.dot(&a); - assert_abs_diff_eq!(a, a2, epsilon = 1e-6); - let a3 = a.t().dot(&eye); - assert_abs_diff_eq!(a.t(), a3, epsilon = 1e-6); - } -} - -#[test] -fn accurate_eye_f64() { - let abs_tol = 1e-15; - for i in 0..20 { - let eye = Array::eye(i); - for j in 0..20 { - let a = gen_f64(Ix2(i, j)); - let a2 = eye.dot(&a); - assert_abs_diff_eq!(a, a2, epsilon = abs_tol); - let a3 = a.t().dot(&eye); - assert_abs_diff_eq!(a.t(), a3, epsilon = abs_tol); - } - } - // pick a few random sizes - let mut rng = SmallRng::from_entropy(); - for _ in 0..10 { - let i = rng.gen_range(15, 512); - let j = rng.gen_range(15, 512); - println!("Testing size {} by {}", i, j); - let a = gen_f64(Ix2(i, j)); - let eye = Array::eye(i); - let a2 = eye.dot(&a); - assert_abs_diff_eq!(a, a2, epsilon = 1e-6); - let a3 = a.t().dot(&eye); - assert_abs_diff_eq!(a.t(), a3, epsilon = 1e-6); - } -} - -#[test] -fn accurate_mul_f32() { - // pick a few random sizes - let mut rng = SmallRng::from_entropy(); - for i in 0..20 { - let m = rng.gen_range(15, 512); - let k = rng.gen_range(15, 512); - let n = rng.gen_range(15, 1560); - let a = gen(Ix2(m, k)); - let b = gen(Ix2(n, k)); - let b = b.t(); - let (a, b) = if i > 10 { - (a.slice(s![..;2, ..;2]), - b.slice(s![..;2, ..;2])) - } else { (a.view(), b) }; - - println!("Testing size {} by {} by {}", a.shape()[0], a.shape()[1], b.shape()[1]); - let c = a.dot(&b); - let reference = reference_mat_mul(&a, &b); - - assert_relative_eq!(c, reference, epsilon = 1e-4, max_relative = 1e-3); - } -} - -#[test] -fn accurate_mul_f32_general() { - // pick a few random sizes - let mut rng = SmallRng::from_entropy(); - for i in 0..20 { - let m = rng.gen_range(15, 512); - let k = rng.gen_range(15, 512); - let n = rng.gen_range(15, 1560); - let a = gen(Ix2(m, k)); - let b = gen(Ix2(n, k)); - let mut c = gen(Ix2(m, n)); - let b = b.t(); - let (a, b, mut c) = if i > 10 { - (a.slice(s![..;2, ..;2]), - b.slice(s![..;2, ..;2]), - c.slice_mut(s![..;2, ..;2])) - } else { (a.view(), b, c.view_mut()) }; - - println!("Testing size {} by {} by {}", a.shape()[0], a.shape()[1], b.shape()[1]); - general_mat_mul(1., &a, &b, 0., &mut c); - let reference = reference_mat_mul(&a, &b); - - assert_relative_eq!(c, reference, epsilon = 1e-4, max_relative = 1e-3); - } -} - -#[test] -fn accurate_mul_f64() { - // pick a few random sizes - let mut rng = SmallRng::from_entropy(); - for i in 0..20 { - let m = rng.gen_range(15, 512); - let k = rng.gen_range(15, 512); - let n = rng.gen_range(15, 1560); - let a = gen_f64(Ix2(m, k)); - let b = gen_f64(Ix2(n, k)); - let b = b.t(); - let (a, b) = if i > 10 { - (a.slice(s![..;2, ..;2]), - b.slice(s![..;2, ..;2])) - } else { (a.view(), b) }; - - println!("Testing size {} by {} by {}", a.shape()[0], a.shape()[1], b.shape()[1]); - let c = a.dot(&b); - let reference = reference_mat_mul(&a, &b); - - assert_relative_eq!(c, reference, epsilon = 1e-12, max_relative = 1e-7); - } -} - -#[test] -fn accurate_mul_f64_general() { - // pick a few random sizes - let mut rng = SmallRng::from_entropy(); - for i in 0..20 { - let m = rng.gen_range(15, 512); - let k = rng.gen_range(15, 512); - let n = rng.gen_range(15, 1560); - let a = gen_f64(Ix2(m, k)); - let b = gen_f64(Ix2(n, k)); - let mut c = gen_f64(Ix2(m, n)); - let b = b.t(); - let (a, b, mut c) = if i > 10 { - (a.slice(s![..;2, ..;2]), - b.slice(s![..;2, ..;2]), - c.slice_mut(s![..;2, ..;2])) - } else { (a.view(), b, c.view_mut()) }; - - println!("Testing size {} by {} by {}", a.shape()[0], a.shape()[1], b.shape()[1]); - general_mat_mul(1., &a, &b, 0., &mut c); - let reference = reference_mat_mul(&a, &b); - - assert_relative_eq!(c, reference, epsilon = 1e-12, max_relative = 1e-7); - } -} - -#[test] -fn accurate_mul_with_column_f64() { - // pick a few random sizes - let mut rng = SmallRng::from_entropy(); - for i in 0..10 { - let m = rng.gen_range(1, 350); - let k = rng.gen_range(1, 350); - let a = gen_f64(Ix2(m, k)); - let b_owner = gen_f64(Ix2(k, k)); - let b_row_col; - let b_sq; - - // pick dense square or broadcasted to square matrix - match i { - 0 ..= 3 => b_sq = b_owner.view(), - 4 ..= 7 => { - b_row_col = b_owner.column(0); - b_sq = b_row_col.broadcast((k, k)).unwrap(); - } - _otherwise => { - b_row_col = b_owner.row(0); - b_sq = b_row_col.broadcast((k, k)).unwrap(); - } - }; - - for j in 0..k { - for &flip in &[true, false] { - let j = j as isize; - let b = if flip { - // one row in 2D - b_sq.slice(s![j..j + 1, ..]).reversed_axes() - } else { - // one column in 2D - b_sq.slice(s![.., j..j + 1]) - }; - println!("Testing size ({} × {}) by ({} × {})", a.shape()[0], a.shape()[1], b.shape()[0], b.shape()[1]); - println!("Strides ({:?}) by ({:?})", a.strides(), b.strides()); - let c = a.dot(&b); - let reference = reference_mat_mul(&a, &b); - - assert_relative_eq!(c, reference, epsilon = 1e-12, max_relative = 1e-7); - } - } - } -} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 000000000..f3e376ccc --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,26 @@ +edition = "2018" +array_width = 100 +chain_width = 60 +fn_call_width = 100 +max_width = 120 +brace_style = "AlwaysNextLine" +control_brace_style = "AlwaysSameLine" +fn_params_layout = "Compressed" # ? +format_macro_bodies = false +imports_granularity = "Preserve" +imports_indent = "Block" +imports_layout = "HorizontalVertical" +inline_attribute_width = 0 +indent_style = "Block" +match_arm_blocks = false +match_arm_leading_pipes = "Preserve" +merge_derives = false +overflow_delimited_expr = true +reorder_modules = false # impacts rustdoc order +short_array_element_width_threshold = 32 +skip_macro_invocations = ["*"] +unstable_features = true +where_single_line = true + +# ignored files +ignore = [] diff --git a/scripts/all-tests.sh b/scripts/all-tests.sh index 9b41b41d8..e98b90df1 100755 --- a/scripts/all-tests.sh +++ b/scripts/all-tests.sh @@ -6,17 +6,29 @@ set -e FEATURES=$1 CHANNEL=$2 -([ "$CHANNEL" != "beta" ] || (rustup component add rustfmt && cargo fmt --all -- --check)) -cargo build --verbose --no-default-features -# Testing both dev and release profiles helps find bugs, especially in low level code -cargo test --verbose --no-default-features -cargo test --release --verbose --no-default-features -cargo build --verbose --features "$FEATURES" -cargo test --verbose --features "$FEATURES" -cargo test --manifest-path=ndarray-rand/Cargo.toml --no-default-features --verbose -cargo test --manifest-path=ndarray-rand/Cargo.toml --features quickcheck --verbose -cargo test --manifest-path=serialization-tests/Cargo.toml --verbose -cargo test --manifest-path=blas-tests/Cargo.toml --verbose -CARGO_TARGET_DIR=target/ cargo test --manifest-path=numeric-tests/Cargo.toml --verbose -([ "$CHANNEL" != "beta" ] || (rustup component add clippy && cargo clippy)) +QC_FEAT=--features=ndarray-rand/quickcheck + +# build check with no features +cargo build -v --no-default-features + +# ndarray with no features +cargo test -p ndarray -v --no-default-features +# ndarray with no_std-compatible features +cargo test -p ndarray -v --no-default-features --features approx +# all with features +cargo test -v --features "$FEATURES" $QC_FEAT +# all with features and release (ignore test crates which is already optimized) +cargo test -v -p ndarray -p ndarray-rand --release --features "$FEATURES" $QC_FEAT --lib --tests + +# BLAS tests +cargo test -p ndarray --lib -v --features blas +cargo test -p blas-mock-tests -v +if [ "$CHANNEL" != "1.64.0" ]; then + ./scripts/blas-integ-tests.sh "$FEATURES" $CHANNEL +fi + +# Examples +cargo test --examples + +# Benchmarks ([ "$CHANNEL" != "nightly" ] || cargo bench --no-run --verbose --features "$FEATURES") diff --git a/scripts/blas-integ-tests.sh b/scripts/blas-integ-tests.sh new file mode 100755 index 000000000..fec938b83 --- /dev/null +++ b/scripts/blas-integ-tests.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +set -x +set -e + +CHANNEL=$1 + +# BLAS tests +cargo test -p blas-tests -v --features blas-tests/openblas-system +cargo test -p numeric-tests -v --features numeric-tests/test_blas diff --git a/scripts/cross-tests.sh b/scripts/cross-tests.sh new file mode 100755 index 000000000..80b37c339 --- /dev/null +++ b/scripts/cross-tests.sh @@ -0,0 +1,14 @@ +#!/bin/sh + +set -x +set -e + +FEATURES=$1 +CHANNEL=$2 +TARGET=$3 + +QC_FEAT=--features=ndarray-rand/quickcheck + +cross build -v --features="$FEATURES" $QC_FEAT --target=$TARGET +cross test -v --no-fail-fast --features="$FEATURES" $QC_FEAT --target=$TARGET +cross test -v -p blas-mock-tests diff --git a/scripts/makechangelog.sh b/scripts/makechangelog.sh new file mode 100755 index 000000000..535280804 --- /dev/null +++ b/scripts/makechangelog.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# Usage: makechangelog +# +# This script depends on and uses the github `gh` binary +# which needs to be authenticated to use. +# +# Will produce some duplicates for PRs integrated using rebase, +# but those will not occur with current merge queue. + +git log --first-parent --pretty="tformat:%H" "$@" | while IFS= read -r commit_sha +do + gh api "/repos/:owner/:repo/commits/${commit_sha}/pulls" \ + -q ".[] | \"- \(.title) by [@\(.user.login)](\(.user.html_url)) [#\(.number)](\(.html_url))\"" +done | uniq + diff --git a/scripts/miri-tests.sh b/scripts/miri-tests.sh new file mode 100755 index 000000000..0100f3e6a --- /dev/null +++ b/scripts/miri-tests.sh @@ -0,0 +1,18 @@ +#!/bin/sh + +set -x +set -e + +# We rely on layout-dependent casts, which should be covered with #[repr(transparent)] +# This should catch if we missed that +RUSTFLAGS="-Zrandomize-layout" + +# Miri reports a stacked borrow violation deep within rayon, in a crate called crossbeam-epoch +# The crate has a PR to fix this: https://github.com/crossbeam-rs/crossbeam/pull/871 +# but using Miri's tree borrow mode may resolve it for now. +# Disabled until we can figure out a different rayon issue: https://github.com/rust-lang/miri/issues/1371 +# MIRIFLAGS="-Zmiri-tree-borrows" + +# General tests +# Note that we exclude blas feature because Miri can't do cblas_gemm +cargo miri test -v -p ndarray -p ndarray-rand --features approx,serde diff --git a/serialization-tests/Cargo.toml b/serialization-tests/Cargo.toml deleted file mode 100644 index 3aaad639c..000000000 --- a/serialization-tests/Cargo.toml +++ /dev/null @@ -1,27 +0,0 @@ -[package] -name = "serialization-tests" -version = "0.1.0" -authors = ["bluss"] -publish = false - -[lib] -test = false - -[dependencies] -ndarray = { path = "../", features = ["serde"] } - -[features] -default = ["ron"] - -[dev-dependencies.serde] -version = "1.0.100" - -[dev-dependencies.serde_json] -version = "1.0.40" - -[dev-dependencies.rmp-serde] -version = "0.14.0" - -[dependencies.ron] -version = "0.5.1" -optional = true diff --git a/serialization-tests/src/lib.rs b/serialization-tests/src/lib.rs deleted file mode 100644 index 8b1378917..000000000 --- a/serialization-tests/src/lib.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/aliases.rs b/src/aliases.rs index f7c71b3d4..5df0c95ec 100644 --- a/src/aliases.rs +++ b/src/aliases.rs @@ -2,56 +2,63 @@ //! use crate::dimension::Dim; -#[allow(deprecated)] -use crate::{ArcArray, Array, ArrayView, ArrayViewMut, Ix, IxDynImpl, RcArray}; +use crate::{ArcArray, Array, ArrayView, ArrayViewMut, Ix, IxDynImpl}; /// Create a zero-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub fn Ix0() -> Ix0 { +pub const fn Ix0() -> Ix0 +{ Dim::new([]) } /// Create a one-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub fn Ix1(i0: Ix) -> Ix1 { +pub const fn Ix1(i0: Ix) -> Ix1 +{ Dim::new([i0]) } /// Create a two-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub fn Ix2(i0: Ix, i1: Ix) -> Ix2 { +pub const fn Ix2(i0: Ix, i1: Ix) -> Ix2 +{ Dim::new([i0, i1]) } /// Create a three-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub fn Ix3(i0: Ix, i1: Ix, i2: Ix) -> Ix3 { +pub const fn Ix3(i0: Ix, i1: Ix, i2: Ix) -> Ix3 +{ Dim::new([i0, i1, i2]) } /// Create a four-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub fn Ix4(i0: Ix, i1: Ix, i2: Ix, i3: Ix) -> Ix4 { +pub const fn Ix4(i0: Ix, i1: Ix, i2: Ix, i3: Ix) -> Ix4 +{ Dim::new([i0, i1, i2, i3]) } /// Create a five-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub fn Ix5(i0: Ix, i1: Ix, i2: Ix, i3: Ix, i4: Ix) -> Ix5 { +pub const fn Ix5(i0: Ix, i1: Ix, i2: Ix, i3: Ix, i4: Ix) -> Ix5 +{ Dim::new([i0, i1, i2, i3, i4]) } /// Create a six-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub fn Ix6(i0: Ix, i1: Ix, i2: Ix, i3: Ix, i4: Ix, i5: Ix) -> Ix6 { +pub const fn Ix6(i0: Ix, i1: Ix, i2: Ix, i3: Ix, i4: Ix, i5: Ix) -> Ix6 +{ Dim::new([i0, i1, i2, i3, i4, i5]) } /// Create a dynamic-dimensional index #[allow(non_snake_case)] #[inline(always)] -pub fn IxDyn(ix: &[Ix]) -> IxDyn { +pub fn IxDyn(ix: &[Ix]) -> IxDyn +{ Dim(ix) } @@ -150,15 +157,6 @@ pub type ArrayViewMut6<'a, A> = ArrayViewMut<'a, A, Ix6>; /// dynamic-dimensional read-write array view pub type ArrayViewMutD<'a, A> = ArrayViewMut<'a, A, IxDyn>; -/// one-dimensional shared ownership array -#[allow(deprecated)] -#[deprecated(note = "`RcArray` has been renamed to `ArcArray`")] -pub type RcArray1 = RcArray; -/// two-dimensional shared ownership array -#[allow(deprecated)] -#[deprecated(note = "`RcArray` has been renamed to `ArcArray`")] -pub type RcArray2 = RcArray; - /// one-dimensional shared ownership array pub type ArcArray1 = ArcArray; /// two-dimensional shared ownership array diff --git a/src/argument_traits.rs b/src/argument_traits.rs new file mode 100644 index 000000000..c4e85186a --- /dev/null +++ b/src/argument_traits.rs @@ -0,0 +1,48 @@ +use std::cell::Cell; +use std::mem::MaybeUninit; + +use crate::math_cell::MathCell; + +/// A producer element that can be assigned to once +pub trait AssignElem +{ + /// Assign the value `input` to the element that self represents. + fn assign_elem(self, input: T); +} + +/// Assignable element, simply `*self = input`. +impl AssignElem for &mut T +{ + fn assign_elem(self, input: T) + { + *self = input; + } +} + +/// Assignable element, simply `self.set(input)`. +impl AssignElem for &Cell +{ + fn assign_elem(self, input: T) + { + self.set(input); + } +} + +/// Assignable element, simply `self.set(input)`. +impl AssignElem for &MathCell +{ + fn assign_elem(self, input: T) + { + self.set(input); + } +} + +/// Assignable element, the item in the MaybeUninit is overwritten (prior value, if any, is not +/// read or dropped). +impl AssignElem for &mut MaybeUninit +{ + fn assign_elem(self, input: T) + { + *self = MaybeUninit::new(input); + } +} diff --git a/src/array_approx.rs b/src/array_approx.rs index 82c95a224..493864c7e 100644 --- a/src/array_approx.rs +++ b/src/array_approx.rs @@ -1,139 +1,195 @@ -use crate::imp_prelude::*; -use crate::Zip; -use approx::{AbsDiffEq, RelativeEq, UlpsEq}; - -/// **Requires crate feature `"approx"`** -impl AbsDiffEq> for ArrayBase -where - A: AbsDiffEq, - A::Epsilon: Clone, - S: Data, - S2: Data, - D: Dimension, +#[cfg(feature = "approx")] +mod approx_methods { - type Epsilon = A::Epsilon; - - fn default_epsilon() -> A::Epsilon { - A::default_epsilon() - } - - fn abs_diff_eq(&self, other: &ArrayBase, epsilon: A::Epsilon) -> bool { - if self.shape() != other.shape() { - return false; + use crate::imp_prelude::*; + + impl ArrayBase + where + S: Data, + D: Dimension, + { + /// A test for equality that uses the elementwise absolute difference to compute the + /// approximate equality of two arrays. + /// + /// **Requires crate feature `"approx"`** + pub fn abs_diff_eq(&self, other: &ArrayBase, epsilon: A::Epsilon) -> bool + where + A: ::approx::AbsDiffEq, + A::Epsilon: Clone, + S2: Data, + { + >::abs_diff_eq(self, other, epsilon) } - Zip::from(self) - .and(other) - .all(|a, b| A::abs_diff_eq(a, b, epsilon.clone())) - } -} -/// **Requires crate feature `"approx"`** -impl RelativeEq> for ArrayBase -where - A: RelativeEq, - A::Epsilon: Clone, - S: Data, - S2: Data, - D: Dimension, -{ - fn default_max_relative() -> A::Epsilon { - A::default_max_relative() - } - - fn relative_eq( - &self, - other: &ArrayBase, - epsilon: A::Epsilon, - max_relative: A::Epsilon, - ) -> bool { - if self.shape() != other.shape() { - return false; + /// A test for equality that uses an elementwise relative comparison if the values are far + /// apart; and the absolute difference otherwise. + /// + /// **Requires crate feature `"approx"`** + pub fn relative_eq(&self, other: &ArrayBase, epsilon: A::Epsilon, max_relative: A::Epsilon) -> bool + where + A: ::approx::RelativeEq, + A::Epsilon: Clone, + S2: Data, + { + >::relative_eq(self, other, epsilon, max_relative) } - Zip::from(self) - .and(other) - .all(|a, b| A::relative_eq(a, b, epsilon.clone(), max_relative.clone())) } } -/// **Requires crate feature `"approx"`** -impl UlpsEq> for ArrayBase -where - A: UlpsEq, - A::Epsilon: Clone, - S: Data, - S2: Data, - D: Dimension, -{ - fn default_max_ulps() -> u32 { - A::default_max_ulps() - } - - fn ulps_eq(&self, other: &ArrayBase, epsilon: A::Epsilon, max_ulps: u32) -> bool { - if self.shape() != other.shape() { - return false; +macro_rules! impl_approx_traits { + ($approx:ident, $doc:expr) => { + mod $approx { + use crate::imp_prelude::*; + use crate::Zip; + use $approx::{AbsDiffEq, RelativeEq, UlpsEq}; + + #[doc = $doc] + impl AbsDiffEq> for ArrayBase + where + A: AbsDiffEq, + A::Epsilon: Clone, + S: Data, + S2: Data, + D: Dimension, + { + type Epsilon = A::Epsilon; + + fn default_epsilon() -> A::Epsilon { + A::default_epsilon() + } + + fn abs_diff_eq(&self, other: &ArrayBase, epsilon: A::Epsilon) -> bool { + if self.shape() != other.shape() { + return false; + } + + Zip::from(self) + .and(other) + .all(move |a, b| A::abs_diff_eq(a, b, epsilon.clone())) + } + } + + #[doc = $doc] + impl RelativeEq> for ArrayBase + where + A: RelativeEq, + A::Epsilon: Clone, + S: Data, + S2: Data, + D: Dimension, + { + fn default_max_relative() -> A::Epsilon { + A::default_max_relative() + } + + fn relative_eq( + &self, + other: &ArrayBase, + epsilon: A::Epsilon, + max_relative: A::Epsilon, + ) -> bool { + if self.shape() != other.shape() { + return false; + } + + Zip::from(self).and(other).all(move |a, b| { + A::relative_eq(a, b, epsilon.clone(), max_relative.clone()) + }) + } + } + + #[doc = $doc] + impl UlpsEq> for ArrayBase + where + A: UlpsEq, + A::Epsilon: Clone, + S: Data, + S2: Data, + D: Dimension, + { + fn default_max_ulps() -> u32 { + A::default_max_ulps() + } + + fn ulps_eq( + &self, + other: &ArrayBase, + epsilon: A::Epsilon, + max_ulps: u32, + ) -> bool { + if self.shape() != other.shape() { + return false; + } + + Zip::from(self) + .and(other) + .all(move |a, b| A::ulps_eq(a, b, epsilon.clone(), max_ulps)) + } + } + + #[cfg(test)] + mod tests { + use crate::prelude::*; + use alloc::vec; + use $approx::{ + assert_abs_diff_eq, assert_abs_diff_ne, assert_relative_eq, assert_relative_ne, + assert_ulps_eq, assert_ulps_ne, + }; + + #[test] + fn abs_diff_eq() { + let a: Array2 = array![[0., 2.], [-0.000010001, 100000000.]]; + let mut b: Array2 = array![[0., 1.], [-0.000010002, 100000001.]]; + assert_abs_diff_ne!(a, b); + b[(0, 1)] = 2.; + assert_abs_diff_eq!(a, b); + + // Check epsilon. + assert_abs_diff_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32); + assert_abs_diff_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32); + + // Make sure we can compare different shapes without failure. + let c = array![[1., 2.]]; + assert_abs_diff_ne!(a, c); + } + + #[test] + fn relative_eq() { + let a: Array2 = array![[1., 2.], [-0.000010001, 100000000.]]; + let mut b: Array2 = array![[1., 1.], [-0.000010002, 100000001.]]; + assert_relative_ne!(a, b); + b[(0, 1)] = 2.; + assert_relative_eq!(a, b); + + // Check epsilon. + assert_relative_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32); + assert_relative_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32); + + // Make sure we can compare different shapes without failure. + let c = array![[1., 2.]]; + assert_relative_ne!(a, c); + } + + #[test] + fn ulps_eq() { + let a: Array2 = array![[1., 2.], [-0.000010001, 100000000.]]; + let mut b: Array2 = array![[1., 1.], [-0.000010002, 100000001.]]; + assert_ulps_ne!(a, b); + b[(0, 1)] = 2.; + assert_ulps_eq!(a, b); + + // Check epsilon. + assert_ulps_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32); + assert_ulps_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32); + + // Make sure we can compare different shapes without failure. + let c = array![[1., 2.]]; + assert_ulps_ne!(a, c); + } + } } - Zip::from(self) - .and(other) - .all(|a, b| A::ulps_eq(a, b, epsilon.clone(), max_ulps)) - } -} - -#[cfg(test)] -mod tests { - use crate::prelude::*; - use approx::{ - assert_abs_diff_eq, assert_abs_diff_ne, assert_relative_eq, assert_relative_ne, - assert_ulps_eq, assert_ulps_ne, }; - - #[test] - fn abs_diff_eq() { - let a: Array2 = array![[0., 2.], [-0.000010001, 100000000.]]; - let mut b: Array2 = array![[0., 1.], [-0.000010002, 100000001.]]; - assert_abs_diff_ne!(a, b); - b[(0, 1)] = 2.; - assert_abs_diff_eq!(a, b); - - // Check epsilon. - assert_abs_diff_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32); - assert_abs_diff_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32); - - // Make sure we can compare different shapes without failure. - let c = array![[1., 2.]]; - assert_abs_diff_ne!(a, c); - } - - #[test] - fn relative_eq() { - let a: Array2 = array![[1., 2.], [-0.000010001, 100000000.]]; - let mut b: Array2 = array![[1., 1.], [-0.000010002, 100000001.]]; - assert_relative_ne!(a, b); - b[(0, 1)] = 2.; - assert_relative_eq!(a, b); - - // Check epsilon. - assert_relative_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32); - assert_relative_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32); - - // Make sure we can compare different shapes without failure. - let c = array![[1., 2.]]; - assert_relative_ne!(a, c); - } - - #[test] - fn ulps_eq() { - let a: Array2 = array![[1., 2.], [-0.000010001, 100000000.]]; - let mut b: Array2 = array![[1., 1.], [-0.000010002, 100000001.]]; - assert_ulps_ne!(a, b); - b[(0, 1)] = 2.; - assert_ulps_eq!(a, b); - - // Check epsilon. - assert_ulps_eq!(array![0.0f32], array![1e-40f32], epsilon = 1e-40f32); - assert_ulps_ne!(array![0.0f32], array![1e-40f32], epsilon = 1e-41f32); - - // Make sure we can compare different shapes without failure. - let c = array![[1., 2.]]; - assert_ulps_ne!(a, c); - } } + +#[cfg(feature = "approx")] +impl_approx_traits!(approx, "**Requires crate feature `\"approx\"`.**"); diff --git a/src/array_serde.rs b/src/array_serde.rs index c8b485d26..50d9c2905 100644 --- a/src/array_serde.rs +++ b/src/array_serde.rs @@ -9,6 +9,9 @@ use serde::de::{self, MapAccess, SeqAccess, Visitor}; use serde::ser::{SerializeSeq, SerializeStruct}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use alloc::format; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; use std::fmt; use std::marker::PhantomData; @@ -21,8 +24,7 @@ use crate::IntoDimension; /// Verifies that the version of the deserialized array matches the current /// `ARRAY_FORMAT_VERSION`. pub fn verify_version(v: u8) -> Result<(), E> -where - E: de::Error, +where E: de::Error { if v != ARRAY_FORMAT_VERSION { let err_msg = format!("unknown array version: {}", v); @@ -34,12 +36,10 @@ where /// **Requires crate feature `"serde"`** impl Serialize for Dim -where - I: Serialize, +where I: Serialize { fn serialize(&self, serializer: Se) -> Result - where - Se: Serializer, + where Se: Serializer { self.ix().serialize(serializer) } @@ -47,32 +47,30 @@ where /// **Requires crate feature `"serde"`** impl<'de, I> Deserialize<'de> for Dim -where - I: Deserialize<'de>, +where I: Deserialize<'de> { fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, + where D: Deserializer<'de> { I::deserialize(deserializer).map(Dim::new) } } /// **Requires crate feature `"serde"`** -impl Serialize for IxDyn { +impl Serialize for IxDyn +{ fn serialize(&self, serializer: Se) -> Result - where - Se: Serializer, + where Se: Serializer { self.ix().serialize(serializer) } } /// **Requires crate feature `"serde"`** -impl<'de> Deserialize<'de> for IxDyn { +impl<'de> Deserialize<'de> for IxDyn +{ fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, + where D: Deserializer<'de> { let v = Vec::::deserialize(deserializer)?; Ok(v.into_dimension()) @@ -87,8 +85,7 @@ where S: Data, { fn serialize(&self, serializer: Se) -> Result - where - Se: Serializer, + where Se: Serializer { let mut state = serializer.serialize_struct("Array", 3)?; state.serialize_field("v", &ARRAY_FORMAT_VERSION)?; @@ -101,14 +98,13 @@ where // private iterator wrapper struct Sequence<'a, A, D>(Iter<'a, A, D>); -impl<'a, A, D> Serialize for Sequence<'a, A, D> +impl Serialize for Sequence<'_, A, D> where A: Serialize, D: Dimension + Serialize, { fn serialize(&self, serializer: S) -> Result - where - S: Serializer, + where S: Serializer { let iter = &self.0; let mut seq = serializer.serialize_seq(Some(iter.len()))?; @@ -119,19 +115,23 @@ where } } -struct ArrayVisitor { +struct ArrayVisitor +{ _marker_a: PhantomData, _marker_b: PhantomData, } -enum ArrayField { +enum ArrayField +{ Version, Dim, Data, } -impl ArrayVisitor { - pub fn new() -> Self { +impl ArrayVisitor +{ + pub fn new() -> Self + { ArrayVisitor { _marker_a: PhantomData, _marker_b: PhantomData, @@ -149,30 +149,30 @@ where S: DataOwned, { fn deserialize(deserializer: D) -> Result, D::Error> - where - D: Deserializer<'de>, + where D: Deserializer<'de> { deserializer.deserialize_struct("Array", ARRAY_FIELDS, ArrayVisitor::new()) } } -impl<'de> Deserialize<'de> for ArrayField { +impl<'de> Deserialize<'de> for ArrayField +{ fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, + where D: Deserializer<'de> { struct ArrayFieldVisitor; - impl<'de> Visitor<'de> for ArrayFieldVisitor { + impl Visitor<'_> for ArrayFieldVisitor + { type Value = ArrayField; - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result + { formatter.write_str(r#""v", "dim", or "data""#) } fn visit_str(self, value: &str) -> Result - where - E: de::Error, + where E: de::Error { match value { "v" => Ok(ArrayField::Version), @@ -183,17 +183,13 @@ impl<'de> Deserialize<'de> for ArrayField { } fn visit_bytes(self, value: &[u8]) -> Result - where - E: de::Error, + where E: de::Error { match value { b"v" => Ok(ArrayField::Version), b"dim" => Ok(ArrayField::Dim), b"data" => Ok(ArrayField::Data), - other => Err(de::Error::unknown_field( - &format!("{:?}", other), - ARRAY_FIELDS, - )), + other => Err(de::Error::unknown_field(&format!("{:?}", other), ARRAY_FIELDS)), } } } @@ -210,13 +206,13 @@ where { type Value = ArrayBase; - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result + { formatter.write_str("ndarray representation") } fn visit_seq(self, mut visitor: V) -> Result, V::Error> - where - V: SeqAccess<'de>, + where V: SeqAccess<'de> { let v: u8 = match visitor.next_element()? { Some(value) => value, @@ -249,8 +245,7 @@ where } fn visit_map(self, mut visitor: V) -> Result, V::Error> - where - V: MapAccess<'de>, + where V: MapAccess<'de> { let mut v: Option = None; let mut data: Option> = None; @@ -274,17 +269,17 @@ where let _v = match v { Some(v) => v, - None => Err(de::Error::missing_field("v"))?, + None => return Err(de::Error::missing_field("v")), }; let data = match data { Some(data) => data, - None => Err(de::Error::missing_field("data"))?, + None => return Err(de::Error::missing_field("data")), }; let dim = match dim { Some(dim) => dim, - None => Err(de::Error::missing_field("dim"))?, + None => return Err(de::Error::missing_field("dim")), }; if let Ok(array) = ArrayBase::from_shape_vec(dim, data) { diff --git a/src/arrayformat.rs b/src/arrayformat.rs index a7203b38a..1a3b714c3 100644 --- a/src/arrayformat.rs +++ b/src/arrayformat.rs @@ -7,6 +7,7 @@ // except according to those terms. use super::{ArrayBase, ArrayView, Axis, Data, Dimension, NdProducer}; use crate::aliases::{Ix1, IxDyn}; +use alloc::format; use std::fmt; /// Default threshold, below this element count, we don't ellipsize @@ -28,14 +29,17 @@ const AXIS_2D_OVERFLOW_LIMIT: usize = 22; const ELLIPSIS: &str = "..."; #[derive(Clone, Debug)] -struct FormatOptions { +struct FormatOptions +{ axis_collapse_limit: usize, axis_collapse_limit_next_last: usize, axis_collapse_limit_last: usize, } -impl FormatOptions { - pub(crate) fn default_for_array(nelem: usize, no_limit: bool) -> Self { +impl FormatOptions +{ + pub(crate) fn default_for_array(nelem: usize, no_limit: bool) -> Self + { let default = Self { axis_collapse_limit: AXIS_LIMIT_STACKED, axis_collapse_limit_next_last: AXIS_LIMIT_COL, @@ -44,20 +48,20 @@ impl FormatOptions { default.set_no_limit(no_limit || nelem < ARRAY_MANY_ELEMENT_LIMIT) } - fn set_no_limit(mut self, no_limit: bool) -> Self { + fn set_no_limit(mut self, no_limit: bool) -> Self + { if no_limit { - self.axis_collapse_limit = std::usize::MAX; - self.axis_collapse_limit_next_last = std::usize::MAX; - self.axis_collapse_limit_last = std::usize::MAX; - self - } else { - self + self.axis_collapse_limit = usize::MAX; + self.axis_collapse_limit_next_last = usize::MAX; + self.axis_collapse_limit_last = usize::MAX; } + self } /// Axis length collapse limit before ellipsizing, where `axis_rindex` is /// the index of the axis from the back. - pub(crate) fn collapse_limit(&self, axis_rindex: usize) -> usize { + pub(crate) fn collapse_limit(&self, axis_rindex: usize) -> usize + { match axis_rindex { 0 => self.axis_collapse_limit_last, 1 => self.axis_collapse_limit_next_last, @@ -79,13 +83,10 @@ impl FormatOptions { /// * `fmt_elem`: A function that formats an element in the list, given the /// formatter and the index of the item in the list. fn format_with_overflow( - f: &mut fmt::Formatter<'_>, - length: usize, - limit: usize, - separator: &str, - ellipsis: &str, + f: &mut fmt::Formatter<'_>, length: usize, limit: usize, separator: &str, ellipsis: &str, fmt_elem: &mut dyn FnMut(&mut fmt::Formatter, usize) -> fmt::Result, -) -> fmt::Result { +) -> fmt::Result +{ if length == 0 { // no-op } else if length <= limit { @@ -112,10 +113,7 @@ fn format_with_overflow( } fn format_array( - array: &ArrayBase, - f: &mut fmt::Formatter<'_>, - format: F, - fmt_opt: &FormatOptions, + array: &ArrayBase, f: &mut fmt::Formatter<'_>, format: F, fmt_opt: &FormatOptions, ) -> fmt::Result where F: FnMut(&A, &mut fmt::Formatter<'_>) -> fmt::Result + Clone, @@ -128,11 +126,7 @@ where } fn format_array_inner( - view: ArrayView, - f: &mut fmt::Formatter<'_>, - mut format: F, - fmt_opt: &FormatOptions, - depth: usize, + view: ArrayView, f: &mut fmt::Formatter<'_>, mut format: F, fmt_opt: &FormatOptions, depth: usize, full_ndim: usize, ) -> fmt::Result where @@ -151,14 +145,9 @@ where &[len] => { let view = view.view().into_dimensionality::().unwrap(); f.write_str("[")?; - format_with_overflow( - f, - len, - fmt_opt.collapse_limit(0), - ", ", - ELLIPSIS, - &mut |f, index| format(&view[index], f), - )?; + format_with_overflow(f, len, fmt_opt.collapse_limit(0), ", ", ELLIPSIS, &mut |f, index| { + format(&view[index], f) + })?; f.write_str("]")?; } // For n-dimensional arrays, we proceed recursively @@ -170,14 +159,7 @@ where f.write_str("[")?; let limit = fmt_opt.collapse_limit(full_ndim - depth - 1); format_with_overflow(f, shape[0], limit, &separator, ELLIPSIS, &mut |f, index| { - format_array_inner( - view.index_axis(Axis(0), index), - f, - format.clone(), - fmt_opt, - depth + 1, - full_ndim, - ) + format_array_inner(view.index_axis(Axis(0), index), f, format.clone(), fmt_opt, depth + 1, full_ndim) })?; f.write_str("]")?; } @@ -190,11 +172,11 @@ where /// to each element. /// /// The array is shown in multiline style. -impl<'a, A: fmt::Display, S, D: Dimension> fmt::Display for ArrayBase -where - S: Data, +impl fmt::Display for ArrayBase +where S: Data { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { let fmt_opt = FormatOptions::default_for_array(self.len(), f.alternate()); format_array(self, f, <_>::fmt, &fmt_opt) } @@ -204,11 +186,11 @@ where /// to each element. /// /// The array is shown in multiline style. -impl<'a, A: fmt::Debug, S, D: Dimension> fmt::Debug for ArrayBase -where - S: Data, +impl fmt::Debug for ArrayBase +where S: Data { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { let fmt_opt = FormatOptions::default_for_array(self.len(), f.alternate()); format_array(self, f, <_>::fmt, &fmt_opt)?; @@ -218,7 +200,7 @@ where ", shape={:?}, strides={:?}, layout={:?}", self.shape(), self.strides(), - layout = self.view().layout() + self.view().layout(), )?; match D::NDIM { Some(ndim) => write!(f, ", const ndim={}", ndim)?, @@ -232,11 +214,11 @@ where /// to each element. /// /// The array is shown in multiline style. -impl<'a, A: fmt::LowerExp, S, D: Dimension> fmt::LowerExp for ArrayBase -where - S: Data, +impl fmt::LowerExp for ArrayBase +where S: Data { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { let fmt_opt = FormatOptions::default_for_array(self.len(), f.alternate()); format_array(self, f, <_>::fmt, &fmt_opt) } @@ -246,11 +228,11 @@ where /// to each element. /// /// The array is shown in multiline style. -impl<'a, A: fmt::UpperExp, S, D: Dimension> fmt::UpperExp for ArrayBase -where - S: Data, +impl fmt::UpperExp for ArrayBase +where S: Data { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { let fmt_opt = FormatOptions::default_for_array(self.len(), f.alternate()); format_array(self, f, <_>::fmt, &fmt_opt) } @@ -259,11 +241,11 @@ where /// to each element. /// /// The array is shown in multiline style. -impl<'a, A: fmt::LowerHex, S, D: Dimension> fmt::LowerHex for ArrayBase -where - S: Data, +impl fmt::LowerHex for ArrayBase +where S: Data { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { let fmt_opt = FormatOptions::default_for_array(self.len(), f.alternate()); format_array(self, f, <_>::fmt, &fmt_opt) } @@ -273,25 +255,30 @@ where /// to each element. /// /// The array is shown in multiline style. -impl<'a, A: fmt::Binary, S, D: Dimension> fmt::Binary for ArrayBase -where - S: Data, +impl fmt::Binary for ArrayBase +where S: Data { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { let fmt_opt = FormatOptions::default_for_array(self.len(), f.alternate()); format_array(self, f, <_>::fmt, &fmt_opt) } } #[cfg(test)] -mod formatting_with_omit { +mod formatting_with_omit +{ + #[cfg(not(feature = "std"))] + use alloc::string::String; + #[cfg(not(feature = "std"))] + use alloc::vec::Vec; use itertools::Itertools; - use std::fmt; use super::*; use crate::prelude::*; - fn assert_str_eq(expected: &str, actual: &str) { + fn assert_str_eq(expected: &str, actual: &str) + { // use assert to avoid printing the strings twice on failure assert!( expected == actual, @@ -301,11 +288,8 @@ mod formatting_with_omit { ); } - fn ellipsize( - limit: usize, - sep: &str, - elements: impl IntoIterator, - ) -> String { + fn ellipsize(limit: usize, sep: &str, elements: impl IntoIterator) -> String + { let elements = elements.into_iter().collect::>(); let edge = limit / 2; if elements.len() <= limit { @@ -323,7 +307,8 @@ mod formatting_with_omit { } #[test] - fn empty_arrays() { + fn empty_arrays() + { let a: Array2 = arr2(&[[], []]); let actual = format!("{}", a); let expected = "[[]]"; @@ -331,7 +316,8 @@ mod formatting_with_omit { } #[test] - fn zero_length_axes() { + fn zero_length_axes() + { let a = Array3::::zeros((3, 0, 4)); let actual = format!("{}", a); let expected = "[[[]]]"; @@ -339,7 +325,8 @@ mod formatting_with_omit { } #[test] - fn dim_0() { + fn dim_0() + { let element = 12; let a = arr0(element); let actual = format!("{}", a); @@ -348,7 +335,8 @@ mod formatting_with_omit { } #[test] - fn dim_1() { + fn dim_1() + { let overflow: usize = 2; let a = Array1::from_elem(ARRAY_MANY_ELEMENT_LIMIT + overflow, 1); let actual = format!("{}", a); @@ -357,7 +345,8 @@ mod formatting_with_omit { } #[test] - fn dim_1_alternate() { + fn dim_1_alternate() + { let overflow: usize = 2; let a = Array1::from_elem(ARRAY_MANY_ELEMENT_LIMIT + overflow, 1); let actual = format!("{:#}", a); @@ -366,12 +355,10 @@ mod formatting_with_omit { } #[test] - fn dim_2_last_axis_overflow() { + fn dim_2_last_axis_overflow() + { let overflow: usize = 2; - let a = Array2::from_elem( - (AXIS_2D_OVERFLOW_LIMIT, AXIS_2D_OVERFLOW_LIMIT + overflow), - 1, - ); + let a = Array2::from_elem((AXIS_2D_OVERFLOW_LIMIT, AXIS_2D_OVERFLOW_LIMIT + overflow), 1); let actual = format!("{}", a); let expected = "\ [[1, 1, 1, 1, 1, ..., 1, 1, 1, 1, 1], @@ -389,7 +376,8 @@ mod formatting_with_omit { } #[test] - fn dim_2_non_last_axis_overflow() { + fn dim_2_non_last_axis_overflow() + { let a = Array2::from_elem((ARRAY_MANY_ELEMENT_LIMIT / 10, 10), 1); let actual = format!("{}", a); let row = format!("{}", a.row(0)); @@ -401,7 +389,8 @@ mod formatting_with_omit { } #[test] - fn dim_2_non_last_axis_overflow_alternate() { + fn dim_2_non_last_axis_overflow_alternate() + { let a = Array2::from_elem((AXIS_LIMIT_COL * 4, 6), 1); let actual = format!("{:#}", a); let row = format!("{}", a.row(0)); @@ -410,15 +399,10 @@ mod formatting_with_omit { } #[test] - fn dim_2_multi_directional_overflow() { + fn dim_2_multi_directional_overflow() + { let overflow: usize = 2; - let a = Array2::from_elem( - ( - AXIS_2D_OVERFLOW_LIMIT + overflow, - AXIS_2D_OVERFLOW_LIMIT + overflow, - ), - 1, - ); + let a = Array2::from_elem((AXIS_2D_OVERFLOW_LIMIT + overflow, AXIS_2D_OVERFLOW_LIMIT + overflow), 1); let actual = format!("{}", a); let row = format!("[{}]", ellipsize(AXIS_LIMIT_ROW, ", ", a.row(0))); let expected = format!( @@ -429,15 +413,10 @@ mod formatting_with_omit { } #[test] - fn dim_2_multi_directional_overflow_alternate() { + fn dim_2_multi_directional_overflow_alternate() + { let overflow: usize = 2; - let a = Array2::from_elem( - ( - AXIS_2D_OVERFLOW_LIMIT + overflow, - AXIS_2D_OVERFLOW_LIMIT + overflow, - ), - 1, - ); + let a = Array2::from_elem((AXIS_2D_OVERFLOW_LIMIT + overflow, AXIS_2D_OVERFLOW_LIMIT + overflow), 1); let actual = format!("{:#}", a); let row = format!("{}", a.row(0)); let expected = format!("[{}]", (0..a.nrows()).map(|_| &row).format(",\n ")); @@ -445,13 +424,11 @@ mod formatting_with_omit { } #[test] - fn dim_3_overflow_most() { - let a = Array3::from_shape_fn( - (AXIS_LIMIT_STACKED + 1, AXIS_LIMIT_COL, AXIS_LIMIT_ROW + 1), - |(i, j, k)| { - 1000. + (100. * ((i as f64).sqrt() + (j as f64).sin() + k as f64)).round() / 100. - }, - ); + fn dim_3_overflow_most() + { + let a = Array3::from_shape_fn((AXIS_LIMIT_STACKED + 1, AXIS_LIMIT_COL, AXIS_LIMIT_ROW + 1), |(i, j, k)| { + 1000. + (100. * ((i as f64).sqrt() + (j as f64).sin() + k as f64)).round() / 100. + }); let actual = format!("{:6.1}", a); let expected = "\ [[[1000.0, 1001.0, 1002.0, 1003.0, 1004.0, ..., 1007.0, 1008.0, 1009.0, 1010.0, 1011.0], @@ -531,7 +508,8 @@ mod formatting_with_omit { } #[test] - fn dim_4_overflow_outer() { + fn dim_4_overflow_outer() + { let a = Array4::from_shape_fn((10, 10, 3, 3), |(i, j, k, l)| i + j + k + l); let actual = format!("{:2}", a); // Generated using NumPy with: diff --git a/src/arraytraits.rs b/src/arraytraits.rs index 143bb5faf..d7a00fcfe 100644 --- a/src/arraytraits.rs +++ b/src/arraytraits.rs @@ -6,23 +6,33 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +#[cfg(not(feature = "std"))] +use alloc::boxed::Box; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; use std::hash; -use std::isize; -use std::iter::FromIterator; -use std::iter::IntoIterator; use std::mem; +use std::mem::size_of; use std::ops::{Index, IndexMut}; +use std::{iter::FromIterator, slice}; use crate::imp_prelude::*; -use crate::iter::{Iter, IterMut}; -use crate::NdIndex; +use crate::Arc; -use crate::numeric_util; -use crate::{FoldWhile, Zip}; +use crate::{ + dimension, + iter::{Iter, IterMut}, + numeric_util, + FoldWhile, + NdIndex, + OwnedArcRepr, + Zip, +}; #[cold] #[inline(never)] -pub(crate) fn array_out_of_bounds() -> ! { +pub(crate) fn array_out_of_bounds() -> ! +{ panic!("ndarray: index out of bounds"); } @@ -47,7 +57,8 @@ where { type Output = S::Elem; #[inline] - fn index(&self, index: I) -> &S::Elem { + fn index(&self, index: I) -> &S::Elem + { debug_bounds_check!(self, index); unsafe { &*self.ptr.as_ptr().offset( @@ -69,7 +80,8 @@ where S: DataMut, { #[inline] - fn index_mut(&mut self, index: I) -> &mut S::Elem { + fn index_mut(&mut self, index: I) -> &mut S::Elem + { debug_bounds_check!(self, index); unsafe { &mut *self.as_mut_ptr().offset( @@ -90,7 +102,8 @@ where S2: Data, D: Dimension, { - fn eq(&self, rhs: &ArrayBase) -> bool { + fn eq(&self, rhs: &ArrayBase) -> bool + { if self.shape() != rhs.shape() { return false; } @@ -112,6 +125,38 @@ where } } +/// Return `true` if the array shapes and all elements of `self` and +/// `rhs` are equal. Return `false` otherwise. +#[allow(clippy::unconditional_recursion)] // false positive +impl PartialEq<&ArrayBase> for ArrayBase +where + A: PartialEq, + S: Data, + S2: Data, + D: Dimension, +{ + fn eq(&self, rhs: &&ArrayBase) -> bool + { + *self == **rhs + } +} + +/// Return `true` if the array shapes and all elements of `self` and +/// `rhs` are equal. Return `false` otherwise. +#[allow(clippy::unconditional_recursion)] // false positive +impl PartialEq> for &ArrayBase +where + A: PartialEq, + S: Data, + S2: Data, + D: Dimension, +{ + fn eq(&self, rhs: &ArrayBase) -> bool + { + **self == *rhs + } +} + impl Eq for ArrayBase where D: Dimension, @@ -120,9 +165,20 @@ where { } +impl From> for ArrayBase +where S: DataOwned +{ + /// Create a one-dimensional array from a boxed slice (no copying needed). + /// + /// **Panics** if the length is greater than `isize::MAX`. + fn from(b: Box<[A]>) -> Self + { + Self::from_vec(b.into_vec()) + } +} + impl From> for ArrayBase -where - S: DataOwned, +where S: DataOwned { /// Create a one-dimensional array from a vector (no copying needed). /// @@ -133,20 +189,14 @@ where /// /// let array = Array::from(vec![1., 2., 3., 4.]); /// ``` - fn from(v: Vec) -> Self { - if mem::size_of::() == 0 { - assert!( - v.len() <= isize::MAX as usize, - "Length must fit in `isize`.", - ); - } - unsafe { Self::from_shape_vec_unchecked(v.len() as Ix, v) } + fn from(v: Vec) -> Self + { + Self::from_vec(v) } } impl FromIterator for ArrayBase -where - S: DataOwned, +where S: DataOwned { /// Create a one-dimensional array from an iterable. /// @@ -154,17 +204,15 @@ where /// /// ```rust /// use ndarray::{Array, arr1}; - /// use std::iter::FromIterator; /// /// // Either use `from_iter` directly or use `Iterator::collect`. /// let array = Array::from_iter((0..5).map(|x| x * x)); /// assert!(array == arr1(&[0, 1, 4, 9, 16])) /// ``` fn from_iter(iterable: I) -> ArrayBase - where - I: IntoIterator, + where I: IntoIterator { - Self::from(iterable.into_iter().collect::>()) + Self::from_iter(iterable) } } @@ -176,7 +224,8 @@ where type Item = &'a S::Elem; type IntoIter = Iter<'a, S::Elem, D>; - fn into_iter(self) -> Self::IntoIter { + fn into_iter(self) -> Self::IntoIter + { self.iter() } } @@ -189,48 +238,50 @@ where type Item = &'a mut S::Elem; type IntoIter = IterMut<'a, S::Elem, D>; - fn into_iter(self) -> Self::IntoIter { + fn into_iter(self) -> Self::IntoIter + { self.iter_mut() } } impl<'a, A, D> IntoIterator for ArrayView<'a, A, D> -where - D: Dimension, +where D: Dimension { type Item = &'a A; type IntoIter = Iter<'a, A, D>; - fn into_iter(self) -> Self::IntoIter { + fn into_iter(self) -> Self::IntoIter + { self.into_iter_() } } impl<'a, A, D> IntoIterator for ArrayViewMut<'a, A, D> -where - D: Dimension, +where D: Dimension { type Item = &'a mut A; type IntoIter = IterMut<'a, A, D>; - fn into_iter(self) -> Self::IntoIter { + fn into_iter(self) -> Self::IntoIter + { self.into_iter_() } } -impl<'a, S, D> hash::Hash for ArrayBase +impl hash::Hash for ArrayBase where D: Dimension, S: Data, S::Elem: hash::Hash, { // Note: elements are hashed in the logical order - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) + { self.shape().hash(state); if let Some(self_s) = self.as_slice() { hash::Hash::hash_slice(self_s, state); } else { - for row in self.inner_rows() { + for row in self.rows() { if let Some(row_s) = row.as_slice() { hash::Hash::hash_slice(row_s, state); } else { @@ -264,28 +315,53 @@ where { } -#[cfg(any(feature = "serde"))] +#[cfg(feature = "serde")] // Use version number so we can add a packed format later. pub const ARRAY_FORMAT_VERSION: u8 = 1u8; // use "raw" form instead of type aliases here so that they show up in docs -/// Implementation of `ArrayView::from(&S)` where `S` is a slice or slicable. +/// Implementation of `ArrayView::from(&S)` where `S` is a slice or sliceable. +/// +/// **Panics** if the length of the slice overflows `isize`. (This can only +/// occur if `A` is zero-sized, because slices cannot contain more than +/// `isize::MAX` number of bytes.) impl<'a, A, Slice: ?Sized> From<&'a Slice> for ArrayView<'a, A, Ix1> -where - Slice: AsRef<[A]>, +where Slice: AsRef<[A]> { /// Create a one-dimensional read-only array view of the data in `slice`. /// /// **Panics** if the slice length is greater than `isize::MAX`. - fn from(slice: &'a Slice) -> Self { - let xs = slice.as_ref(); - if mem::size_of::() == 0 { - assert!( - xs.len() <= ::std::isize::MAX as usize, - "Slice length must fit in `isize`.", - ); - } - unsafe { Self::from_shape_ptr(xs.len(), xs.as_ptr()) } + fn from(slice: &'a Slice) -> Self + { + aview1(slice.as_ref()) + } +} + +/// Implementation of ArrayView2::from(&[[A; N]; M]) +/// +/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A +/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes). +/// **Panics** if N == 0 and the number of rows is greater than isize::MAX. +impl<'a, A, const M: usize, const N: usize> From<&'a [[A; N]; M]> for ArrayView<'a, A, Ix2> +{ + /// Create a two-dimensional read-only array view of the data in `slice` + fn from(xs: &'a [[A; N]; M]) -> Self + { + Self::from(&xs[..]) + } +} + +/// Implementation of ArrayView2::from(&[[A; N]]) +/// +/// **Panics** if the product of non-zero axis lengths overflows `isize`. (This +/// can only occur if A is zero-sized or if `N` is zero, because slices cannot +/// contain more than `isize::MAX` number of bytes.) +impl<'a, A, const N: usize> From<&'a [[A; N]]> for ArrayView<'a, A, Ix2> +{ + /// Create a two-dimensional read-only array view of the data in `slice` + fn from(xs: &'a [[A; N]]) -> Self + { + aview2(xs) } } @@ -296,24 +372,25 @@ where D: Dimension, { /// Create a read-only array view of the array. - fn from(array: &'a ArrayBase) -> Self { + fn from(array: &'a ArrayBase) -> Self + { array.view() } } -/// Implementation of `ArrayViewMut::from(&mut S)` where `S` is a slice or slicable. +/// Implementation of `ArrayViewMut::from(&mut S)` where `S` is a slice or sliceable. impl<'a, A, Slice: ?Sized> From<&'a mut Slice> for ArrayViewMut<'a, A, Ix1> -where - Slice: AsMut<[A]>, +where Slice: AsMut<[A]> { /// Create a one-dimensional read-write array view of the data in `slice`. /// /// **Panics** if the slice length is greater than `isize::MAX`. - fn from(slice: &'a mut Slice) -> Self { + fn from(slice: &'a mut Slice) -> Self + { let xs = slice.as_mut(); if mem::size_of::() == 0 { assert!( - xs.len() <= ::std::isize::MAX as usize, + xs.len() <= isize::MAX as usize, "Slice length must fit in `isize`.", ); } @@ -321,6 +398,51 @@ where } } +/// Implementation of ArrayViewMut2::from(&mut [[A; N]; M]) +/// +/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A +/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes). +/// **Panics** if N == 0 and the number of rows is greater than isize::MAX. +impl<'a, A, const M: usize, const N: usize> From<&'a mut [[A; N]; M]> for ArrayViewMut<'a, A, Ix2> +{ + /// Create a two-dimensional read-write array view of the data in `slice` + fn from(xs: &'a mut [[A; N]; M]) -> Self + { + Self::from(&mut xs[..]) + } +} + +/// Implementation of ArrayViewMut2::from(&mut [[A; N]]) +/// +/// **Panics** if the product of non-zero axis lengths overflows `isize`. (This +/// can only occur if `A` is zero-sized or if `N` is zero, because slices +/// cannot contain more than `isize::MAX` number of bytes.) +impl<'a, A, const N: usize> From<&'a mut [[A; N]]> for ArrayViewMut<'a, A, Ix2> +{ + /// Create a two-dimensional read-write array view of the data in `slice` + fn from(xs: &'a mut [[A; N]]) -> Self + { + let cols = N; + let rows = xs.len(); + let dim = Ix2(rows, cols); + if size_of::() == 0 { + dimension::size_of_shape_checked(&dim).expect("Product of non-zero axis lengths must not overflow isize."); + } else if N == 0 { + assert!( + xs.len() <= isize::MAX as usize, + "Product of non-zero axis lengths must not overflow isize.", + ); + } + + // `cols * rows` is guaranteed to fit in `isize` because we checked that it fits in + // `isize::MAX` + unsafe { + let data = slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut A, cols * rows); + ArrayViewMut::from_shape_ptr(dim, data.as_mut_ptr()) + } + } +} + /// Implementation of `ArrayViewMut::from(&mut A)` where `A` is an array. impl<'a, A, S, D> From<&'a mut ArrayBase> for ArrayViewMut<'a, A, D> where @@ -328,11 +450,23 @@ where D: Dimension, { /// Create a read-write array view of the array. - fn from(array: &'a mut ArrayBase) -> Self { + fn from(array: &'a mut ArrayBase) -> Self + { array.view_mut() } } +impl From> for ArcArray +where D: Dimension +{ + fn from(arr: Array) -> ArcArray + { + let data = OwnedArcRepr(Arc::new(arr.data)); + // safe because: equivalent unmoved data, ptr and dims remain valid + unsafe { ArrayBase::from_data_ptr(data, arr.ptr).with_strides_dim(arr.strides, arr.dim) } + } +} + /// Argument conversion into an array view /// /// The trait is parameterized over `A`, the element type, and `D`, the @@ -355,8 +489,7 @@ where /// /// ``` pub trait AsArray<'a, A: 'a, D = Ix1>: Into> -where - D: Dimension, +where D: Dimension { } impl<'a, A: 'a, D, T> AsArray<'a, A, D> for T @@ -386,7 +519,8 @@ where { // NOTE: We can implement Default for non-zero dimensional array views by // using an empty slice, however we need a trait for nonzero Dimension. - fn default() -> Self { + fn default() -> Self + { ArrayBase::default(D::default()) } } diff --git a/src/data_repr.rs b/src/data_repr.rs new file mode 100644 index 000000000..4041c192b --- /dev/null +++ b/src/data_repr.rs @@ -0,0 +1,188 @@ +use crate::extension::nonnull; +#[cfg(not(feature = "std"))] +use alloc::borrow::ToOwned; +use alloc::slice; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; +use std::mem; +use std::mem::ManuallyDrop; +use std::ptr::NonNull; + +#[allow(unused_imports)] +use rawpointer::PointerExt; + +/// Array's representation. +/// +/// *Don’t use this type directly—use the type alias +/// [`Array`](crate::Array) for the array type!* +// Like a Vec, but with non-unique ownership semantics +// +// repr(C) to make it transmutable OwnedRepr -> OwnedRepr if +// transmutable A -> B. +#[derive(Debug)] +#[repr(C)] +pub struct OwnedRepr +{ + ptr: NonNull, + len: usize, + capacity: usize, +} + +impl OwnedRepr +{ + pub(crate) fn from(v: Vec) -> Self + { + let mut v = ManuallyDrop::new(v); + let len = v.len(); + let capacity = v.capacity(); + let ptr = nonnull::nonnull_from_vec_data(&mut v); + Self { ptr, len, capacity } + } + + pub(crate) fn into_vec(self) -> Vec + { + ManuallyDrop::new(self).take_as_vec() + } + + pub(crate) fn as_slice(&self) -> &[A] + { + unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) } + } + + pub(crate) fn len(&self) -> usize + { + self.len + } + + pub(crate) fn as_ptr(&self) -> *const A + { + self.ptr.as_ptr() + } + + pub(crate) fn as_nonnull_mut(&mut self) -> NonNull + { + self.ptr + } + + /// Return end pointer + pub(crate) fn as_end_nonnull(&self) -> NonNull + { + unsafe { self.ptr.add(self.len) } + } + + /// Reserve `additional` elements; return the new pointer + /// + /// ## Safety + /// + /// Note that existing pointers into the data are invalidated + #[must_use = "must use new pointer to update existing pointers"] + pub(crate) fn reserve(&mut self, additional: usize) -> NonNull + { + self.modify_as_vec(|mut v| { + v.reserve(additional); + v + }); + self.as_nonnull_mut() + } + + /// Set the valid length of the data + /// + /// ## Safety + /// + /// The first `new_len` elements of the data should be valid. + pub(crate) unsafe fn set_len(&mut self, new_len: usize) + { + debug_assert!(new_len <= self.capacity); + self.len = new_len; + } + + /// Return the length (number of elements in total) + pub(crate) fn release_all_elements(&mut self) -> usize + { + let ret = self.len; + self.len = 0; + ret + } + + /// Cast self into equivalent repr of other element type + /// + /// ## Safety + /// + /// Caller must ensure the two types have the same representation. + /// **Panics** if sizes don't match (which is not a sufficient check). + pub(crate) unsafe fn data_subst(self) -> OwnedRepr + { + // necessary but not sufficient check + assert_eq!(mem::size_of::(), mem::size_of::()); + let self_ = ManuallyDrop::new(self); + OwnedRepr { + ptr: self_.ptr.cast::(), + len: self_.len, + capacity: self_.capacity, + } + } + + fn modify_as_vec(&mut self, f: impl FnOnce(Vec) -> Vec) + { + let v = self.take_as_vec(); + *self = Self::from(f(v)); + } + + fn take_as_vec(&mut self) -> Vec + { + let capacity = self.capacity; + let len = self.len; + self.len = 0; + self.capacity = 0; + unsafe { Vec::from_raw_parts(self.ptr.as_ptr(), len, capacity) } + } +} + +impl Clone for OwnedRepr +where A: Clone +{ + fn clone(&self) -> Self + { + Self::from(self.as_slice().to_owned()) + } + + fn clone_from(&mut self, other: &Self) + { + let mut v = self.take_as_vec(); + let other = other.as_slice(); + + if v.len() > other.len() { + v.truncate(other.len()); + } + let (front, back) = other.split_at(v.len()); + v.clone_from_slice(front); + v.extend_from_slice(back); + *self = Self::from(v); + } +} + +impl Drop for OwnedRepr +{ + fn drop(&mut self) + { + if self.capacity > 0 { + // correct because: If the elements don't need dropping, an + // empty Vec is ok. Only the Vec's allocation needs dropping. + // + // implemented because: in some places in ndarray + // where A: Copy (hence does not need drop) we use uninitialized elements in + // vectors. Setting the length to 0 avoids that the vector tries to + // drop, slice or otherwise produce values of these elements. + // (The details of the validity letting this happen with nonzero len, are + // under discussion as of this writing.) + if !mem::needs_drop::() { + self.len = 0; + } + // drop as a Vec. + self.take_as_vec(); + } + } +} + +unsafe impl Sync for OwnedRepr where A: Sync {} +unsafe impl Send for OwnedRepr where A: Send {} diff --git a/src/data_traits.rs b/src/data_traits.rs index b3f07959e..fc2fe4bfa 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -8,15 +8,22 @@ //! The data (inner representation) traits for ndarray -use crate::extension::nonnull::nonnull_from_vec_data; +#[allow(unused_imports)] use rawpointer::PointerExt; + +#[cfg(target_has_atomic = "ptr")] +use alloc::sync::Arc; + +#[cfg(not(target_has_atomic = "ptr"))] +use portable_atomic_util::Arc; + +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; +use std::mem::MaybeUninit; use std::mem::{self, size_of}; use std::ptr::NonNull; -use std::sync::Arc; -use crate::{ - ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRcRepr, OwnedRepr, RawViewRepr, ViewRepr, -}; +use crate::{ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr}; /// Array representation trait. /// @@ -27,13 +34,14 @@ use crate::{ /// ***Note:*** `RawData` is not an extension interface at this point. /// Traits in Rust can serve many different roles. This trait is public because /// it is used as a bound on public methods. -pub unsafe trait RawData: Sized { +#[allow(clippy::missing_safety_doc)] // not implementable downstream +pub unsafe trait RawData: Sized +{ /// The array element type. type Elem; #[doc(hidden)] - // This method is only used for debugging - fn _data_slice(&self) -> Option<&[Self::Elem]>; + fn _is_pointer_inbounds(&self, ptr: *const Self::Elem) -> bool; private_decl! {} } @@ -43,11 +51,16 @@ pub unsafe trait RawData: Sized { /// For an array with writable elements. /// /// ***Internal trait, see `RawData`.*** -pub unsafe trait RawDataMut: RawData { +#[allow(clippy::missing_safety_doc)] // not implementable downstream +pub unsafe trait RawDataMut: RawData +{ /// If possible, ensures that the array has unique access to its data. /// - /// If `Self` provides safe mutable access to array elements, then it - /// **must** panic or ensure that the data is unique. + /// The implementer must ensure that if the input is contiguous, then the + /// output has the same strides as input. + /// + /// Additionally, if `Self` provides safe mutable access to array elements, + /// then this method **must** panic or ensure that the data is unique. #[doc(hidden)] fn try_ensure_unique(_: &mut ArrayBase) where @@ -67,17 +80,16 @@ pub unsafe trait RawDataMut: RawData { /// An array representation that can be cloned. /// /// ***Internal trait, see `RawData`.*** -pub unsafe trait RawDataClone: RawData { +#[allow(clippy::missing_safety_doc)] // not implementable downstream +pub unsafe trait RawDataClone: RawData +{ #[doc(hidden)] /// Unsafe because, `ptr` must point inside the current storage. unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull); #[doc(hidden)] - unsafe fn clone_from_with_ptr( - &mut self, - other: &Self, - ptr: NonNull, - ) -> NonNull { + unsafe fn clone_from_with_ptr(&mut self, other: &Self, ptr: NonNull) -> NonNull + { let (data, ptr) = other.clone_with_ptr(ptr); *self = data; ptr @@ -89,13 +101,35 @@ pub unsafe trait RawDataClone: RawData { /// For an array with elements that can be accessed with safe code. /// /// ***Internal trait, see `RawData`.*** -pub unsafe trait Data: RawData { +#[allow(clippy::missing_safety_doc)] // not implementable downstream +pub unsafe trait Data: RawData +{ /// Converts the array to a uniquely owned array, cloning elements if necessary. #[doc(hidden)] - fn into_owned(self_: ArrayBase) -> ArrayBase, D> + #[allow(clippy::wrong_self_convention)] + fn into_owned(self_: ArrayBase) -> Array where Self::Elem: Clone, D: Dimension; + + /// Converts the array into `Array` if this is possible without + /// cloning the array elements. Otherwise, returns `self_` unchanged. + #[doc(hidden)] + fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> + where D: Dimension; + + /// Return a shared ownership (copy on write) array based on the existing one, + /// cloning elements if necessary. + #[doc(hidden)] + #[allow(clippy::wrong_self_convention)] + fn to_shared(self_: &ArrayBase) -> ArcArray + where + Self::Elem: Clone, + D: Dimension, + { + // clone to shared + self_.to_owned().into_shared() + } } /// Array representation trait. @@ -110,7 +144,9 @@ pub unsafe trait Data: RawData { // `RawDataMut::try_ensure_unique` implementation always panics or ensures that // the data is unique. You are also guaranteeing that `try_is_unique` always // returns `Some(_)`. -pub unsafe trait DataMut: Data + RawDataMut { +#[allow(clippy::missing_safety_doc)] // not implementable downstream +pub unsafe trait DataMut: Data + RawDataMut +{ /// Ensures that the array has unique access to its data. #[doc(hidden)] #[inline] @@ -125,46 +161,49 @@ pub unsafe trait DataMut: Data + RawDataMut { /// Returns whether the array has unique access to its data. #[doc(hidden)] #[inline] - fn is_unique(&mut self) -> bool { + #[allow(clippy::wrong_self_convention)] // mut needed for Arc types + fn is_unique(&mut self) -> bool + { self.try_is_unique().unwrap() } } -/// Array representation trait. -/// -/// An array representation that can be cloned and allows elements to be -/// accessed with safe code. -/// -/// ***Internal trait, see `Data`.*** -#[deprecated(note = "use `Data + RawDataClone` instead", since = "0.13.0")] -pub trait DataClone: Data + RawDataClone {} - -#[allow(deprecated)] -impl DataClone for T where T: Data + RawDataClone {} - -unsafe impl RawData for RawViewRepr<*const A> { +unsafe impl RawData for RawViewRepr<*const A> +{ type Elem = A; - fn _data_slice(&self) -> Option<&[A]> { - None + + #[inline(always)] + fn _is_pointer_inbounds(&self, _ptr: *const Self::Elem) -> bool + { + true } + private_impl! {} } -unsafe impl RawDataClone for RawViewRepr<*const A> { - unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { +unsafe impl RawDataClone for RawViewRepr<*const A> +{ + unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) + { (*self, ptr) } } -unsafe impl RawData for RawViewRepr<*mut A> { +unsafe impl RawData for RawViewRepr<*mut A> +{ type Elem = A; - fn _data_slice(&self) -> Option<&[A]> { - None + + #[inline(always)] + fn _is_pointer_inbounds(&self, _ptr: *const Self::Elem) -> bool + { + true } + private_impl! {} } -unsafe impl RawDataMut for RawViewRepr<*mut A> { +unsafe impl RawDataMut for RawViewRepr<*mut A> +{ #[inline] fn try_ensure_unique(_: &mut ArrayBase) where @@ -174,29 +213,35 @@ unsafe impl RawDataMut for RawViewRepr<*mut A> { } #[inline] - fn try_is_unique(&mut self) -> Option { + fn try_is_unique(&mut self) -> Option + { None } } -unsafe impl RawDataClone for RawViewRepr<*mut A> { - unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { +unsafe impl RawDataClone for RawViewRepr<*mut A> +{ + unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) + { (*self, ptr) } } -unsafe impl RawData for OwnedArcRepr { +unsafe impl RawData for OwnedArcRepr +{ type Elem = A; - fn _data_slice(&self) -> Option<&[A]> { - Some(&self.0) + + fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool + { + self.0._is_pointer_inbounds(self_ptr) } + private_impl! {} } // NOTE: Copy on write unsafe impl RawDataMut for OwnedArcRepr -where - A: Clone, +where A: Clone { fn try_ensure_unique(self_: &mut ArrayBase) where @@ -207,14 +252,9 @@ where return; } if self_.dim.size() <= self_.data.0.len() / 2 { - // Create a new vec if the current view is less than half of - // backing data. - unsafe { - *self_ = ArrayBase::from_shape_vec_unchecked( - self_.dim.clone(), - self_.iter().cloned().collect(), - ); - } + // Clone only the visible elements if the current view is less than + // half of backing data. + *self_ = self_.to_owned().into_shared(); return; } let rcvec = &mut self_.data.0; @@ -226,50 +266,85 @@ where }; let rvec = Arc::make_mut(rcvec); unsafe { - self_.ptr = nonnull_from_vec_data(rvec).offset(our_off); + self_.ptr = rvec.as_nonnull_mut().offset(our_off); } } - fn try_is_unique(&mut self) -> Option { + fn try_is_unique(&mut self) -> Option + { Some(Arc::get_mut(&mut self.0).is_some()) } } -unsafe impl Data for OwnedArcRepr { - fn into_owned(mut self_: ArrayBase) -> ArrayBase, D> +unsafe impl Data for OwnedArcRepr +{ + fn into_owned(mut self_: ArrayBase) -> Array where A: Clone, D: Dimension, { Self::ensure_unique(&mut self_); - let data = OwnedRepr(Arc::try_unwrap(self_.data.0).ok().unwrap()); - ArrayBase { - data, - ptr: self_.ptr, - dim: self_.dim, - strides: self_.strides, + let data = Arc::try_unwrap(self_.data.0).ok().unwrap(); + // safe because data is equivalent + unsafe { ArrayBase::from_data_ptr(data, self_.ptr).with_strides_dim(self_.strides, self_.dim) } + } + + fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> + where D: Dimension + { + match Arc::try_unwrap(self_.data.0) { + Ok(owned_data) => unsafe { + // Safe because the data is equivalent. + Ok(ArrayBase::from_data_ptr(owned_data, self_.ptr).with_strides_dim(self_.strides, self_.dim)) + }, + Err(arc_data) => unsafe { + // Safe because the data is equivalent; we're just + // reconstructing `self_`. + Err(ArrayBase::from_data_ptr(OwnedArcRepr(arc_data), self_.ptr) + .with_strides_dim(self_.strides, self_.dim)) + }, } } + + #[allow(clippy::wrong_self_convention)] + fn to_shared(self_: &ArrayBase) -> ArcArray + where + Self::Elem: Clone, + D: Dimension, + { + // to shared using clone of OwnedArcRepr without clone of raw data. + self_.clone() + } } unsafe impl DataMut for OwnedArcRepr where A: Clone {} -unsafe impl RawDataClone for OwnedArcRepr { - unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { +unsafe impl RawDataClone for OwnedArcRepr +{ + unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) + { // pointer is preserved (self.clone(), ptr) } } -unsafe impl RawData for OwnedRepr { +unsafe impl RawData for OwnedRepr +{ type Elem = A; - fn _data_slice(&self) -> Option<&[A]> { - Some(&self.0) + + fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool + { + let slc = self.as_slice(); + let ptr = slc.as_ptr() as *mut A; + let end = unsafe { ptr.add(slc.len()) }; + self_ptr >= ptr && self_ptr <= end } + private_impl! {} } -unsafe impl RawDataMut for OwnedRepr { +unsafe impl RawDataMut for OwnedRepr +{ #[inline] fn try_ensure_unique(_: &mut ArrayBase) where @@ -279,87 +354,112 @@ unsafe impl RawDataMut for OwnedRepr { } #[inline] - fn try_is_unique(&mut self) -> Option { + fn try_is_unique(&mut self) -> Option + { Some(true) } } -unsafe impl Data for OwnedRepr { +unsafe impl Data for OwnedRepr +{ #[inline] - fn into_owned(self_: ArrayBase) -> ArrayBase, D> + fn into_owned(self_: ArrayBase) -> Array where A: Clone, D: Dimension, { self_ } + + #[inline] + fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> + where D: Dimension + { + Ok(self_) + } } unsafe impl DataMut for OwnedRepr {} unsafe impl RawDataClone for OwnedRepr -where - A: Clone, +where A: Clone { - unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { + unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) + { let mut u = self.clone(); - let mut new_ptr = nonnull_from_vec_data(&mut u.0); + let mut new_ptr = u.as_nonnull_mut(); if size_of::() != 0 { - let our_off = - (ptr.as_ptr() as isize - self.0.as_ptr() as isize) / mem::size_of::() as isize; + let our_off = (ptr.as_ptr() as isize - self.as_ptr() as isize) / mem::size_of::() as isize; new_ptr = new_ptr.offset(our_off); } (u, new_ptr) } - unsafe fn clone_from_with_ptr( - &mut self, - other: &Self, - ptr: NonNull, - ) -> NonNull { + unsafe fn clone_from_with_ptr(&mut self, other: &Self, ptr: NonNull) -> NonNull + { let our_off = if size_of::() != 0 { - (ptr.as_ptr() as isize - other.0.as_ptr() as isize) / mem::size_of::() as isize + (ptr.as_ptr() as isize - other.as_ptr() as isize) / mem::size_of::() as isize } else { 0 }; - self.0.clone_from(&other.0); - nonnull_from_vec_data(&mut self.0).offset(our_off) + self.clone_from(other); + self.as_nonnull_mut().offset(our_off) } } -unsafe impl<'a, A> RawData for ViewRepr<&'a A> { +unsafe impl RawData for ViewRepr<&A> +{ type Elem = A; - fn _data_slice(&self) -> Option<&[A]> { - None + + #[inline(always)] + fn _is_pointer_inbounds(&self, _ptr: *const Self::Elem) -> bool + { + true } + private_impl! {} } -unsafe impl<'a, A> Data for ViewRepr<&'a A> { - fn into_owned(self_: ArrayBase) -> ArrayBase, D> +unsafe impl Data for ViewRepr<&A> +{ + fn into_owned(self_: ArrayBase) -> Array where Self::Elem: Clone, D: Dimension, { self_.to_owned() } + + fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> + where D: Dimension + { + Err(self_) + } } -unsafe impl<'a, A> RawDataClone for ViewRepr<&'a A> { - unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { +unsafe impl RawDataClone for ViewRepr<&A> +{ + unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) + { (*self, ptr) } } -unsafe impl<'a, A> RawData for ViewRepr<&'a mut A> { +unsafe impl RawData for ViewRepr<&mut A> +{ type Elem = A; - fn _data_slice(&self) -> Option<&[A]> { - None + + #[inline(always)] + fn _is_pointer_inbounds(&self, _ptr: *const Self::Elem) -> bool + { + true } + private_impl! {} } -unsafe impl<'a, A> RawDataMut for ViewRepr<&'a mut A> { +unsafe impl RawDataMut for ViewRepr<&mut A> +{ #[inline] fn try_ensure_unique(_: &mut ArrayBase) where @@ -369,36 +469,59 @@ unsafe impl<'a, A> RawDataMut for ViewRepr<&'a mut A> { } #[inline] - fn try_is_unique(&mut self) -> Option { + fn try_is_unique(&mut self) -> Option + { Some(true) } } -unsafe impl<'a, A> Data for ViewRepr<&'a mut A> { - fn into_owned(self_: ArrayBase) -> ArrayBase, D> +unsafe impl Data for ViewRepr<&mut A> +{ + fn into_owned(self_: ArrayBase) -> Array where Self::Elem: Clone, D: Dimension, { self_.to_owned() } + + fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> + where D: Dimension + { + Err(self_) + } } -unsafe impl<'a, A> DataMut for ViewRepr<&'a mut A> {} +unsafe impl DataMut for ViewRepr<&mut A> {} /// Array representation trait. /// -/// A representation that is a unique or shared owner of its data. +/// A representation which can be the owner of its data. /// /// ***Internal trait, see `Data`.*** -pub unsafe trait DataOwned: Data { +// The owned storage represents the ownership and allocation of the array's elements. +// The storage may be unique or shared ownership style; it must be an aliasable owner +// (permit aliasing pointers, such as our separate array head pointer). +// +// The array storage must be initially mutable - copy on write arrays may require copying for +// unsharing storage before mutating it. The initially allocated storage must be mutable so +// that it can be mutated directly - through .raw_view_mut_unchecked() - for initialization. +#[allow(clippy::missing_safety_doc)] // not implementable downstream +pub unsafe trait DataOwned: Data +{ + /// Corresponding owned data with MaybeUninit elements + type MaybeUninit: DataOwned> + RawDataSubst; #[doc(hidden)] fn new(elements: Vec) -> Self; /// Converts the data representation to a shared (copy on write) - /// representation, without any copying. + /// representation, cloning the array elements if necessary. #[doc(hidden)] - fn into_shared(self) -> OwnedRcRepr; + #[allow(clippy::wrong_self_convention)] + fn into_shared(self_: ArrayBase) -> ArcArray + where + Self::Elem: Clone, + D: Dimension; } /// Array representation trait. @@ -406,44 +529,66 @@ pub unsafe trait DataOwned: Data { /// A representation that is a lightweight view. /// /// ***Internal trait, see `Data`.*** +#[allow(clippy::missing_safety_doc)] // not implementable downstream pub unsafe trait DataShared: Clone + Data + RawDataClone {} -unsafe impl DataShared for OwnedRcRepr {} -unsafe impl<'a, A> DataShared for ViewRepr<&'a A> {} +unsafe impl DataShared for OwnedArcRepr {} +unsafe impl DataShared for ViewRepr<&A> {} -unsafe impl DataOwned for OwnedRepr { - fn new(elements: Vec) -> Self { - OwnedRepr(elements) +unsafe impl DataOwned for OwnedRepr +{ + type MaybeUninit = OwnedRepr>; + + fn new(elements: Vec) -> Self + { + OwnedRepr::from(elements) } - fn into_shared(self) -> OwnedRcRepr { - OwnedArcRepr(Arc::new(self.0)) + + fn into_shared(self_: ArrayBase) -> ArcArray + where + A: Clone, + D: Dimension, + { + ArcArray::from(self_) } } -unsafe impl DataOwned for OwnedArcRepr { - fn new(elements: Vec) -> Self { - OwnedArcRepr(Arc::new(elements)) +unsafe impl DataOwned for OwnedArcRepr +{ + type MaybeUninit = OwnedArcRepr>; + + fn new(elements: Vec) -> Self + { + OwnedArcRepr(Arc::new(OwnedRepr::from(elements))) } - fn into_shared(self) -> OwnedRcRepr { - self + fn into_shared(self_: ArrayBase) -> ArcArray + where + A: Clone, + D: Dimension, + { + self_ } } -unsafe impl<'a, A> RawData for CowRepr<'a, A> { +unsafe impl RawData for CowRepr<'_, A> +{ type Elem = A; - fn _data_slice(&self) -> Option<&[A]> { + + #[inline] + fn _is_pointer_inbounds(&self, ptr: *const Self::Elem) -> bool + { match self { - CowRepr::View(view) => view._data_slice(), - CowRepr::Owned(data) => data._data_slice(), + CowRepr::View(view) => view._is_pointer_inbounds(ptr), + CowRepr::Owned(data) => data._is_pointer_inbounds(ptr), } } + private_impl! {} } -unsafe impl<'a, A> RawDataMut for CowRepr<'a, A> -where - A: Clone, +unsafe impl RawDataMut for CowRepr<'_, A> +where A: Clone { #[inline] fn try_ensure_unique(array: &mut ArrayBase) @@ -464,16 +609,17 @@ where } #[inline] - fn try_is_unique(&mut self) -> Option { + fn try_is_unique(&mut self) -> Option + { Some(self.is_owned()) } } -unsafe impl<'a, A> RawDataClone for CowRepr<'a, A> -where - A: Clone, +unsafe impl RawDataClone for CowRepr<'_, A> +where A: Clone { - unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { + unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) + { match self { CowRepr::View(view) => { let (new_view, ptr) = view.clone_with_ptr(ptr); @@ -486,12 +632,8 @@ where } } - #[doc(hidden)] - unsafe fn clone_from_with_ptr( - &mut self, - other: &Self, - ptr: NonNull, - ) -> NonNull { + unsafe fn clone_from_with_ptr(&mut self, other: &Self, ptr: NonNull) -> NonNull + { match (&mut *self, other) { (CowRepr::View(self_), CowRepr::View(other)) => self_.clone_from_with_ptr(other, ptr), (CowRepr::Owned(self_), CowRepr::Owned(other)) => self_.clone_from_with_ptr(other, ptr), @@ -509,23 +651,145 @@ where } } -unsafe impl<'a, A> Data for CowRepr<'a, A> { +unsafe impl<'a, A> Data for CowRepr<'a, A> +{ #[inline] - fn into_owned(self_: ArrayBase, D>) -> ArrayBase, D> + fn into_owned(self_: ArrayBase, D>) -> Array where A: Clone, D: Dimension, { match self_.data { CowRepr::View(_) => self_.to_owned(), - CowRepr::Owned(data) => ArrayBase { - data, - ptr: self_.ptr, - dim: self_.dim, - strides: self_.strides, + CowRepr::Owned(data) => unsafe { + // safe because the data is equivalent so ptr, dims remain valid + ArrayBase::from_data_ptr(data, self_.ptr).with_strides_dim(self_.strides, self_.dim) }, } } + + fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> + where D: Dimension + { + match self_.data { + CowRepr::View(_) => Err(self_), + CowRepr::Owned(data) => unsafe { + // safe because the data is equivalent so ptr, dims remain valid + Ok(ArrayBase::from_data_ptr(data, self_.ptr).with_strides_dim(self_.strides, self_.dim)) + }, + } + } +} + +unsafe impl DataMut for CowRepr<'_, A> where A: Clone {} + +unsafe impl<'a, A> DataOwned for CowRepr<'a, A> +{ + type MaybeUninit = CowRepr<'a, MaybeUninit>; + + fn new(elements: Vec) -> Self + { + CowRepr::Owned(OwnedRepr::new(elements)) + } + + fn into_shared(self_: ArrayBase) -> ArcArray + where + A: Clone, + D: Dimension, + { + self_.into_owned().into_shared() + } +} + +/// Array representation trait. +/// +/// The RawDataSubst trait maps the element type of array storage, while +/// keeping the same kind of storage. +/// +/// For example, `RawDataSubst` can map the type `OwnedRepr` to `OwnedRepr`. +pub trait RawDataSubst: RawData +{ + /// The resulting array storage of the same kind but substituted element type + type Output: RawData; + + /// Unsafely translate the data representation from one element + /// representation to another. + /// + /// ## Safety + /// + /// Caller must ensure the two types have the same representation. + unsafe fn data_subst(self) -> Self::Output; +} + +impl RawDataSubst for OwnedRepr +{ + type Output = OwnedRepr; + + unsafe fn data_subst(self) -> Self::Output + { + self.data_subst() + } +} + +impl RawDataSubst for OwnedArcRepr +{ + type Output = OwnedArcRepr; + + unsafe fn data_subst(self) -> Self::Output + { + OwnedArcRepr(Arc::from_raw(Arc::into_raw(self.0) as *const OwnedRepr)) + } +} + +impl RawDataSubst for RawViewRepr<*const A> +{ + type Output = RawViewRepr<*const B>; + + unsafe fn data_subst(self) -> Self::Output + { + RawViewRepr::new() + } } -unsafe impl<'a, A> DataMut for CowRepr<'a, A> where A: Clone {} +impl RawDataSubst for RawViewRepr<*mut A> +{ + type Output = RawViewRepr<*mut B>; + + unsafe fn data_subst(self) -> Self::Output + { + RawViewRepr::new() + } +} + +impl<'a, A: 'a, B: 'a> RawDataSubst for ViewRepr<&'a A> +{ + type Output = ViewRepr<&'a B>; + + unsafe fn data_subst(self) -> Self::Output + { + ViewRepr::new() + } +} + +impl<'a, A: 'a, B: 'a> RawDataSubst for ViewRepr<&'a mut A> +{ + type Output = ViewRepr<&'a mut B>; + + unsafe fn data_subst(self) -> Self::Output + { + ViewRepr::new() + } +} + +impl<'a, A: 'a, B: 'a> RawDataSubst for CowRepr<'a, A> +{ + type Output = CowRepr<'a, B>; + + unsafe fn data_subst(self) -> Self::Output + { + match self { + CowRepr::View(view) => CowRepr::View(view.data_subst()), + CowRepr::Owned(owned) => CowRepr::Owned(owned.data_subst()), + } + } +} diff --git a/src/dimension/axes.rs b/src/dimension/axes.rs index 50830f3c6..c7aaff149 100644 --- a/src/dimension/axes.rs +++ b/src/dimension/axes.rs @@ -1,9 +1,8 @@ -use crate::{Axis, Dimension, Ix, Ixs}; +use crate::{Axis, Dimension, Ixs}; /// Create a new Axes iterator -pub fn axes_of<'a, D>(d: &'a D, strides: &'a D) -> Axes<'a, D> -where - D: Dimension, +pub(crate) fn axes_of<'a, D>(d: &'a D, strides: &'a D) -> Axes<'a, D> +where D: Dimension { Axes { dim: d, @@ -15,9 +14,10 @@ where /// An iterator over the length and stride of each axis of an array. /// -/// See [`.axes()`](../struct.ArrayBase.html#method.axes) for more information. +/// This iterator is created from the array method +/// [`.axes()`](crate::ArrayBase::axes). /// -/// Iterator element type is `AxisDescription`. +/// Iterator element type is [`AxisDescription`]. /// /// # Examples /// @@ -27,13 +27,18 @@ where /// /// let a = Array3::::zeros((3, 5, 4)); /// +/// // find the largest axis in the array +/// // check the axis index and its length +/// /// let largest_axis = a.axes() -/// .max_by_key(|ax| ax.len()) -/// .unwrap().axis(); -/// assert_eq!(largest_axis, Axis(1)); +/// .max_by_key(|ax| ax.len) +/// .unwrap(); +/// assert_eq!(largest_axis.axis, Axis(1)); +/// assert_eq!(largest_axis.len, 5); /// ``` #[derive(Debug)] -pub struct Axes<'a, D> { +pub struct Axes<'a, D> +{ dim: &'a D, strides: &'a D, start: usize, @@ -42,107 +47,94 @@ pub struct Axes<'a, D> { /// Description of the axis, its length and its stride. #[derive(Debug)] -pub struct AxisDescription(pub Axis, pub Ix, pub Ixs); - -copy_and_clone!(AxisDescription); - -// AxisDescription can't really be empty -// https://github.com/rust-ndarray/ndarray/pull/642#discussion_r296051702 -#[allow(clippy::len_without_is_empty)] -impl AxisDescription { - /// Return axis - #[inline(always)] - pub fn axis(self) -> Axis { - self.0 - } - /// Return length - #[inline(always)] - pub fn len(self) -> Ix { - self.1 - } - /// Return stride - #[inline(always)] - pub fn stride(self) -> Ixs { - self.2 - } +pub struct AxisDescription +{ + /// Axis identifier (index) + pub axis: Axis, + /// Length in count of elements of the current axis + pub len: usize, + /// Stride in count of elements of the current axis + pub stride: isize, } +copy_and_clone!(AxisDescription); copy_and_clone!(['a, D] Axes<'a, D>); -impl<'a, D> Iterator for Axes<'a, D> -where - D: Dimension, +impl Iterator for Axes<'_, D> +where D: Dimension { /// Description of the axis, its length and its stride. type Item = AxisDescription; - fn next(&mut self) -> Option { + fn next(&mut self) -> Option + { if self.start < self.end { let i = self.start.post_inc(); - Some(AxisDescription( - Axis(i), - self.dim[i], - self.strides[i] as Ixs, - )) + Some(AxisDescription { + axis: Axis(i), + len: self.dim[i], + stride: self.strides[i] as Ixs, + }) } else { None } } fn fold(self, init: B, f: F) -> B - where - F: FnMut(B, AxisDescription) -> B, + where F: FnMut(B, AxisDescription) -> B { (self.start..self.end) - .map(move |i| AxisDescription(Axis(i), self.dim[i], self.strides[i] as isize)) + .map(move |i| AxisDescription { + axis: Axis(i), + len: self.dim[i], + stride: self.strides[i] as isize, + }) .fold(init, f) } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { let len = self.end - self.start; (len, Some(len)) } } -impl<'a, D> DoubleEndedIterator for Axes<'a, D> -where - D: Dimension, +impl DoubleEndedIterator for Axes<'_, D> +where D: Dimension { - fn next_back(&mut self) -> Option { + fn next_back(&mut self) -> Option + { if self.start < self.end { let i = self.end.pre_dec(); - Some(AxisDescription( - Axis(i), - self.dim[i], - self.strides[i] as Ixs, - )) + Some(AxisDescription { + axis: Axis(i), + len: self.dim[i], + stride: self.strides[i] as Ixs, + }) } else { None } } } -trait IncOps: Copy { +trait IncOps: Copy +{ fn post_inc(&mut self) -> Self; - fn post_dec(&mut self) -> Self; fn pre_dec(&mut self) -> Self; } -impl IncOps for usize { +impl IncOps for usize +{ #[inline(always)] - fn post_inc(&mut self) -> Self { + fn post_inc(&mut self) -> Self + { let x = *self; *self += 1; x } #[inline(always)] - fn post_dec(&mut self) -> Self { - let x = *self; - *self -= 1; - x - } - #[inline(always)] - fn pre_dec(&mut self) -> Self { + fn pre_dec(&mut self) -> Self + { *self -= 1; *self } diff --git a/src/dimension/axis.rs b/src/dimension/axis.rs index 42a1ee12c..8c896f6b7 100644 --- a/src/dimension/axis.rs +++ b/src/dimension/axis.rs @@ -8,18 +8,30 @@ /// An axis index. /// -/// An axis one of an array’s “dimensions”; an *n*-dimensional array has *n* axes. -/// Axis *0* is the array’s outermost axis and *n*-1 is the innermost. +/// An axis one of an array’s “dimensions”; an *n*-dimensional array has *n* +/// axes. Axis *0* is the array’s outermost axis and *n*-1 is the innermost. /// /// All array axis arguments use this type to make the code easier to write /// correctly and easier to understand. +/// +/// For example: in a method like `index_axis(axis, index)` the code becomes +/// self-explanatory when it's called like `.index_axis(Axis(1), i)`; it's +/// evident which integer is the axis number and which is the index. +/// +/// Note: This type does **not** implement From/Into usize and similar trait +/// based conversions, because we want to preserve code readability and quality. +/// +/// `Axis(1)` in itself is a very clear code style and the style that should be +/// avoided is code like `1.into()`. #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct Axis(pub usize); -impl Axis { +impl Axis +{ /// Return the index of the axis. #[inline(always)] - pub fn index(self) -> usize { + pub fn index(self) -> usize + { self.0 } } diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs new file mode 100644 index 000000000..fb9fc1a0c --- /dev/null +++ b/src/dimension/broadcast.rs @@ -0,0 +1,123 @@ +use crate::error::*; +use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; + +/// Calculate the common shape for a pair of array shapes, that they can be broadcasted +/// to. Return an error if the shapes are not compatible. +/// +/// Uses the [NumPy broadcasting rules] +// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules). +pub(crate) fn co_broadcast(shape1: &D1, shape2: &D2) -> Result +where + D1: Dimension, + D2: Dimension, + Output: Dimension, +{ + let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim()); + // Swap the order if d2 is longer. + if overflow { + return co_broadcast::(shape2, shape1); + } + // The output should be the same length as shape1. + let mut out = Output::zeros(shape1.ndim()); + for (out, s) in izip!(out.slice_mut(), shape1.slice()) { + *out = *s; + } + for (out, s2) in izip!(&mut out.slice_mut()[k..], shape2.slice()) { + if *out != *s2 { + if *out == 1 { + *out = *s2 + } else if *s2 != 1 { + return Err(from_kind(ErrorKind::IncompatibleShape)); + } + } + } + Ok(out) +} + +pub trait DimMax +{ + /// The resulting dimension type after broadcasting. + type Output: Dimension; +} + +/// Dimensions of the same type remain unchanged when co_broadcast. +/// So you can directly use `D` as the resulting type. +/// (Instead of `>::BroadcastOutput`) +impl DimMax for D +{ + type Output = D; +} + +macro_rules! impl_broadcast_distinct_fixed { + ($smaller:ty, $larger:ty) => { + impl DimMax<$larger> for $smaller { + type Output = $larger; + } + + impl DimMax<$smaller> for $larger { + type Output = $larger; + } + }; +} + +impl_broadcast_distinct_fixed!(Ix0, Ix1); +impl_broadcast_distinct_fixed!(Ix0, Ix2); +impl_broadcast_distinct_fixed!(Ix0, Ix3); +impl_broadcast_distinct_fixed!(Ix0, Ix4); +impl_broadcast_distinct_fixed!(Ix0, Ix5); +impl_broadcast_distinct_fixed!(Ix0, Ix6); +impl_broadcast_distinct_fixed!(Ix1, Ix2); +impl_broadcast_distinct_fixed!(Ix1, Ix3); +impl_broadcast_distinct_fixed!(Ix1, Ix4); +impl_broadcast_distinct_fixed!(Ix1, Ix5); +impl_broadcast_distinct_fixed!(Ix1, Ix6); +impl_broadcast_distinct_fixed!(Ix2, Ix3); +impl_broadcast_distinct_fixed!(Ix2, Ix4); +impl_broadcast_distinct_fixed!(Ix2, Ix5); +impl_broadcast_distinct_fixed!(Ix2, Ix6); +impl_broadcast_distinct_fixed!(Ix3, Ix4); +impl_broadcast_distinct_fixed!(Ix3, Ix5); +impl_broadcast_distinct_fixed!(Ix3, Ix6); +impl_broadcast_distinct_fixed!(Ix4, Ix5); +impl_broadcast_distinct_fixed!(Ix4, Ix6); +impl_broadcast_distinct_fixed!(Ix5, Ix6); +impl_broadcast_distinct_fixed!(Ix0, IxDyn); +impl_broadcast_distinct_fixed!(Ix1, IxDyn); +impl_broadcast_distinct_fixed!(Ix2, IxDyn); +impl_broadcast_distinct_fixed!(Ix3, IxDyn); +impl_broadcast_distinct_fixed!(Ix4, IxDyn); +impl_broadcast_distinct_fixed!(Ix5, IxDyn); +impl_broadcast_distinct_fixed!(Ix6, IxDyn); + +#[cfg(test)] +#[cfg(feature = "std")] +mod tests +{ + use super::co_broadcast; + use crate::{Dim, DimMax, Dimension, ErrorKind, Ix0, IxDynImpl, ShapeError}; + + #[test] + fn test_broadcast_shape() + { + fn test_co(d1: &D1, d2: &D2, r: Result<>::Output, ShapeError>) + where + D1: Dimension + DimMax, + D2: Dimension, + { + let d = co_broadcast::>::Output>(&d1, d2); + assert_eq!(d, r); + } + test_co(&Dim([2, 3]), &Dim([4, 1, 3]), Ok(Dim([4, 2, 3]))); + test_co(&Dim([1, 2, 2]), &Dim([1, 3, 4]), Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + test_co(&Dim([3, 4, 5]), &Ix0(), Ok(Dim([3, 4, 5]))); + let v = vec![1, 2, 3, 4, 5, 6, 7]; + test_co(&Dim(vec![1, 1, 3, 1, 5, 1, 7]), &Dim([2, 1, 4, 1, 6, 1]), Ok(Dim(IxDynImpl::from(v.as_slice())))); + let d = Dim([1, 2, 1, 3]); + test_co(&d, &d, Ok(d)); + test_co(&Dim([2, 1, 2]).into_dyn(), &Dim(0), Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + test_co(&Dim([2, 1, 1]), &Dim([0, 0, 1, 3, 4]), Ok(Dim([0, 0, 2, 3, 4]))); + test_co(&Dim([0]), &Dim([0, 0, 0]), Ok(Dim([0, 0, 0]))); + test_co(&Dim(1), &Dim([1, 0, 0]), Ok(Dim([1, 0, 0]))); + test_co(&Dim([1, 3, 0, 1, 1]), &Dim([1, 2, 3, 1]), Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + } +} diff --git a/src/dimension/conversion.rs b/src/dimension/conversion.rs index bf48dae2f..0cf2e1296 100644 --- a/src/dimension/conversion.rs +++ b/src/dimension/conversion.rs @@ -8,6 +8,8 @@ //! Tuple to array conversion, IntoDimension, and related things +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; use num_traits::Zero; use std::ops::{Index, IndexMut}; @@ -38,47 +40,55 @@ macro_rules! index_item { } /// Argument conversion a dimension. -pub trait IntoDimension { +pub trait IntoDimension +{ type Dim: Dimension; fn into_dimension(self) -> Self::Dim; } -impl IntoDimension for Ix { +impl IntoDimension for Ix +{ type Dim = Ix1; #[inline(always)] - fn into_dimension(self) -> Ix1 { + fn into_dimension(self) -> Ix1 + { Ix1(self) } } impl IntoDimension for D -where - D: Dimension, +where D: Dimension { type Dim = D; #[inline(always)] - fn into_dimension(self) -> Self { + fn into_dimension(self) -> Self + { self } } -impl IntoDimension for IxDynImpl { +impl IntoDimension for IxDynImpl +{ type Dim = IxDyn; #[inline(always)] - fn into_dimension(self) -> Self::Dim { + fn into_dimension(self) -> Self::Dim + { Dim::new(self) } } -impl IntoDimension for Vec { +impl IntoDimension for Vec +{ type Dim = IxDyn; #[inline(always)] - fn into_dimension(self) -> Self::Dim { + fn into_dimension(self) -> Self::Dim + { Dim::new(IxDynImpl::from(self)) } } -pub trait Convert { +pub trait Convert +{ type To; fn convert(self) -> Self::To; } @@ -92,25 +102,25 @@ macro_rules! sub { macro_rules! tuple_type { ([$T:ident] $($index:tt)*) => ( ( $(sub!($index $T), )* ) - ) + ); } macro_rules! tuple_expr { ([$self_:expr] $($index:tt)*) => ( ( $($self_[$index], )* ) - ) + ); } macro_rules! array_expr { ([$self_:expr] $($index:tt)*) => ( [$($self_ . $index, )*] - ) + ); } macro_rules! array_zero { ([] $($index:tt)*) => ( [$(sub!($index 0), )*] - ) + ); } macro_rules! tuple_to_array { @@ -166,7 +176,7 @@ macro_rules! tuple_to_array { } )* - } + }; } index_item!(tuple_to_array [] 7); diff --git a/src/dimension/dim.rs b/src/dimension/dim.rs index 3f47e15ae..96e433bb3 100644 --- a/src/dimension/dim.rs +++ b/src/dimension/dim.rs @@ -6,19 +6,18 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::fmt; - use super::Dimension; use super::IntoDimension; use crate::itertools::zip; use crate::Ix; +use std::fmt; /// Dimension description. /// /// `Dim` describes the number of axes and the length of each axis /// in an array. It is also used as an index type. /// -/// See also the [`Dimension` trait](trait.Dimension.html) for its methods and +/// See also the [`Dimension`] trait for its methods and /// operations. /// /// # Examples @@ -36,21 +35,26 @@ use crate::Ix; /// assert_eq!(array.raw_dim(), Dim([3, 2])); /// ``` #[derive(Copy, Clone, PartialEq, Eq, Hash, Default)] -pub struct Dim { +pub struct Dim +{ index: I, } -impl Dim { +impl Dim +{ /// Private constructor and accessors for Dim - pub(crate) fn new(index: I) -> Dim { + pub(crate) const fn new(index: I) -> Dim + { Dim { index } } #[inline(always)] - pub(crate) fn ix(&self) -> &I { + pub(crate) fn ix(&self) -> &I + { &self.index } #[inline(always)] - pub(crate) fn ixm(&mut self) -> &mut I { + pub(crate) fn ixm(&mut self) -> &mut I + { &mut self.index } } @@ -58,26 +62,25 @@ impl Dim { /// Create a new dimension value. #[allow(non_snake_case)] pub fn Dim(index: T) -> T::Dim -where - T: IntoDimension, +where T: IntoDimension { index.into_dimension() } impl PartialEq for Dim -where - I: PartialEq, +where I: PartialEq { - fn eq(&self, rhs: &I) -> bool { + fn eq(&self, rhs: &I) -> bool + { self.index == *rhs } } impl fmt::Debug for Dim -where - I: fmt::Debug, +where I: fmt::Debug { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { write!(f, "{:?}", self.index) } } diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 4bfe7c0b2..3544a7f3c 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -6,19 +6,22 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; use std::ops::{Index, IndexMut}; use super::axes_of; use super::conversion::Convert; +use super::ops::DimAdd; use super::{stride_offset, stride_offset_checked}; use crate::itertools::{enumerate, zip}; -use crate::Axis; use crate::IntoDimension; use crate::RemoveAxis; use crate::{ArrayView1, ArrayViewMut1}; -use crate::{Dim, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, Ixs, SliceOrIndex}; +use crate::{Axis, DimMax}; +use crate::{Dim, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, Ixs}; /// Array shape and index trait. /// @@ -45,26 +48,22 @@ pub trait Dimension: + MulAssign + for<'x> MulAssign<&'x Self> + MulAssign + + DimMax + + DimMax + + DimMax + + DimMax<::Smaller, Output = Self> + + DimMax<::Larger, Output = ::Larger> + + DimAdd + + DimAdd<::Smaller> + + DimAdd<::Larger> + + DimAdd + + DimAdd::Larger> + + DimAdd { /// For fixed-size dimension representations (e.g. `Ix2`), this should be /// `Some(ndim)`, and for variable-size dimension representations (e.g. /// `IxDyn`), this should be `None`. const NDIM: Option; - /// `SliceArg` is the type which is used to specify slicing for this - /// dimension. - /// - /// For the fixed size dimensions it is a fixed size array of the correct - /// size, which you pass by reference. For the dynamic dimension it is - /// a slice. - /// - /// - For `Ix1`: `[SliceOrIndex; 1]` - /// - For `Ix2`: `[SliceOrIndex; 2]` - /// - and so on.. - /// - For `IxDyn`: `[SliceOrIndex]` - /// - /// The easiest way to create a `&SliceInfo` is using the - /// [`s![]`](macro.s!.html) macro. - type SliceArg: ?Sized + AsRef<[SliceOrIndex]>; /// Pattern matching friendly form of the dimension value. /// /// - For `Ix1`: `usize`, @@ -75,7 +74,7 @@ pub trait Dimension: /// Next smaller dimension (if applicable) type Smaller: Dimension; /// Next larger dimension - type Larger: Dimension; + type Larger: Dimension + RemoveAxis; /// Returns the number of dimensions (number of axes). fn ndim(&self) -> usize; @@ -84,15 +83,17 @@ pub trait Dimension: fn into_pattern(self) -> Self::Pattern; /// Compute the size of the dimension (number of elements) - fn size(&self) -> usize { - self.slice().iter().fold(1, |s, &a| s * a as usize) + fn size(&self) -> usize + { + self.slice().iter().product() } /// Compute the size while checking for overflow. - fn size_checked(&self) -> Option { + fn size_checked(&self) -> Option + { self.slice() .iter() - .fold(Some(1), |s, &a| s.and_then(|s_| s_.checked_mul(a))) + .try_fold(1_usize, |s, &a| s.checked_mul(a)) } #[doc(hidden)] @@ -102,17 +103,20 @@ pub trait Dimension: fn slice_mut(&mut self) -> &mut [Ix]; /// Borrow as a read-only array view. - fn as_array_view(&self) -> ArrayView1<'_, Ix> { + fn as_array_view(&self) -> ArrayView1<'_, Ix> + { ArrayView1::from(self.slice()) } /// Borrow as a read-write array view. - fn as_array_view_mut(&mut self) -> ArrayViewMut1<'_, Ix> { + fn as_array_view_mut(&mut self) -> ArrayViewMut1<'_, Ix> + { ArrayViewMut1::from(self.slice_mut()) } #[doc(hidden)] - fn equal(&self, rhs: &Self) -> bool { + fn equal(&self, rhs: &Self) -> bool + { self.slice() == rhs.slice() } @@ -121,7 +125,8 @@ pub trait Dimension: /// If the array is non-empty, the strides result in contiguous layout; if /// the array is empty, the strides are all zeros. #[doc(hidden)] - fn default_strides(&self) -> Self { + fn default_strides(&self) -> Self + { // Compute default array strides // Shape (a, b, c) => Give strides (b * c, c, 1) let mut strides = Self::zeros(self.ndim()); @@ -146,7 +151,8 @@ pub trait Dimension: /// If the array is non-empty, the strides result in contiguous layout; if /// the array is empty, the strides are all zeros. #[doc(hidden)] - fn fortran_strides(&self) -> Self { + fn fortran_strides(&self) -> Self + { // Compute fortran array strides // Shape (a, b, c) => Give strides (1, a, a * b) let mut strides = Self::zeros(self.ndim()); @@ -176,7 +182,8 @@ pub trait Dimension: #[doc(hidden)] #[inline] - fn first_index(&self) -> Option { + fn first_index(&self) -> Option + { for ax in self.slice().iter() { if *ax == 0 { return None; @@ -190,7 +197,8 @@ pub trait Dimension: /// or None if there are no more. // FIXME: use &Self for index or even &mut? #[inline] - fn next_for(&self, index: Self) -> Option { + fn next_for(&self, index: Self) -> Option + { let mut index = index; let mut done = false; for (&dim, ix) in zip(self.slice(), index.slice_mut()).rev() { @@ -215,7 +223,8 @@ pub trait Dimension: /// /// Next in f-order #[inline] - fn next_for_f(&self, index: &mut Self) -> bool { + fn next_for_f(&self, index: &mut Self) -> bool + { let mut end_iteration = true; for (&dim, ix) in zip(self.slice(), index.slice_mut()) { *ix += 1; @@ -229,9 +238,28 @@ pub trait Dimension: !end_iteration } + /// Returns `true` iff `strides1` and `strides2` are equivalent for the + /// shape `self`. + /// + /// The strides are equivalent if, for each axis with length > 1, the + /// strides are equal. + /// + /// Note: Returns `false` if any of the ndims don't match. + #[doc(hidden)] + fn strides_equivalent(&self, strides1: &Self, strides2: &D) -> bool + where D: Dimension + { + let shape_ndim = self.ndim(); + shape_ndim == strides1.ndim() + && shape_ndim == strides2.ndim() + && izip!(self.slice(), strides1.slice(), strides2.slice()) + .all(|(&d, &s1, &s2)| d <= 1 || s1 as isize == s2 as isize) + } + #[doc(hidden)] /// Return stride offset for index. - fn stride_offset(index: &Self, strides: &Self) -> isize { + fn stride_offset(index: &Self, strides: &Self) -> isize + { let mut offset = 0; for (&i, &s) in izip!(index.slice(), strides.slice()) { offset += stride_offset(i, s); @@ -241,12 +269,14 @@ pub trait Dimension: #[doc(hidden)] /// Return stride offset for this dimension and index. - fn stride_offset_checked(&self, strides: &Self, index: &Self) -> Option { + fn stride_offset_checked(&self, strides: &Self, index: &Self) -> Option + { stride_offset_checked(self.slice(), strides.slice(), index.slice()) } #[doc(hidden)] - fn last_elem(&self) -> usize { + fn last_elem(&self) -> usize + { if self.ndim() == 0 { 0 } else { @@ -255,56 +285,64 @@ pub trait Dimension: } #[doc(hidden)] - fn set_last_elem(&mut self, i: usize) { + fn set_last_elem(&mut self, i: usize) + { let nd = self.ndim(); self.slice_mut()[nd - 1] = i; } #[doc(hidden)] - fn is_contiguous(dim: &Self, strides: &Self) -> bool { + fn is_contiguous(dim: &Self, strides: &Self) -> bool + { let defaults = dim.default_strides(); if strides.equal(&defaults) { return true; } if dim.ndim() == 1 { - return false; - } - let order = strides._fastest_varying_stride_order(); - let strides = strides.slice(); - - // FIXME: Negative strides - let dim_slice = dim.slice(); - let mut cstride = 1; - for &i in order.slice() { - // a dimension of length 1 can have unequal strides - if dim_slice[i] != 1 && strides[i] != cstride { - return false; + // fast case for ndim == 1: + // Either we have length <= 1, then stride is arbitrary, + // or we have stride == 1 or stride == -1, but +1 case is already handled above. + dim[0] <= 1 || strides[0] as isize == -1 + } else { + let order = strides._fastest_varying_stride_order(); + let strides = strides.slice(); + + let dim_slice = dim.slice(); + let mut cstride = 1; + for &i in order.slice() { + // a dimension of length 1 can have unequal strides + if dim_slice[i] != 1 && (strides[i] as isize).unsigned_abs() != cstride { + return false; + } + cstride *= dim_slice[i]; } - cstride *= dim_slice[i]; + true } - true } /// Return the axis ordering corresponding to the fastest variation /// (in ascending order). /// - /// Assumes that no stride value appears twice. This cannot yield the correct - /// result the strides are not positive. + /// Assumes that no stride value appears twice. #[doc(hidden)] - fn _fastest_varying_stride_order(&self) -> Self { + fn _fastest_varying_stride_order(&self) -> Self + { let mut indices = self.clone(); for (i, elt) in enumerate(indices.slice_mut()) { *elt = i; } let strides = self.slice(); - indices.slice_mut().sort_by_key(|&i| strides[i]); + indices + .slice_mut() + .sort_by_key(|&i| (strides[i] as isize).abs()); indices } /// Compute the minimum stride axis (absolute value), under the constraint /// that the length of the axis is > 1; #[doc(hidden)] - fn min_stride_axis(&self, strides: &Self) -> Axis { + fn min_stride_axis(&self, strides: &Self) -> Axis + { let n = match self.ndim() { 0 => panic!("min_stride_axis: Array must have ndim > 0"), 1 => return Axis(0), @@ -312,32 +350,35 @@ pub trait Dimension: }; axes_of(self, strides) .rev() - .min_by_key(|ax| ax.stride().abs()) - .map_or(Axis(n - 1), |ax| ax.axis()) + .min_by_key(|ax| ax.stride.abs()) + .map_or(Axis(n - 1), |ax| ax.axis) } /// Compute the maximum stride axis (absolute value), under the constraint /// that the length of the axis is > 1; #[doc(hidden)] - fn max_stride_axis(&self, strides: &Self) -> Axis { + fn max_stride_axis(&self, strides: &Self) -> Axis + { match self.ndim() { 0 => panic!("max_stride_axis: Array must have ndim > 0"), 1 => return Axis(0), _ => {} } axes_of(self, strides) - .filter(|ax| ax.len() > 1) - .max_by_key(|ax| ax.stride().abs()) - .map_or(Axis(0), |ax| ax.axis()) + .filter(|ax| ax.len > 1) + .max_by_key(|ax| ax.stride.abs()) + .map_or(Axis(0), |ax| ax.axis) } /// Convert the dimensional into a dynamic dimensional (IxDyn). - fn into_dyn(self) -> IxDyn { + fn into_dyn(self) -> IxDyn + { IxDyn(self.slice()) } #[doc(hidden)] - fn from_dimension(d: &D2) -> Option { + fn from_dimension(d: &D2) -> Option + { let mut s = Self::default(); if s.ndim() == d.ndim() { for i in 0..d.ndim() { @@ -362,6 +403,7 @@ pub trait Dimension: macro_rules! impl_insert_axis_array( ($n:expr) => ( + #[inline] fn insert_axis(&self, axis: Axis) -> Self::Larger { debug_assert!(axis.index() <= $n); let mut out = [1; $n + 1]; @@ -372,79 +414,91 @@ macro_rules! impl_insert_axis_array( ); ); -impl Dimension for Dim<[Ix; 0]> { +impl Dimension for Dim<[Ix; 0]> +{ const NDIM: Option = Some(0); - type SliceArg = [SliceOrIndex; 0]; type Pattern = (); type Smaller = Self; type Larger = Ix1; // empty product is 1 -> size is 1 #[inline] - fn ndim(&self) -> usize { + fn ndim(&self) -> usize + { 0 } #[inline] - fn slice(&self) -> &[Ix] { + fn slice(&self) -> &[Ix] + { &[] } #[inline] - fn slice_mut(&mut self) -> &mut [Ix] { + fn slice_mut(&mut self) -> &mut [Ix] + { &mut [] } #[inline] - fn _fastest_varying_stride_order(&self) -> Self { + fn _fastest_varying_stride_order(&self) -> Self + { Ix0() } #[inline] fn into_pattern(self) -> Self::Pattern {} #[inline] - fn zeros(ndim: usize) -> Self { + fn zeros(ndim: usize) -> Self + { assert_eq!(ndim, 0); Self::default() } #[inline] - fn next_for(&self, _index: Self) -> Option { + fn next_for(&self, _index: Self) -> Option + { None } - #[inline] impl_insert_axis_array!(0); #[inline] - fn try_remove_axis(&self, _ignore: Axis) -> Self::Smaller { + fn try_remove_axis(&self, _ignore: Axis) -> Self::Smaller + { *self } private_impl! {} } -impl Dimension for Dim<[Ix; 1]> { +impl Dimension for Dim<[Ix; 1]> +{ const NDIM: Option = Some(1); - type SliceArg = [SliceOrIndex; 1]; type Pattern = Ix; type Smaller = Ix0; type Larger = Ix2; #[inline] - fn ndim(&self) -> usize { + fn ndim(&self) -> usize + { 1 } #[inline] - fn slice(&self) -> &[Ix] { + fn slice(&self) -> &[Ix] + { self.ix() } #[inline] - fn slice_mut(&mut self) -> &mut [Ix] { + fn slice_mut(&mut self) -> &mut [Ix] + { self.ixm() } #[inline] - fn into_pattern(self) -> Self::Pattern { + fn into_pattern(self) -> Self::Pattern + { get!(&self, 0) } #[inline] - fn zeros(ndim: usize) -> Self { + fn zeros(ndim: usize) -> Self + { assert_eq!(ndim, 1); Self::default() } #[inline] - fn next_for(&self, mut index: Self) -> Option { + fn next_for(&self, mut index: Self) -> Option + { getm!(index, 0) += 1; if get!(&index, 0) < get!(self, 0) { Some(index) @@ -454,21 +508,25 @@ impl Dimension for Dim<[Ix; 1]> { } #[inline] - fn equal(&self, rhs: &Self) -> bool { + fn equal(&self, rhs: &Self) -> bool + { get!(self, 0) == get!(rhs, 0) } #[inline] - fn size(&self) -> usize { + fn size(&self) -> usize + { get!(self, 0) } #[inline] - fn size_checked(&self) -> Option { + fn size_checked(&self) -> Option + { Some(get!(self, 0)) } #[inline] - fn default_strides(&self) -> Self { + fn default_strides(&self) -> Self + { if get!(self, 0) == 0 { Ix1(0) } else { @@ -477,22 +535,26 @@ impl Dimension for Dim<[Ix; 1]> { } #[inline] - fn _fastest_varying_stride_order(&self) -> Self { + fn _fastest_varying_stride_order(&self) -> Self + { Ix1(0) } #[inline(always)] - fn min_stride_axis(&self, _: &Self) -> Axis { + fn min_stride_axis(&self, _: &Self) -> Axis + { Axis(0) } #[inline(always)] - fn max_stride_axis(&self, _: &Self) -> Axis { + fn max_stride_axis(&self, _: &Self) -> Axis + { Axis(0) } #[inline] - fn first_index(&self) -> Option { + fn first_index(&self) -> Option + { if get!(self, 0) != 0 { Some(Ix1(0)) } else { @@ -502,57 +564,74 @@ impl Dimension for Dim<[Ix; 1]> { /// Self is an index, return the stride offset #[inline(always)] - fn stride_offset(index: &Self, stride: &Self) -> isize { + fn stride_offset(index: &Self, stride: &Self) -> isize + { stride_offset(get!(index, 0), get!(stride, 0)) } /// Return stride offset for this dimension and index. #[inline] - fn stride_offset_checked(&self, stride: &Self, index: &Self) -> Option { + fn stride_offset_checked(&self, stride: &Self, index: &Self) -> Option + { if get!(index, 0) < get!(self, 0) { Some(stride_offset(get!(index, 0), get!(stride, 0))) } else { None } } - #[inline] impl_insert_axis_array!(1); #[inline] - fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { + fn try_remove_axis(&self, axis: Axis) -> Self::Smaller + { self.remove_axis(axis) } + + fn from_dimension(d: &D2) -> Option + { + if 1 == d.ndim() { + Some(Ix1(d[0])) + } else { + None + } + } private_impl! {} } -impl Dimension for Dim<[Ix; 2]> { +impl Dimension for Dim<[Ix; 2]> +{ const NDIM: Option = Some(2); - type SliceArg = [SliceOrIndex; 2]; type Pattern = (Ix, Ix); type Smaller = Ix1; type Larger = Ix3; #[inline] - fn ndim(&self) -> usize { + fn ndim(&self) -> usize + { 2 } #[inline] - fn into_pattern(self) -> Self::Pattern { + fn into_pattern(self) -> Self::Pattern + { self.ix().convert() } #[inline] - fn slice(&self) -> &[Ix] { + fn slice(&self) -> &[Ix] + { self.ix() } #[inline] - fn slice_mut(&mut self) -> &mut [Ix] { + fn slice_mut(&mut self) -> &mut [Ix] + { self.ixm() } #[inline] - fn zeros(ndim: usize) -> Self { + fn zeros(ndim: usize) -> Self + { assert_eq!(ndim, 2); Self::default() } #[inline] - fn next_for(&self, index: Self) -> Option { + fn next_for(&self, index: Self) -> Option + { let mut i = get!(&index, 0); let mut j = get!(&index, 1); let imax = get!(self, 0); @@ -569,34 +648,40 @@ impl Dimension for Dim<[Ix; 2]> { } #[inline] - fn equal(&self, rhs: &Self) -> bool { + fn equal(&self, rhs: &Self) -> bool + { get!(self, 0) == get!(rhs, 0) && get!(self, 1) == get!(rhs, 1) } #[inline] - fn size(&self) -> usize { + fn size(&self) -> usize + { get!(self, 0) * get!(self, 1) } #[inline] - fn size_checked(&self) -> Option { + fn size_checked(&self) -> Option + { let m = get!(self, 0); let n = get!(self, 1); - (m as usize).checked_mul(n as usize) + m.checked_mul(n) } #[inline] - fn last_elem(&self) -> usize { + fn last_elem(&self) -> usize + { get!(self, 1) } #[inline] - fn set_last_elem(&mut self, i: usize) { + fn set_last_elem(&mut self, i: usize) + { getm!(self, 1) = i; } #[inline] - fn default_strides(&self) -> Self { + fn default_strides(&self) -> Self + { let m = get!(self, 0); let n = get!(self, 1); if m == 0 || n == 0 { @@ -606,7 +691,8 @@ impl Dimension for Dim<[Ix; 2]> { } } #[inline] - fn fortran_strides(&self) -> Self { + fn fortran_strides(&self) -> Self + { let m = get!(self, 0); let n = get!(self, 1); if m == 0 || n == 0 { @@ -617,8 +703,9 @@ impl Dimension for Dim<[Ix; 2]> { } #[inline] - fn _fastest_varying_stride_order(&self) -> Self { - if get!(self, 0) as Ixs <= get!(self, 1) as Ixs { + fn _fastest_varying_stride_order(&self) -> Self + { + if (get!(self, 0) as Ixs).abs() <= (get!(self, 1) as Ixs).abs() { Ix2(0, 1) } else { Ix2(1, 0) @@ -626,7 +713,8 @@ impl Dimension for Dim<[Ix; 2]> { } #[inline] - fn min_stride_axis(&self, strides: &Self) -> Axis { + fn min_stride_axis(&self, strides: &Self) -> Axis + { let s = get!(strides, 0) as Ixs; let t = get!(strides, 1) as Ixs; if s.abs() < t.abs() { @@ -637,7 +725,8 @@ impl Dimension for Dim<[Ix; 2]> { } #[inline] - fn first_index(&self) -> Option { + fn first_index(&self) -> Option + { let m = get!(self, 0); let n = get!(self, 1); if m != 0 && n != 0 { @@ -649,7 +738,8 @@ impl Dimension for Dim<[Ix; 2]> { /// Self is an index, return the stride offset #[inline(always)] - fn stride_offset(index: &Self, strides: &Self) -> isize { + fn stride_offset(index: &Self, strides: &Self) -> isize + { let i = get!(index, 0); let j = get!(index, 1); let s = get!(strides, 0); @@ -659,7 +749,8 @@ impl Dimension for Dim<[Ix; 2]> { /// Return stride offset for this dimension and index. #[inline] - fn stride_offset_checked(&self, strides: &Self, index: &Self) -> Option { + fn stride_offset_checked(&self, strides: &Self, index: &Self) -> Option + { let m = get!(self, 0); let n = get!(self, 1); let i = get!(index, 0); @@ -672,54 +763,61 @@ impl Dimension for Dim<[Ix; 2]> { None } } - #[inline] impl_insert_axis_array!(2); #[inline] - fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { + fn try_remove_axis(&self, axis: Axis) -> Self::Smaller + { self.remove_axis(axis) } private_impl! {} } -impl Dimension for Dim<[Ix; 3]> { +impl Dimension for Dim<[Ix; 3]> +{ const NDIM: Option = Some(3); - type SliceArg = [SliceOrIndex; 3]; type Pattern = (Ix, Ix, Ix); type Smaller = Ix2; type Larger = Ix4; #[inline] - fn ndim(&self) -> usize { + fn ndim(&self) -> usize + { 3 } #[inline] - fn into_pattern(self) -> Self::Pattern { + fn into_pattern(self) -> Self::Pattern + { self.ix().convert() } #[inline] - fn slice(&self) -> &[Ix] { + fn slice(&self) -> &[Ix] + { self.ix() } #[inline] - fn slice_mut(&mut self) -> &mut [Ix] { + fn slice_mut(&mut self) -> &mut [Ix] + { self.ixm() } #[inline] - fn size(&self) -> usize { + fn size(&self) -> usize + { let m = get!(self, 0); let n = get!(self, 1); let o = get!(self, 2); - m as usize * n as usize * o as usize + m * n * o } #[inline] - fn zeros(ndim: usize) -> Self { + fn zeros(ndim: usize) -> Self + { assert_eq!(ndim, 3); Self::default() } #[inline] - fn next_for(&self, index: Self) -> Option { + fn next_for(&self, index: Self) -> Option + { let mut i = get!(&index, 0); let mut j = get!(&index, 1); let mut k = get!(&index, 2); @@ -743,7 +841,8 @@ impl Dimension for Dim<[Ix; 3]> { /// Self is an index, return the stride offset #[inline] - fn stride_offset(index: &Self, strides: &Self) -> isize { + fn stride_offset(index: &Self, strides: &Self) -> isize + { let i = get!(index, 0); let j = get!(index, 1); let k = get!(index, 2); @@ -755,7 +854,8 @@ impl Dimension for Dim<[Ix; 3]> { /// Return stride offset for this dimension and index. #[inline] - fn stride_offset_checked(&self, strides: &Self, index: &Self) -> Option { + fn stride_offset_checked(&self, strides: &Self, index: &Self) -> Option + { let m = get!(self, 0); let n = get!(self, 1); let l = get!(self, 2); @@ -773,12 +873,13 @@ impl Dimension for Dim<[Ix; 3]> { } #[inline] - fn _fastest_varying_stride_order(&self) -> Self { + fn _fastest_varying_stride_order(&self) -> Self + { let mut stride = *self; let mut order = Ix3(0, 1, 2); macro_rules! swap { ($stride:expr, $order:expr, $x:expr, $y:expr) => { - if $stride[$x] > $stride[$y] { + if ($stride[$x] as isize).abs() > ($stride[$y] as isize).abs() { $stride.swap($x, $y); $order.ixm().swap($x, $y); } @@ -793,10 +894,10 @@ impl Dimension for Dim<[Ix; 3]> { } order } - #[inline] impl_insert_axis_array!(3); #[inline] - fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { + fn try_remove_axis(&self, axis: Axis) -> Self::Smaller + { self.remove_axis(axis) } private_impl! {} @@ -806,7 +907,6 @@ macro_rules! large_dim { ($n:expr, $name:ident, $pattern:ty, $larger:ty, { $($insert_axis:tt)* }) => ( impl Dimension for Dim<[Ix; $n]> { const NDIM: Option = Some($n); - type SliceArg = [SliceOrIndex; $n]; type Pattern = $pattern; type Smaller = Dim<[Ix; $n - 1]>; type Larger = $larger; @@ -825,7 +925,6 @@ macro_rules! large_dim { assert_eq!(ndim, $n); Self::default() } - #[inline] $($insert_axis)* #[inline] fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { @@ -833,7 +932,7 @@ macro_rules! large_dim { } private_impl!{} } - ) + ); } large_dim!(4, Ix4, (Ix, Ix, Ix, Ix), Ix5, { @@ -855,42 +954,49 @@ large_dim!(6, Ix6, (Ix, Ix, Ix, Ix, Ix, Ix), IxDyn, { /// IxDyn is a "dynamic" index, pretty hard to use when indexing, /// and memory wasteful, but it allows an arbitrary and dynamic number of axes. -impl Dimension for IxDyn { +impl Dimension for IxDyn +{ const NDIM: Option = None; - type SliceArg = [SliceOrIndex]; type Pattern = Self; type Smaller = Self; type Larger = Self; #[inline] - fn ndim(&self) -> usize { + fn ndim(&self) -> usize + { self.ix().len() } #[inline] - fn slice(&self) -> &[Ix] { + fn slice(&self) -> &[Ix] + { self.ix() } #[inline] - fn slice_mut(&mut self) -> &mut [Ix] { + fn slice_mut(&mut self) -> &mut [Ix] + { self.ixm() } #[inline] - fn into_pattern(self) -> Self::Pattern { + fn into_pattern(self) -> Self::Pattern + { self } #[inline] - fn zeros(ndim: usize) -> Self { + fn zeros(ndim: usize) -> Self + { IxDyn::zeros(ndim) } #[inline] - fn insert_axis(&self, axis: Axis) -> Self::Larger { + fn insert_axis(&self, axis: Axis) -> Self::Larger + { debug_assert!(axis.index() <= self.ndim()); Dim::new(self.ix().insert(axis.index())) } #[inline] - fn try_remove_axis(&self, axis: Axis) -> Self::Smaller { + fn try_remove_axis(&self, axis: Axis) -> Self::Smaller + { if self.ndim() > 0 { self.remove_axis(axis) } else { @@ -898,21 +1004,32 @@ impl Dimension for IxDyn { } } - fn from_dimension(d: &D2) -> Option { + fn from_dimension(d: &D2) -> Option + { Some(IxDyn(d.slice())) } + + fn into_dyn(self) -> IxDyn + { + self + } + private_impl! {} } -impl Index for Dim { +impl Index for Dim +{ type Output = >::Output; - fn index(&self, index: usize) -> &Self::Output { + fn index(&self, index: usize) -> &Self::Output + { &self.ix()[index] } } -impl IndexMut for Dim { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { +impl IndexMut for Dim +{ + fn index_mut(&mut self, index: usize) -> &mut Self::Output + { &mut self.ixm()[index] } } diff --git a/src/dimension/dynindeximpl.rs b/src/dimension/dynindeximpl.rs index 76087aa52..60aeacd80 100644 --- a/src/dimension/dynindeximpl.rs +++ b/src/dimension/dynindeximpl.rs @@ -1,55 +1,68 @@ use crate::imp_prelude::*; +#[cfg(not(feature = "std"))] +use alloc::boxed::Box; +use alloc::vec; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; use std::hash::{Hash, Hasher}; use std::ops::{Deref, DerefMut, Index, IndexMut}; - const CAP: usize = 4; /// T is usize or isize #[derive(Debug)] -enum IxDynRepr { +enum IxDynRepr +{ Inline(u32, [T; CAP]), Alloc(Box<[T]>), } -impl Deref for IxDynRepr { +impl Deref for IxDynRepr +{ type Target = [T]; - fn deref(&self) -> &[T] { + fn deref(&self) -> &[T] + { match *self { IxDynRepr::Inline(len, ref ar) => { debug_assert!(len as usize <= ar.len()); unsafe { ar.get_unchecked(..len as usize) } } - IxDynRepr::Alloc(ref ar) => &*ar, + IxDynRepr::Alloc(ref ar) => ar, } } } -impl DerefMut for IxDynRepr { - fn deref_mut(&mut self) -> &mut [T] { +impl DerefMut for IxDynRepr +{ + fn deref_mut(&mut self) -> &mut [T] + { match *self { IxDynRepr::Inline(len, ref mut ar) => { debug_assert!(len as usize <= ar.len()); unsafe { ar.get_unchecked_mut(..len as usize) } } - IxDynRepr::Alloc(ref mut ar) => &mut *ar, + IxDynRepr::Alloc(ref mut ar) => ar, } } } /// The default is equivalent to `Self::from(&[0])`. -impl Default for IxDynRepr { - fn default() -> Self { +impl Default for IxDynRepr +{ + fn default() -> Self + { Self::copy_from(&[0]) } } use num_traits::Zero; -impl IxDynRepr { - pub fn copy_from(x: &[T]) -> Self { +impl IxDynRepr +{ + pub fn copy_from(x: &[T]) -> Self + { if x.len() <= CAP { let mut arr = [T::zero(); CAP]; - arr[..x.len()].copy_from_slice(&x[..]); + arr[..x.len()].copy_from_slice(x); IxDynRepr::Inline(x.len() as _, arr) } else { Self::from(x) @@ -57,9 +70,11 @@ impl IxDynRepr { } } -impl IxDynRepr { +impl IxDynRepr +{ // make an Inline or Alloc version as appropriate - fn from_vec_auto(v: Vec) -> Self { + fn from_vec_auto(v: Vec) -> Self + { if v.len() <= CAP { Self::copy_from(&v) } else { @@ -68,18 +83,23 @@ impl IxDynRepr { } } -impl IxDynRepr { - fn from_vec(v: Vec) -> Self { +impl IxDynRepr +{ + fn from_vec(v: Vec) -> Self + { IxDynRepr::Alloc(v.into_boxed_slice()) } - fn from(x: &[T]) -> Self { + fn from(x: &[T]) -> Self + { Self::from_vec(x.to_vec()) } } -impl Clone for IxDynRepr { - fn clone(&self) -> Self { +impl Clone for IxDynRepr +{ + fn clone(&self) -> Self + { match *self { IxDynRepr::Inline(len, arr) => IxDynRepr::Inline(len, arr), _ => Self::from(&self[..]), @@ -89,22 +109,25 @@ impl Clone for IxDynRepr { impl Eq for IxDynRepr {} -impl PartialEq for IxDynRepr { - fn eq(&self, rhs: &Self) -> bool { +impl PartialEq for IxDynRepr +{ + fn eq(&self, rhs: &Self) -> bool + { match (self, rhs) { - (&IxDynRepr::Inline(slen, ref sarr), &IxDynRepr::Inline(rlen, ref rarr)) => { + (&IxDynRepr::Inline(slen, ref sarr), &IxDynRepr::Inline(rlen, ref rarr)) => slen == rlen - && (0..CAP as usize) + && (0..CAP) .filter(|&i| i < slen as usize) - .all(|i| sarr[i] == rarr[i]) - } + .all(|i| sarr[i] == rarr[i]), _ => self[..] == rhs[..], } } } -impl Hash for IxDynRepr { - fn hash(&self, state: &mut H) { +impl Hash for IxDynRepr +{ + fn hash(&self, state: &mut H) + { Hash::hash(&self[..], state) } } @@ -117,8 +140,10 @@ impl Hash for IxDynRepr { #[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] pub struct IxDynImpl(IxDynRepr); -impl IxDynImpl { - pub(crate) fn insert(&self, i: usize) -> Self { +impl IxDynImpl +{ + pub(crate) fn insert(&self, i: usize) -> Self + { let len = self.len(); debug_assert!(i <= len); IxDynImpl(if len < CAP { @@ -135,7 +160,8 @@ impl IxDynImpl { }) } - fn remove(&self, i: usize) -> Self { + fn remove(&self, i: usize) -> Self + { IxDynImpl(match self.0 { IxDynRepr::Inline(0, _) => IxDynRepr::Inline(0, [0; CAP]), IxDynRepr::Inline(1, _) => IxDynRepr::Inline(0, [0; CAP]), @@ -156,74 +182,88 @@ impl IxDynImpl { } } -impl<'a> From<&'a [Ix]> for IxDynImpl { +impl<'a> From<&'a [Ix]> for IxDynImpl +{ #[inline] - fn from(ix: &'a [Ix]) -> Self { + fn from(ix: &'a [Ix]) -> Self + { IxDynImpl(IxDynRepr::copy_from(ix)) } } -impl From> for IxDynImpl { +impl From> for IxDynImpl +{ #[inline] - fn from(ix: Vec) -> Self { + fn from(ix: Vec) -> Self + { IxDynImpl(IxDynRepr::from_vec_auto(ix)) } } impl Index for IxDynImpl -where - [Ix]: Index, +where [Ix]: Index { type Output = <[Ix] as Index>::Output; - fn index(&self, index: J) -> &Self::Output { + fn index(&self, index: J) -> &Self::Output + { &self.0[index] } } impl IndexMut for IxDynImpl -where - [Ix]: IndexMut, +where [Ix]: IndexMut { - fn index_mut(&mut self, index: J) -> &mut Self::Output { + fn index_mut(&mut self, index: J) -> &mut Self::Output + { &mut self.0[index] } } -impl Deref for IxDynImpl { +impl Deref for IxDynImpl +{ type Target = [Ix]; #[inline] - fn deref(&self) -> &[Ix] { + fn deref(&self) -> &[Ix] + { &self.0 } } -impl DerefMut for IxDynImpl { +impl DerefMut for IxDynImpl +{ #[inline] - fn deref_mut(&mut self) -> &mut [Ix] { + fn deref_mut(&mut self) -> &mut [Ix] + { &mut self.0 } } -impl<'a> IntoIterator for &'a IxDynImpl { +impl<'a> IntoIterator for &'a IxDynImpl +{ type Item = &'a Ix; type IntoIter = <&'a [Ix] as IntoIterator>::IntoIter; #[inline] - fn into_iter(self) -> Self::IntoIter { + fn into_iter(self) -> Self::IntoIter + { self[..].iter() } } -impl RemoveAxis for Dim { - fn remove_axis(&self, axis: Axis) -> Self { +impl RemoveAxis for Dim +{ + fn remove_axis(&self, axis: Axis) -> Self + { debug_assert!(axis.index() < self.ndim()); Dim::new(self.ix().remove(axis.index())) } } -impl IxDyn { +impl IxDyn +{ /// Create a new dimension value with `n` axes, all zeros #[inline] - pub fn zeros(n: usize) -> IxDyn { + pub fn zeros(n: usize) -> IxDyn + { const ZEROS: &[usize] = &[0; 4]; if n <= ZEROS.len() { Dim(&ZEROS[..n]) diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 28d2e9b2c..eb07252b2 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -7,36 +7,47 @@ // except according to those terms. use crate::error::{from_kind, ErrorKind, ShapeError}; -use crate::{Ix, Ixs, Slice, SliceOrIndex}; +use crate::shape_builder::Strides; +use crate::slice::SliceArg; +use crate::{Ix, Ixs, Slice, SliceInfoElem}; use num_integer::div_floor; -pub use self::axes::{axes_of, Axes, AxisDescription}; +pub use self::axes::{Axes, AxisDescription}; pub use self::axis::Axis; +pub use self::broadcast::DimMax; pub use self::conversion::IntoDimension; pub use self::dim::*; pub use self::dimension_trait::Dimension; pub use self::dynindeximpl::IxDynImpl; pub use self::ndindex::NdIndex; +pub use self::ops::DimAdd; pub use self::remove_axis::RemoveAxis; -use std::isize; +pub(crate) use self::axes::axes_of; +pub(crate) use self::reshape::reshape_dim; + use std::mem; #[macro_use] mod macros; mod axes; mod axis; +pub(crate) mod broadcast; mod conversion; pub mod dim; mod dimension_trait; mod dynindeximpl; mod ndindex; +mod ops; mod remove_axis; +pub(crate) mod reshape; +mod sequence; /// Calculate offset from `Ix` stride converting sign properly #[inline(always)] -pub fn stride_offset(n: Ix, stride: Ix) -> isize { - (n as isize) * ((stride as Ixs) as isize) +pub fn stride_offset(n: Ix, stride: Ix) -> isize +{ + (n as isize) * (stride as Ixs) } /// Check whether the given `dim` and `stride` lead to overlapping indices @@ -44,15 +55,13 @@ pub fn stride_offset(n: Ix, stride: Ix) -> isize { /// There is overlap if, when iterating through the dimensions in order of /// increasing stride, the current stride is less than or equal to the maximum /// possible offset along the preceding axes. (Axes of length ≤1 are ignored.) -/// -/// The current implementation assumes that strides of axes with length > 1 are -/// nonnegative. Additionally, it does not check for overflow. -pub fn dim_stride_overlap(dim: &D, strides: &D) -> bool { +pub(crate) fn dim_stride_overlap(dim: &D, strides: &D) -> bool +{ let order = strides._fastest_varying_stride_order(); let mut sum_prev_offsets = 0; for &index in order.slice() { let d = dim[index]; - let s = strides[index] as isize; + let s = (strides[index] as isize).abs(); match d { 0 => return false, 1 => {} @@ -76,20 +85,36 @@ pub fn dim_stride_overlap(dim: &D, strides: &D) -> bool { /// are met to construct an array from the data buffer, `dim`, and `strides`. /// (The data buffer being a slice or `Vec` guarantees that it contains no more /// than `isize::MAX` bytes.) -pub fn size_of_shape_checked(dim: &D) -> Result { +pub fn size_of_shape_checked(dim: &D) -> Result +{ let size_nonzero = dim .slice() .iter() .filter(|&&d| d != 0) .try_fold(1usize, |acc, &d| acc.checked_mul(d)) .ok_or_else(|| from_kind(ErrorKind::Overflow))?; - if size_nonzero > ::std::isize::MAX as usize { + if size_nonzero > isize::MAX as usize { Err(from_kind(ErrorKind::Overflow)) } else { Ok(dim.size()) } } +/// Select how aliasing is checked +/// +/// For owned or mutable data: +/// +/// The strides must not allow any element to be referenced by two different indices. +/// +#[derive(Copy, Clone, PartialEq)] +pub(crate) enum CanIndexCheckMode +{ + /// Owned or mutable: No aliasing + OwnedMutable, + /// Aliasing + ReadOnly, +} + /// Checks whether the given data and dimension meet the invariants of the /// `ArrayBase` type, assuming the strides are created using /// `dim.default_strides()` or `dim.fortran_strides()`. @@ -114,11 +139,24 @@ pub fn size_of_shape_checked(dim: &D) -> Result /// conditions 1 and 2 are sufficient to guarantee that the offset in units of /// `A` and in units of bytes between the least address and greatest address /// accessible by moving along all axes does not exceed `isize::MAX`. -pub fn can_index_slice_not_custom(data: &[A], dim: &D) -> Result<(), ShapeError> { +pub(crate) fn can_index_slice_with_strides( + data: &[A], dim: &D, strides: &Strides, mode: CanIndexCheckMode, +) -> Result<(), ShapeError> +{ + if let Strides::Custom(strides) = strides { + can_index_slice(data, dim, strides, mode) + } else { + // contiguous shapes: never aliasing, mode does not matter + can_index_slice_not_custom(data.len(), dim) + } +} + +pub(crate) fn can_index_slice_not_custom(data_len: usize, dim: &D) -> Result<(), ShapeError> +{ // Condition 1. let len = size_of_shape_checked(dim)?; // Condition 2. - if len > data.len() { + if len > data_len { return Err(from_kind(ErrorKind::OutOfBounds)); } Ok(()) @@ -139,8 +177,13 @@ pub fn can_index_slice_not_custom(data: &[A], dim: &D) -> Resul /// also implies that the length of any individual axis does not exceed /// `isize::MAX`.) pub fn max_abs_offset_check_overflow(dim: &D, strides: &D) -> Result -where - D: Dimension, +where D: Dimension +{ + max_abs_offset_check_overflow_impl(mem::size_of::(), dim, strides) +} + +fn max_abs_offset_check_overflow_impl(elem_size: usize, dim: &D, strides: &D) -> Result +where D: Dimension { // Condition 1. if dim.ndim() != strides.ndim() { @@ -156,7 +199,7 @@ where .try_fold(0usize, |acc, (&d, &s)| { let s = s as isize; // Calculate maximum possible absolute movement along this axis. - let off = d.saturating_sub(1).checked_mul(s.abs() as usize)?; + let off = d.saturating_sub(1).checked_mul(s.unsigned_abs())?; acc.checked_add(off) }) .ok_or_else(|| from_kind(ErrorKind::Overflow))?; @@ -168,7 +211,7 @@ where // Determine absolute difference in units of bytes between least and // greatest address accessible by moving along all axes let max_offset_bytes = max_offset - .checked_mul(mem::size_of::()) + .checked_mul(elem_size) .ok_or_else(|| from_kind(ErrorKind::Overflow))?; // Condition 2b. if max_offset_bytes > isize::MAX as usize { @@ -187,11 +230,7 @@ where /// /// 2. The product of non-zero axis lengths must not exceed `isize::MAX`. /// -/// 3. For axes with length > 1, the stride must be nonnegative. This is -/// necessary to make sure the pointer cannot move backwards outside the -/// slice. For axes with length ≤ 1, the stride can be anything. -/// -/// 4. If the array will be empty (any axes are zero-length), the difference +/// 3. If the array will be empty (any axes are zero-length), the difference /// between the least address and greatest address accessible by moving /// along all axes must be ≤ `data.len()`. (It's fine in this case to move /// one byte past the end of the slice since the pointers will be offset but @@ -202,40 +241,44 @@ where /// `data.len()`. This and #3 ensure that all dereferenceable pointers point /// to elements within the slice. /// -/// 5. The strides must not allow any element to be referenced by two different +/// 4. The strides must not allow any element to be referenced by two different /// indices. /// /// Note that since slices cannot contain more than `isize::MAX` bytes, /// condition 4 is sufficient to guarantee that the absolute difference in /// units of `A` and in units of bytes between the least address and greatest /// address accessible by moving along all axes does not exceed `isize::MAX`. -pub fn can_index_slice( - data: &[A], - dim: &D, - strides: &D, -) -> Result<(), ShapeError> { +/// +/// Warning: This function is sufficient to check the invariants of ArrayBase +/// only if the pointer to the first element of the array is chosen such that +/// the element with the smallest memory address is at the start of the +/// allocation. (In other words, the pointer to the first element of the array +/// must be computed using `offset_from_low_addr_ptr_to_logical_ptr` so that +/// negative strides are correctly handled.) +pub(crate) fn can_index_slice( + data: &[A], dim: &D, strides: &D, mode: CanIndexCheckMode, +) -> Result<(), ShapeError> +{ // Check conditions 1 and 2 and calculate `max_offset`. let max_offset = max_abs_offset_check_overflow::(dim, strides)?; + can_index_slice_impl(max_offset, data.len(), dim, strides, mode) +} - // Check condition 4. +fn can_index_slice_impl( + max_offset: usize, data_len: usize, dim: &D, strides: &D, mode: CanIndexCheckMode, +) -> Result<(), ShapeError> +{ + // Check condition 3. let is_empty = dim.slice().iter().any(|&d| d == 0); - if is_empty && max_offset > data.len() { + if is_empty && max_offset > data_len { return Err(from_kind(ErrorKind::OutOfBounds)); } - if !is_empty && max_offset >= data.len() { + if !is_empty && max_offset >= data_len { return Err(from_kind(ErrorKind::OutOfBounds)); } - // Check condition 3. - for (&d, &s) in izip!(dim.slice(), strides.slice()) { - let s = s as isize; - if d > 1 && s < 0 { - return Err(from_kind(ErrorKind::Unsupported)); - } - } - - // Check condition 5. - if !is_empty && dim_stride_overlap(dim, strides) { + // Check condition 4. + if !is_empty && mode != CanIndexCheckMode::ReadOnly && dim_stride_overlap(dim, strides) { return Err(from_kind(ErrorKind::Unsupported)); } @@ -244,7 +287,8 @@ pub fn can_index_slice( /// Stride offset checked general version (slices) #[inline] -pub fn stride_offset_checked(dim: &[Ix], strides: &[Ix], index: &[Ix]) -> Option { +pub fn stride_offset_checked(dim: &[Ix], strides: &[Ix], index: &[Ix]) -> Option +{ if index.len() != dim.len() { return None; } @@ -258,44 +302,63 @@ pub fn stride_offset_checked(dim: &[Ix], strides: &[Ix], index: &[Ix]) -> Option Some(offset) } +/// Checks if strides are non-negative. +pub fn strides_non_negative(strides: &D) -> Result<(), ShapeError> +where D: Dimension +{ + for &stride in strides.slice() { + if (stride as isize) < 0 { + return Err(from_kind(ErrorKind::Unsupported)); + } + } + Ok(()) +} + /// Implementation-specific extensions to `Dimension` -pub trait DimensionExt { +pub trait DimensionExt +{ // note: many extensions go in the main trait if they need to be special- // cased per dimension /// Get the dimension at `axis`. /// /// *Panics* if `axis` is out of bounds. + #[track_caller] fn axis(&self, axis: Axis) -> Ix; /// Set the dimension at `axis`. /// /// *Panics* if `axis` is out of bounds. + #[track_caller] fn set_axis(&mut self, axis: Axis, value: Ix); } impl DimensionExt for D -where - D: Dimension, +where D: Dimension { #[inline] - fn axis(&self, axis: Axis) -> Ix { + fn axis(&self, axis: Axis) -> Ix + { self[axis.index()] } #[inline] - fn set_axis(&mut self, axis: Axis, value: Ix) { + fn set_axis(&mut self, axis: Axis, value: Ix) + { self[axis.index()] = value; } } -impl<'a> DimensionExt for [Ix] { +impl DimensionExt for [Ix] +{ #[inline] - fn axis(&self, axis: Axis) -> Ix { + fn axis(&self, axis: Axis) -> Ix + { self[axis.index()] } #[inline] - fn set_axis(&mut self, axis: Axis, value: Ix) { + fn set_axis(&mut self, axis: Axis, value: Ix) + { self[axis.index()] = value; } } @@ -304,13 +367,10 @@ impl<'a> DimensionExt for [Ix] { /// available. /// /// **Panics** if `index` is larger than the size of the axis +#[track_caller] // FIXME: Move to Dimension trait -pub fn do_collapse_axis( - dims: &mut D, - strides: &D, - axis: usize, - index: usize, -) -> isize { +pub fn do_collapse_axis(dims: &mut D, strides: &D, axis: usize, index: usize) -> isize +{ let dim = dims.slice()[axis]; let stride = strides.slice()[axis]; ndassert!( @@ -327,7 +387,8 @@ pub fn do_collapse_axis( /// Compute the equivalent unsigned index given the axis length and signed index. #[inline] -pub fn abs_index(len: Ix, index: Ixs) -> Ix { +pub fn abs_index(len: Ix, index: Ixs) -> Ix +{ if index < 0 { len - (-index as Ix) } else { @@ -340,7 +401,9 @@ pub fn abs_index(len: Ix, index: Ixs) -> Ix { /// The return value is (start, end, step). /// /// **Panics** if stride is 0 or if any index is out of bounds. -fn to_abs_slice(axis_len: usize, slice: Slice) -> (usize, usize, isize) { +#[track_caller] +fn to_abs_slice(axis_len: usize, slice: Slice) -> (usize, usize, isize) +{ let Slice { start, end, step } = slice; let start = abs_index(axis_len, start); let mut end = abs_index(axis_len, end.unwrap_or(axis_len as isize)); @@ -363,10 +426,28 @@ fn to_abs_slice(axis_len: usize, slice: Slice) -> (usize, usize, isize) { (start, end, step) } +/// This function computes the offset from the lowest address element to the +/// logically first element. +pub fn offset_from_low_addr_ptr_to_logical_ptr(dim: &D, strides: &D) -> usize +{ + let offset = izip!(dim.slice(), strides.slice()).fold(0, |_offset, (&d, &s)| { + let s = s as isize; + if s < 0 && d > 1 { + _offset - s * (d as isize - 1) + } else { + _offset + } + }); + debug_assert!(offset >= 0); + offset as usize +} + /// Modify dimension, stride and return data pointer offset /// /// **Panics** if stride is 0 or if any index is out of bounds. -pub fn do_slice(dim: &mut usize, stride: &mut usize, slice: Slice) -> isize { +#[track_caller] +pub fn do_slice(dim: &mut usize, stride: &mut usize, slice: Slice) -> isize +{ let (start, end, step) = to_abs_slice(*dim, slice); let m = end - start; @@ -397,7 +478,7 @@ pub fn do_slice(dim: &mut usize, stride: &mut usize, slice: Slice) -> isize { }; // Update dimension. - let abs_step = step.abs() as usize; + let abs_step = step.unsigned_abs(); *dim = if abs_step == 1 { m } else { @@ -419,7 +500,8 @@ pub fn do_slice(dim: &mut usize, stride: &mut usize, slice: Slice) -> isize { /// nonnegative. /// /// See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm -fn extended_gcd(a: isize, b: isize) -> (isize, (isize, isize)) { +fn extended_gcd(a: isize, b: isize) -> (isize, (isize, isize)) +{ if a == 0 { (b.abs(), (0, b.signum())) } else if b == 0 { @@ -455,7 +537,8 @@ fn extended_gcd(a: isize, b: isize) -> (isize, (isize, isize)) { /// /// See https://en.wikipedia.org/wiki/Diophantine_equation#One_equation /// and https://math.stackexchange.com/questions/1656120#1656138 -fn solve_linear_diophantine_eq(a: isize, b: isize, c: isize) -> Option<(isize, isize)> { +fn solve_linear_diophantine_eq(a: isize, b: isize, c: isize) -> Option<(isize, isize)> +{ debug_assert_ne!(a, 0); debug_assert_ne!(b, 0); let (g, (u, _)) = extended_gcd(a, b); @@ -473,10 +556,8 @@ fn solve_linear_diophantine_eq(a: isize, b: isize, c: isize) -> Option<(isize, i /// consecutive elements (the sign is irrelevant). /// /// **Note** `step1` and `step2` must be nonzero. -fn arith_seq_intersect( - (min1, max1, step1): (isize, isize, isize), - (min2, max2, step2): (isize, isize, isize), -) -> bool { +fn arith_seq_intersect((min1, max1, step1): (isize, isize, isize), (min2, max2, step2): (isize, isize, isize)) -> bool +{ debug_assert!(max1 >= min1); debug_assert!(max2 >= min2); debug_assert_eq!((max1 - min1) % step1, 0); @@ -532,7 +613,8 @@ fn arith_seq_intersect( /// Returns the minimum and maximum values of the indices (inclusive). /// /// If the slice is empty, then returns `None`, otherwise returns `Some((min, max))`. -fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> { +fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> +{ let (start, end, step) = to_abs_slice(axis_len, slice); if start == end { None @@ -544,22 +626,23 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> { } /// Returns `true` iff the slices intersect. -pub fn slices_intersect( - dim: &D, - indices1: &D::SliceArg, - indices2: &D::SliceArg, -) -> bool { - debug_assert_eq!(indices1.as_ref().len(), indices2.as_ref().len()); - for (&axis_len, &si1, &si2) in izip!(dim.slice(), indices1.as_ref(), indices2.as_ref()) { - // The slices do not intersect iff any pair of `SliceOrIndex` does not intersect. +pub fn slices_intersect(dim: &D, indices1: impl SliceArg, indices2: impl SliceArg) -> bool +{ + debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim()); + for (&axis_len, &si1, &si2) in izip!( + dim.slice(), + indices1.as_ref().iter().filter(|si| !si.is_new_axis()), + indices2.as_ref().iter().filter(|si| !si.is_new_axis()), + ) { + // The slices do not intersect iff any pair of `SliceInfoElem` does not intersect. match (si1, si2) { ( - SliceOrIndex::Slice { + SliceInfoElem::Slice { start: start1, end: end1, step: step1, }, - SliceOrIndex::Slice { + SliceInfoElem::Slice { start: start2, end: end2, step: step2, @@ -573,39 +656,88 @@ pub fn slices_intersect( Some(m) => m, None => return false, }; - if !arith_seq_intersect( - (min1 as isize, max1 as isize, step1), - (min2 as isize, max2 as isize, step2), - ) { + if !arith_seq_intersect((min1 as isize, max1 as isize, step1), (min2 as isize, max2 as isize, step2)) { return false; } } - (SliceOrIndex::Slice { start, end, step }, SliceOrIndex::Index(ind)) - | (SliceOrIndex::Index(ind), SliceOrIndex::Slice { start, end, step }) => { + (SliceInfoElem::Slice { start, end, step }, SliceInfoElem::Index(ind)) + | (SliceInfoElem::Index(ind), SliceInfoElem::Slice { start, end, step }) => { let ind = abs_index(axis_len, ind); let (min, max) = match slice_min_max(axis_len, Slice::new(start, end, step)) { Some(m) => m, None => return false, }; - if ind < min || ind > max || (ind - min) % step.abs() as usize != 0 { + if ind < min || ind > max || (ind - min) % step.unsigned_abs() != 0 { return false; } } - (SliceOrIndex::Index(ind1), SliceOrIndex::Index(ind2)) => { + (SliceInfoElem::Index(ind1), SliceInfoElem::Index(ind2)) => { let ind1 = abs_index(axis_len, ind1); let ind2 = abs_index(axis_len, ind2); if ind1 != ind2 { return false; } } + (SliceInfoElem::NewAxis, _) | (_, SliceInfoElem::NewAxis) => unreachable!(), + } + } + true +} + +pub(crate) fn is_layout_c(dim: &D, strides: &D) -> bool +{ + if let Some(1) = D::NDIM { + return strides[0] == 1 || dim[0] <= 1; + } + + for &d in dim.slice() { + if d == 0 { + return true; + } + } + + let mut contig_stride = 1_isize; + // check all dimensions -- a dimension of length 1 can have unequal strides + for (&dim, &s) in izip!(dim.slice().iter().rev(), strides.slice().iter().rev()) { + if dim != 1 { + let s = s as isize; + if s != contig_stride { + return false; + } + contig_stride *= dim as isize; + } + } + true +} + +pub(crate) fn is_layout_f(dim: &D, strides: &D) -> bool +{ + if let Some(1) = D::NDIM { + return strides[0] == 1 || dim[0] <= 1; + } + + for &d in dim.slice() { + if d == 0 { + return true; + } + } + + let mut contig_stride = 1_isize; + // check all dimensions -- a dimension of length 1 can have unequal strides + for (&dim, &s) in izip!(dim.slice(), strides.slice()) { + if dim != 1 { + let s = s as isize; + if s != contig_stride { + return false; + } + contig_stride *= dim as isize; } } true } pub fn merge_axes(dim: &mut D, strides: &mut D, take: Axis, into: Axis) -> bool -where - D: Dimension, +where D: Dimension { let into_len = dim.axis(into); let into_stride = strides.axis(into) as isize; @@ -630,49 +762,94 @@ where } } +/// Move the axis which has the smallest absolute stride and a length +/// greater than one to be the last axis. +pub fn move_min_stride_axis_to_last(dim: &mut D, strides: &mut D) +where D: Dimension +{ + debug_assert_eq!(dim.ndim(), strides.ndim()); + match dim.ndim() { + 0 | 1 => {} + 2 => + if dim[1] <= 1 || dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs() { + dim.slice_mut().swap(0, 1); + strides.slice_mut().swap(0, 1); + }, + n => { + if let Some(min_stride_axis) = (0..n) + .filter(|&ax| dim[ax] > 1) + .min_by_key(|&ax| (strides[ax] as isize).abs()) + { + let last = n - 1; + dim.slice_mut().swap(last, min_stride_axis); + strides.slice_mut().swap(last, min_stride_axis); + } + } + } +} + #[cfg(test)] -mod test { +mod test +{ use super::{ - arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd, - max_abs_offset_check_overflow, slice_min_max, slices_intersect, - solve_linear_diophantine_eq, IntoDimension, + arith_seq_intersect, + can_index_slice, + can_index_slice_not_custom, + extended_gcd, + max_abs_offset_check_overflow, + slice_min_max, + slices_intersect, + solve_linear_diophantine_eq, + CanIndexCheckMode, + IntoDimension, }; use crate::error::{from_kind, ErrorKind}; use crate::slice::Slice; - use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn}; + use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn, NewAxis}; use num_integer::gcd; use quickcheck::{quickcheck, TestResult}; #[test] - fn slice_indexing_uncommon_strides() { - let v: Vec<_> = (0..12).collect(); + fn slice_indexing_uncommon_strides() + { + let v: alloc::vec::Vec<_> = (0..12).collect(); let dim = (2, 3, 2).into_dimension(); let strides = (1, 2, 6).into_dimension(); - assert!(super::can_index_slice(&v, &dim, &strides).is_ok()); + assert!(super::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok()); let strides = (2, 4, 12).into_dimension(); assert_eq!( - super::can_index_slice(&v, &dim, &strides), + super::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable), Err(from_kind(ErrorKind::OutOfBounds)) ); } #[test] - fn overlapping_strides_dim() { + fn overlapping_strides_dim() + { let dim = (2, 3, 2).into_dimension(); let strides = (5, 2, 1).into_dimension(); assert!(super::dim_stride_overlap(&dim, &strides)); + let strides = (-5isize as usize, 2, -1isize as usize).into_dimension(); + assert!(super::dim_stride_overlap(&dim, &strides)); let strides = (6, 2, 1).into_dimension(); assert!(!super::dim_stride_overlap(&dim, &strides)); + let strides = (6, -2isize as usize, 1).into_dimension(); + assert!(!super::dim_stride_overlap(&dim, &strides)); let strides = (6, 0, 1).into_dimension(); assert!(super::dim_stride_overlap(&dim, &strides)); + let strides = (-6isize as usize, 0, 1).into_dimension(); + assert!(super::dim_stride_overlap(&dim, &strides)); let dim = (2, 2).into_dimension(); let strides = (3, 2).into_dimension(); assert!(!super::dim_stride_overlap(&dim, &strides)); + let strides = (3, -2isize as usize).into_dimension(); + assert!(!super::dim_stride_overlap(&dim, &strides)); } #[test] - fn max_abs_offset_check_overflow_examples() { + fn max_abs_offset_check_overflow_examples() + { let dim = (1, ::std::isize::MAX as usize, 1).into_dimension(); let strides = (1, 1, 1).into_dimension(); max_abs_offset_check_overflow::(&dim, &strides).unwrap(); @@ -688,98 +865,116 @@ mod test { } #[test] - fn can_index_slice_ix0() { - can_index_slice::(&[1], &Ix0(), &Ix0()).unwrap(); - can_index_slice::(&[], &Ix0(), &Ix0()).unwrap_err(); + fn can_index_slice_ix0() + { + can_index_slice::(&[1], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap(); + can_index_slice::(&[], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap_err(); } #[test] - fn can_index_slice_ix1() { - can_index_slice::(&[], &Ix1(0), &Ix1(0)).unwrap(); - can_index_slice::(&[], &Ix1(0), &Ix1(1)).unwrap(); - can_index_slice::(&[], &Ix1(1), &Ix1(0)).unwrap_err(); - can_index_slice::(&[], &Ix1(1), &Ix1(1)).unwrap_err(); - can_index_slice::(&[1], &Ix1(1), &Ix1(0)).unwrap(); - can_index_slice::(&[1], &Ix1(1), &Ix1(2)).unwrap(); - can_index_slice::(&[1], &Ix1(1), &Ix1(-1isize as usize)).unwrap(); - can_index_slice::(&[1], &Ix1(2), &Ix1(1)).unwrap_err(); - can_index_slice::(&[1, 2], &Ix1(2), &Ix1(0)).unwrap_err(); - can_index_slice::(&[1, 2], &Ix1(2), &Ix1(1)).unwrap(); - can_index_slice::(&[1, 2], &Ix1(2), &Ix1(-1isize as usize)).unwrap_err(); + fn can_index_slice_ix1() + { + let mode = CanIndexCheckMode::OwnedMutable; + can_index_slice::(&[], &Ix1(0), &Ix1(0), mode).unwrap(); + can_index_slice::(&[], &Ix1(0), &Ix1(1), mode).unwrap(); + can_index_slice::(&[], &Ix1(1), &Ix1(0), mode).unwrap_err(); + can_index_slice::(&[], &Ix1(1), &Ix1(1), mode).unwrap_err(); + can_index_slice::(&[1], &Ix1(1), &Ix1(0), mode).unwrap(); + can_index_slice::(&[1], &Ix1(1), &Ix1(2), mode).unwrap(); + can_index_slice::(&[1], &Ix1(1), &Ix1(-1isize as usize), mode).unwrap(); + can_index_slice::(&[1], &Ix1(2), &Ix1(1), mode).unwrap_err(); + can_index_slice::(&[1, 2], &Ix1(2), &Ix1(0), mode).unwrap_err(); + can_index_slice::(&[1, 2], &Ix1(2), &Ix1(1), mode).unwrap(); + can_index_slice::(&[1, 2], &Ix1(2), &Ix1(-1isize as usize), mode).unwrap(); } #[test] - fn can_index_slice_ix2() { - can_index_slice::(&[], &Ix2(0, 0), &Ix2(0, 0)).unwrap(); - can_index_slice::(&[], &Ix2(0, 0), &Ix2(2, 1)).unwrap(); - can_index_slice::(&[], &Ix2(0, 1), &Ix2(0, 0)).unwrap(); - can_index_slice::(&[], &Ix2(0, 1), &Ix2(2, 1)).unwrap(); - can_index_slice::(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap(); - can_index_slice::(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err(); - can_index_slice::(&[1], &Ix2(1, 2), &Ix2(5, 1)).unwrap_err(); - can_index_slice::(&[1, 2], &Ix2(1, 2), &Ix2(5, 1)).unwrap(); - can_index_slice::(&[1, 2], &Ix2(1, 2), &Ix2(5, 2)).unwrap_err(); - can_index_slice::(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1)).unwrap(); - can_index_slice::(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1)).unwrap_err(); + fn can_index_slice_ix2() + { + let mode = CanIndexCheckMode::OwnedMutable; + can_index_slice::(&[], &Ix2(0, 0), &Ix2(0, 0), mode).unwrap(); + can_index_slice::(&[], &Ix2(0, 0), &Ix2(2, 1), mode).unwrap(); + can_index_slice::(&[], &Ix2(0, 1), &Ix2(0, 0), mode).unwrap(); + can_index_slice::(&[], &Ix2(0, 1), &Ix2(2, 1), mode).unwrap(); + can_index_slice::(&[], &Ix2(0, 2), &Ix2(0, 0), mode).unwrap(); + can_index_slice::(&[], &Ix2(0, 2), &Ix2(2, 1), mode).unwrap_err(); + can_index_slice::(&[1], &Ix2(1, 2), &Ix2(5, 1), mode).unwrap_err(); + can_index_slice::(&[1, 2], &Ix2(1, 2), &Ix2(5, 1), mode).unwrap(); + can_index_slice::(&[1, 2], &Ix2(1, 2), &Ix2(5, 2), mode).unwrap_err(); + can_index_slice::(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1), mode).unwrap(); + can_index_slice::(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1), mode).unwrap_err(); + + // aliasing strides: ok when readonly + can_index_slice::(&[0; 4], &Ix2(2, 2), &Ix2(1, 1), CanIndexCheckMode::OwnedMutable).unwrap_err(); + can_index_slice::(&[0; 4], &Ix2(2, 2), &Ix2(1, 1), CanIndexCheckMode::ReadOnly).unwrap(); } #[test] - fn can_index_slice_ix3() { - can_index_slice::(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3)).unwrap(); - can_index_slice::(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap_err(); - can_index_slice::(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap(); - can_index_slice::(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap_err(); - can_index_slice::(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap(); + fn can_index_slice_ix3() + { + let mode = CanIndexCheckMode::OwnedMutable; + can_index_slice::(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3), mode).unwrap(); + can_index_slice::(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap_err(); + can_index_slice::(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap(); + can_index_slice::(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1), mode).unwrap_err(); + can_index_slice::(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1), mode).unwrap(); } #[test] - fn can_index_slice_zero_size_elem() { - can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1)).unwrap(); - can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1)).unwrap(); - can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1)).unwrap(); + fn can_index_slice_zero_size_elem() + { + let mode = CanIndexCheckMode::OwnedMutable; + can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1), mode).unwrap(); + can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1), mode).unwrap(); + can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1), mode).unwrap(); // These might seem okay because the element type is zero-sized, but // there could be a zero-sized type such that the number of instances // in existence are carefully controlled. - can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1)).unwrap_err(); - can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1)).unwrap_err(); + can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1), mode).unwrap_err(); + can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1), mode).unwrap_err(); - can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0)).unwrap(); - can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap(); + can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0), mode).unwrap(); + can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0), mode).unwrap(); // This case would be probably be sound, but that's not entirely clear // and it's not worth the special case code. - can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err(); + can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1), mode).unwrap_err(); } quickcheck! { - fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec, dim: Vec) -> bool { + fn can_index_slice_not_custom_same_as_can_index_slice(data: alloc::vec::Vec, dim: alloc::vec::Vec) -> bool { let dim = IxDyn(&dim); - let result = can_index_slice_not_custom(&data, &dim); + let result = can_index_slice_not_custom(data.len(), &dim); if dim.size_checked().is_none() { // Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`. result.is_err() } else { - result == can_index_slice(&data, &dim, &dim.default_strides()) && - result == can_index_slice(&data, &dim, &dim.fortran_strides()) + result == can_index_slice(&data, &dim, &dim.default_strides(), CanIndexCheckMode::OwnedMutable) && + result == can_index_slice(&data, &dim, &dim.fortran_strides(), CanIndexCheckMode::OwnedMutable) } } } quickcheck! { - fn extended_gcd_solves_eq(a: isize, b: isize) -> bool { + // FIXME: This test can't handle larger values at the moment + fn extended_gcd_solves_eq(a: i16, b: i16) -> bool { + let (a, b) = (a as isize, b as isize); let (g, (x, y)) = extended_gcd(a, b); a * x + b * y == g } - fn extended_gcd_correct_gcd(a: isize, b: isize) -> bool { + // FIXME: This test can't handle larger values at the moment + fn extended_gcd_correct_gcd(a: i16, b: i16) -> bool { + let (a, b) = (a as isize, b as isize); let (g, _) = extended_gcd(a, b); g == gcd(a, b) } } #[test] - fn extended_gcd_zero() { + fn extended_gcd_zero() + { assert_eq!(extended_gcd(0, 0), (0, (0, 0))); assert_eq!(extended_gcd(0, 5), (5, (0, 1))); assert_eq!(extended_gcd(5, 0), (5, (1, 0))); @@ -788,9 +983,12 @@ mod test { } quickcheck! { + // FIXME: This test can't handle larger values at the moment fn solve_linear_diophantine_eq_solution_existence( - a: isize, b: isize, c: isize + a: i16, b: i16, c: i16 ) -> TestResult { + let (a, b, c) = (a as isize, b as isize, c as isize); + if a == 0 || b == 0 { TestResult::discard() } else { @@ -800,9 +998,12 @@ mod test { } } + // FIXME: This test can't handle larger values at the moment fn solve_linear_diophantine_eq_correct_solution( - a: isize, b: isize, c: isize, t: isize + a: i8, b: i8, c: i8, t: i8 ) -> TestResult { + let (a, b, c, t) = (a as isize, b as isize, c as isize, t as isize); + if a == 0 || b == 0 { TestResult::discard() } else { @@ -819,17 +1020,24 @@ mod test { } quickcheck! { + #[cfg_attr(miri, ignore)] // Very slow on CI/CD machines + // FIXME: This test is extremely slow, even with i16 values, investigate fn arith_seq_intersect_correct( - first1: isize, len1: isize, step1: isize, - first2: isize, len2: isize, step2: isize + first1: i8, len1: i8, step1: i8, + first2: i8, len2: i8, step2: i8 ) -> TestResult { use std::cmp; + let (len1, len2) = (len1 as isize, len2 as isize); + let (first1, step1) = (first1 as isize, step1 as isize); + let (first2, step2) = (first2 as isize, step2 as isize); + if len1 == 0 || len2 == 0 { // This case is impossible to reach in `arith_seq_intersect()` // because the `min*` and `max*` arguments are inclusive. return TestResult::discard(); } + let len1 = len1.abs(); let len2 = len2.abs(); @@ -840,7 +1048,7 @@ mod test { let (min2, max2) = (cmp::min(first2, last2), cmp::max(first2, last2)); // Naively determine if the sequences intersect. - let seq1: Vec<_> = (0..len1) + let seq1: alloc::vec::Vec<_> = (0..len1) .map(|n| first1 + step1 * n) .collect(); let intersects = (0..len2) @@ -857,7 +1065,8 @@ mod test { } #[test] - fn slice_min_max_empty() { + fn slice_min_max_empty() + { assert_eq!(slice_min_max(0, Slice::new(0, None, 3)), None); assert_eq!(slice_min_max(10, Slice::new(1, Some(1), 3)), None); assert_eq!(slice_min_max(10, Slice::new(-1, Some(-1), 3)), None); @@ -866,7 +1075,8 @@ mod test { } #[test] - fn slice_min_max_pos_step() { + fn slice_min_max_pos_step() + { assert_eq!(slice_min_max(10, Slice::new(1, Some(8), 3)), Some((1, 7))); assert_eq!(slice_min_max(10, Slice::new(1, Some(9), 3)), Some((1, 7))); assert_eq!(slice_min_max(10, Slice::new(-9, Some(8), 3)), Some((1, 7))); @@ -882,7 +1092,8 @@ mod test { } #[test] - fn slice_min_max_neg_step() { + fn slice_min_max_neg_step() + { assert_eq!(slice_min_max(10, Slice::new(1, Some(8), -3)), Some((1, 7))); assert_eq!(slice_min_max(10, Slice::new(2, Some(8), -3)), Some((4, 7))); assert_eq!(slice_min_max(10, Slice::new(-9, Some(8), -3)), Some((1, 7))); @@ -904,18 +1115,48 @@ mod test { } #[test] - fn slices_intersect_true() { - assert!(slices_intersect(&Dim([4, 5]), s![.., ..], s![.., ..])); - assert!(slices_intersect(&Dim([4, 5]), s![0, ..], s![0, ..])); - assert!(slices_intersect(&Dim([4, 5]), s![..;2, ..], s![..;3, ..])); - assert!(slices_intersect(&Dim([4, 5]), s![.., ..;2], s![.., 1..;3])); + fn slices_intersect_true() + { + assert!(slices_intersect( + &Dim([4, 5]), + s![NewAxis, .., NewAxis, ..], + s![.., NewAxis, .., NewAxis] + )); + assert!(slices_intersect( + &Dim([4, 5]), + s![NewAxis, 0, ..], + s![0, ..] + )); + assert!(slices_intersect( + &Dim([4, 5]), + s![..;2, ..], + s![..;3, NewAxis, ..] + )); + assert!(slices_intersect( + &Dim([4, 5]), + s![.., ..;2], + s![.., 1..;3, NewAxis] + )); assert!(slices_intersect(&Dim([4, 10]), s![.., ..;9], s![.., 3..;6])); } #[test] - fn slices_intersect_false() { - assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;2, ..])); - assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;3, ..])); - assert!(!slices_intersect(&Dim([4, 5]), s![.., ..;9], s![.., 3..;6])); + fn slices_intersect_false() + { + assert!(!slices_intersect( + &Dim([4, 5]), + s![..;2, ..], + s![NewAxis, 1..;2, ..] + )); + assert!(!slices_intersect( + &Dim([4, 5]), + s![..;2, NewAxis, ..], + s![1..;3, ..] + )); + assert!(!slices_intersect( + &Dim([4, 5]), + s![.., ..;9], + s![.., 3..;6, NewAxis] + )); } } diff --git a/src/dimension/ndindex.rs b/src/dimension/ndindex.rs index d9bac1d94..ca2a3ea69 100644 --- a/src/dimension/ndindex.rs +++ b/src/dimension/ndindex.rs @@ -2,9 +2,7 @@ use std::fmt::Debug; use super::{stride_offset, stride_offset_checked}; use crate::itertools::zip; -use crate::{ - Dim, Dimension, IntoDimension, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, -}; +use crate::{Dim, Dimension, IntoDimension, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl}; /// Tuple or fixed size arrays that can be used to index an array. /// @@ -18,7 +16,9 @@ use crate::{ /// a[[1, 1]] += 1; /// assert_eq!(a[(1, 1)], 4); /// ``` -pub unsafe trait NdIndex: Debug { +#[allow(clippy::missing_safety_doc)] // TODO: Add doc +pub unsafe trait NdIndex: Debug +{ #[doc(hidden)] fn index_checked(&self, dim: &E, strides: &E) -> Option; #[doc(hidden)] @@ -26,96 +26,134 @@ pub unsafe trait NdIndex: Debug { } unsafe impl NdIndex for D -where - D: Dimension, +where D: Dimension { - fn index_checked(&self, dim: &D, strides: &D) -> Option { + fn index_checked(&self, dim: &D, strides: &D) -> Option + { dim.stride_offset_checked(strides, self) } - fn index_unchecked(&self, strides: &D) -> isize { + fn index_unchecked(&self, strides: &D) -> isize + { D::stride_offset(self, strides) } } -unsafe impl NdIndex for () { +unsafe impl NdIndex for () +{ #[inline] - fn index_checked(&self, dim: &Ix0, strides: &Ix0) -> Option { + fn index_checked(&self, dim: &Ix0, strides: &Ix0) -> Option + { dim.stride_offset_checked(strides, &Ix0()) } #[inline(always)] - fn index_unchecked(&self, _strides: &Ix0) -> isize { + fn index_unchecked(&self, _strides: &Ix0) -> isize + { 0 } } -unsafe impl NdIndex for (Ix, Ix) { +unsafe impl NdIndex for (Ix, Ix) +{ #[inline] - fn index_checked(&self, dim: &Ix2, strides: &Ix2) -> Option { + fn index_checked(&self, dim: &Ix2, strides: &Ix2) -> Option + { dim.stride_offset_checked(strides, &Ix2(self.0, self.1)) } #[inline] - fn index_unchecked(&self, strides: &Ix2) -> isize { + fn index_unchecked(&self, strides: &Ix2) -> isize + { stride_offset(self.0, get!(strides, 0)) + stride_offset(self.1, get!(strides, 1)) } } -unsafe impl NdIndex for (Ix, Ix, Ix) { +unsafe impl NdIndex for (Ix, Ix, Ix) +{ #[inline] - fn index_checked(&self, dim: &Ix3, strides: &Ix3) -> Option { + fn index_checked(&self, dim: &Ix3, strides: &Ix3) -> Option + { dim.stride_offset_checked(strides, &self.into_dimension()) } #[inline] - fn index_unchecked(&self, strides: &Ix3) -> isize { + fn index_unchecked(&self, strides: &Ix3) -> isize + { stride_offset(self.0, get!(strides, 0)) + stride_offset(self.1, get!(strides, 1)) + stride_offset(self.2, get!(strides, 2)) } } -unsafe impl NdIndex for (Ix, Ix, Ix, Ix) { +unsafe impl NdIndex for (Ix, Ix, Ix, Ix) +{ + #[inline] + fn index_checked(&self, dim: &Ix4, strides: &Ix4) -> Option + { + dim.stride_offset_checked(strides, &self.into_dimension()) + } + #[inline] + fn index_unchecked(&self, strides: &Ix4) -> isize + { + zip(strides.ix(), self.into_dimension().ix()) + .map(|(&s, &i)| stride_offset(i, s)) + .sum() + } +} +unsafe impl NdIndex for (Ix, Ix, Ix, Ix, Ix) +{ #[inline] - fn index_checked(&self, dim: &Ix4, strides: &Ix4) -> Option { + fn index_checked(&self, dim: &Ix5, strides: &Ix5) -> Option + { dim.stride_offset_checked(strides, &self.into_dimension()) } #[inline] - fn index_unchecked(&self, strides: &Ix4) -> isize { + fn index_unchecked(&self, strides: &Ix5) -> isize + { zip(strides.ix(), self.into_dimension().ix()) .map(|(&s, &i)| stride_offset(i, s)) .sum() } } -unsafe impl NdIndex for (Ix, Ix, Ix, Ix, Ix) { + +unsafe impl NdIndex for (Ix, Ix, Ix, Ix, Ix, Ix) +{ #[inline] - fn index_checked(&self, dim: &Ix5, strides: &Ix5) -> Option { + fn index_checked(&self, dim: &Ix6, strides: &Ix6) -> Option + { dim.stride_offset_checked(strides, &self.into_dimension()) } #[inline] - fn index_unchecked(&self, strides: &Ix5) -> isize { + fn index_unchecked(&self, strides: &Ix6) -> isize + { zip(strides.ix(), self.into_dimension().ix()) .map(|(&s, &i)| stride_offset(i, s)) .sum() } } -unsafe impl NdIndex for Ix { +unsafe impl NdIndex for Ix +{ #[inline] - fn index_checked(&self, dim: &Ix1, strides: &Ix1) -> Option { + fn index_checked(&self, dim: &Ix1, strides: &Ix1) -> Option + { dim.stride_offset_checked(strides, &Ix1(*self)) } #[inline(always)] - fn index_unchecked(&self, strides: &Ix1) -> isize { + fn index_unchecked(&self, strides: &Ix1) -> isize + { stride_offset(*self, get!(strides, 0)) } } -unsafe impl NdIndex for Ix { +unsafe impl NdIndex for Ix +{ #[inline] - fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { + fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option + { debug_assert_eq!(dim.ndim(), 1); stride_offset_checked(dim.ix(), strides.ix(), &[*self]) } #[inline(always)] - fn index_unchecked(&self, strides: &IxDyn) -> isize { + fn index_unchecked(&self, strides: &IxDyn) -> isize + { debug_assert_eq!(strides.ndim(), 1); stride_offset(*self, get!(strides, 0)) } @@ -139,50 +177,6 @@ macro_rules! ndindex_with_array { 0 } } - - // implement NdIndex for Dim<[Ix; 2]> and so on - unsafe impl NdIndex for Dim<[Ix; $n]> { - #[inline] - fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { - debug_assert_eq!(strides.ndim(), $n, - "Attempted to index with {:?} in array with {} axes", - self, strides.ndim()); - stride_offset_checked(dim.ix(), strides.ix(), self.ix()) - } - - #[inline] - fn index_unchecked(&self, strides: &IxDyn) -> isize { - debug_assert_eq!(strides.ndim(), $n, - "Attempted to index with {:?} in array with {} axes", - self, strides.ndim()); - $( - stride_offset(get!(self, $index), get!(strides, $index)) + - )* - 0 - } - } - - // implement NdIndex for [Ix; 2] and so on - unsafe impl NdIndex for [Ix; $n] { - #[inline] - fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { - debug_assert_eq!(strides.ndim(), $n, - "Attempted to index with {:?} in array with {} axes", - self, strides.ndim()); - stride_offset_checked(dim.ix(), strides.ix(), self) - } - - #[inline] - fn index_unchecked(&self, strides: &IxDyn) -> isize { - debug_assert_eq!(strides.ndim(), $n, - "Attempted to index with {:?} in array with {} axes", - self, strides.ndim()); - $( - stride_offset(self[$index], get!(strides, $index)) + - )* - 0 - } - } )+ }; } @@ -197,27 +191,99 @@ ndindex_with_array! { [6, Ix6 0 1 2 3 4 5] } -impl<'a> IntoDimension for &'a [Ix] { +// implement NdIndex for Dim<[Ix; 2]> and so on +unsafe impl NdIndex for Dim<[Ix; N]> +{ + #[inline] + fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option + { + debug_assert_eq!( + strides.ndim(), + N, + "Attempted to index with {:?} in array with {} axes", + self, + strides.ndim() + ); + stride_offset_checked(dim.ix(), strides.ix(), self.ix()) + } + + #[inline] + fn index_unchecked(&self, strides: &IxDyn) -> isize + { + debug_assert_eq!( + strides.ndim(), + N, + "Attempted to index with {:?} in array with {} axes", + self, + strides.ndim() + ); + (0..N) + .map(|i| stride_offset(get!(self, i), get!(strides, i))) + .sum() + } +} + +// implement NdIndex for [Ix; 2] and so on +unsafe impl NdIndex for [Ix; N] +{ + #[inline] + fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option + { + debug_assert_eq!( + strides.ndim(), + N, + "Attempted to index with {:?} in array with {} axes", + self, + strides.ndim() + ); + stride_offset_checked(dim.ix(), strides.ix(), self) + } + + #[inline] + fn index_unchecked(&self, strides: &IxDyn) -> isize + { + debug_assert_eq!( + strides.ndim(), + N, + "Attempted to index with {:?} in array with {} axes", + self, + strides.ndim() + ); + (0..N) + .map(|i| stride_offset(self[i], get!(strides, i))) + .sum() + } +} + +impl IntoDimension for &[Ix] +{ type Dim = IxDyn; - fn into_dimension(self) -> Self::Dim { + fn into_dimension(self) -> Self::Dim + { Dim(IxDynImpl::from(self)) } } -unsafe impl<'a> NdIndex for &'a IxDyn { - fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { +unsafe impl NdIndex for &IxDyn +{ + fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option + { (**self).index_checked(dim, strides) } - fn index_unchecked(&self, strides: &IxDyn) -> isize { + fn index_unchecked(&self, strides: &IxDyn) -> isize + { (**self).index_unchecked(strides) } } -unsafe impl<'a> NdIndex for &'a [Ix] { - fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { - stride_offset_checked(dim.ix(), strides.ix(), *self) +unsafe impl NdIndex for &[Ix] +{ + fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option + { + stride_offset_checked(dim.ix(), strides.ix(), self) } - fn index_unchecked(&self, strides: &IxDyn) -> isize { + fn index_unchecked(&self, strides: &IxDyn) -> isize + { zip(strides.ix(), *self) .map(|(&s, &i)| stride_offset(i, s)) .sum() diff --git a/src/dimension/ops.rs b/src/dimension/ops.rs new file mode 100644 index 000000000..1365ab488 --- /dev/null +++ b/src/dimension/ops.rs @@ -0,0 +1,93 @@ +use crate::imp_prelude::*; + +/// Adds the two dimensions at compile time. +pub trait DimAdd +{ + /// The sum of the two dimensions. + type Output: Dimension; +} + +macro_rules! impl_dimadd_const_out_const { + ($lhs:expr, $rhs:expr) => { + impl DimAdd> for Dim<[usize; $lhs]> { + type Output = Dim<[usize; $lhs + $rhs]>; + } + }; +} + +macro_rules! impl_dimadd_const_out_dyn { + ($lhs:expr, IxDyn) => { + impl DimAdd for Dim<[usize; $lhs]> { + type Output = IxDyn; + } + }; + ($lhs:expr, $rhs:expr) => { + impl DimAdd> for Dim<[usize; $lhs]> { + type Output = IxDyn; + } + }; +} + +impl DimAdd for Ix0 +{ + type Output = D; +} + +impl_dimadd_const_out_const!(1, 0); +impl_dimadd_const_out_const!(1, 1); +impl_dimadd_const_out_const!(1, 2); +impl_dimadd_const_out_const!(1, 3); +impl_dimadd_const_out_const!(1, 4); +impl_dimadd_const_out_const!(1, 5); +impl_dimadd_const_out_dyn!(1, 6); +impl_dimadd_const_out_dyn!(1, IxDyn); + +impl_dimadd_const_out_const!(2, 0); +impl_dimadd_const_out_const!(2, 1); +impl_dimadd_const_out_const!(2, 2); +impl_dimadd_const_out_const!(2, 3); +impl_dimadd_const_out_const!(2, 4); +impl_dimadd_const_out_dyn!(2, 5); +impl_dimadd_const_out_dyn!(2, 6); +impl_dimadd_const_out_dyn!(2, IxDyn); + +impl_dimadd_const_out_const!(3, 0); +impl_dimadd_const_out_const!(3, 1); +impl_dimadd_const_out_const!(3, 2); +impl_dimadd_const_out_const!(3, 3); +impl_dimadd_const_out_dyn!(3, 4); +impl_dimadd_const_out_dyn!(3, 5); +impl_dimadd_const_out_dyn!(3, 6); +impl_dimadd_const_out_dyn!(3, IxDyn); + +impl_dimadd_const_out_const!(4, 0); +impl_dimadd_const_out_const!(4, 1); +impl_dimadd_const_out_const!(4, 2); +impl_dimadd_const_out_dyn!(4, 3); +impl_dimadd_const_out_dyn!(4, 4); +impl_dimadd_const_out_dyn!(4, 5); +impl_dimadd_const_out_dyn!(4, 6); +impl_dimadd_const_out_dyn!(4, IxDyn); + +impl_dimadd_const_out_const!(5, 0); +impl_dimadd_const_out_const!(5, 1); +impl_dimadd_const_out_dyn!(5, 2); +impl_dimadd_const_out_dyn!(5, 3); +impl_dimadd_const_out_dyn!(5, 4); +impl_dimadd_const_out_dyn!(5, 5); +impl_dimadd_const_out_dyn!(5, 6); +impl_dimadd_const_out_dyn!(5, IxDyn); + +impl_dimadd_const_out_const!(6, 0); +impl_dimadd_const_out_dyn!(6, 1); +impl_dimadd_const_out_dyn!(6, 2); +impl_dimadd_const_out_dyn!(6, 3); +impl_dimadd_const_out_dyn!(6, 4); +impl_dimadd_const_out_dyn!(6, 5); +impl_dimadd_const_out_dyn!(6, 6); +impl_dimadd_const_out_dyn!(6, IxDyn); + +impl DimAdd for IxDyn +{ + type Output = IxDyn; +} diff --git a/src/dimension/remove_axis.rs b/src/dimension/remove_axis.rs index da366ae17..cbb039fc5 100644 --- a/src/dimension/remove_axis.rs +++ b/src/dimension/remove_axis.rs @@ -12,21 +12,26 @@ use crate::{Axis, Dim, Dimension, Ix, Ix0, Ix1}; /// /// `RemoveAxis` defines a larger-than relation for array shapes: /// removing one axis from *Self* gives smaller dimension *Smaller*. -pub trait RemoveAxis: Dimension { +pub trait RemoveAxis: Dimension +{ fn remove_axis(&self, axis: Axis) -> Self::Smaller; } -impl RemoveAxis for Dim<[Ix; 1]> { +impl RemoveAxis for Dim<[Ix; 1]> +{ #[inline] - fn remove_axis(&self, axis: Axis) -> Ix0 { + fn remove_axis(&self, axis: Axis) -> Ix0 + { debug_assert!(axis.index() < self.ndim()); Ix0() } } -impl RemoveAxis for Dim<[Ix; 2]> { +impl RemoveAxis for Dim<[Ix; 2]> +{ #[inline] - fn remove_axis(&self, axis: Axis) -> Ix1 { + fn remove_axis(&self, axis: Axis) -> Ix1 + { let axis = axis.index(); debug_assert!(axis < self.ndim()); if axis == 0 { diff --git a/src/dimension/reshape.rs b/src/dimension/reshape.rs new file mode 100644 index 000000000..abcec4993 --- /dev/null +++ b/src/dimension/reshape.rs @@ -0,0 +1,236 @@ +use crate::dimension::sequence::{Forward, Reverse, Sequence, SequenceMut}; +use crate::{Dimension, ErrorKind, Order, ShapeError}; + +#[inline] +pub(crate) fn reshape_dim(from: &D, strides: &D, to: &E, order: Order) -> Result +where + D: Dimension, + E: Dimension, +{ + debug_assert_eq!(from.ndim(), strides.ndim()); + let mut to_strides = E::zeros(to.ndim()); + match order { + Order::RowMajor => { + reshape_dim_c(&Forward(from), &Forward(strides), &Forward(to), Forward(&mut to_strides))?; + } + Order::ColumnMajor => { + reshape_dim_c(&Reverse(from), &Reverse(strides), &Reverse(to), Reverse(&mut to_strides))?; + } + } + Ok(to_strides) +} + +/// Try to reshape an array with dimensions `from_dim` and strides `from_strides` to the new +/// dimension `to_dim`, while keeping the same layout of elements in memory. The strides needed +/// if this is possible are stored into `to_strides`. +/// +/// This function uses RowMajor index ordering if the inputs are read in the forward direction +/// (index 0 is axis 0 etc) and ColumnMajor index ordering if the inputs are read in reversed +/// direction (as made possible with the Sequence trait). +/// +/// Preconditions: +/// +/// 1. from_dim and to_dim are valid dimensions (product of all non-zero axes +/// fits in isize::MAX). +/// 2. from_dim and to_dim are don't have any axes that are zero (that should be handled before +/// this function). +/// 3. `to_strides` should be an all-zeros or all-ones dimension of the right dimensionality +/// (but it will be overwritten after successful exit of this function). +/// +/// This function returns: +/// +/// - IncompatibleShape if the two shapes are not of matching number of elements +/// - IncompatibleLayout if the input shape and stride can not be remapped to the output shape +/// without moving the array data into a new memory layout. +/// - Ok if the from dim could be mapped to the new to dim. +fn reshape_dim_c(from_dim: &D, from_strides: &D, to_dim: &E, mut to_strides: E2) -> Result<(), ShapeError> +where + D: Sequence, + E: Sequence, + E2: SequenceMut, +{ + // cursor indexes into the from and to dimensions + let mut fi = 0; // index into `from_dim` + let mut ti = 0; // index into `to_dim`. + + while fi < from_dim.len() && ti < to_dim.len() { + let mut fd = from_dim[fi]; + let mut fs = from_strides[fi] as isize; + let mut td = to_dim[ti]; + + if fd == td { + to_strides[ti] = from_strides[fi]; + fi += 1; + ti += 1; + continue; + } + + if fd == 1 { + fi += 1; + continue; + } + + if td == 1 { + to_strides[ti] = 1; + ti += 1; + continue; + } + + if fd == 0 || td == 0 { + debug_assert!(false, "zero dim not handled by this function"); + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + + // stride times element count is to be distributed out over a combination of axes. + let mut fstride_whole = fs * (fd as isize); + let mut fd_product = fd; // cumulative product of axis lengths in the combination (from) + let mut td_product = td; // cumulative product of axis lengths in the combination (to) + + // The two axis lengths are not a match, so try to combine multiple axes + // to get it to match up. + while fd_product != td_product { + if fd_product < td_product { + // Take another axis on the from side + fi += 1; + if fi >= from_dim.len() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + fd = from_dim[fi]; + fd_product *= fd; + if fd > 1 { + let fs_old = fs; + fs = from_strides[fi] as isize; + // check if this axis and the next are contiguous together + if fs_old != fd as isize * fs { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)); + } + } + } else { + // Take another axis on the `to` side + // First assign the stride to the axis we leave behind + fstride_whole /= td as isize; + to_strides[ti] = fstride_whole as usize; + ti += 1; + if ti >= to_dim.len() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + + td = to_dim[ti]; + td_product *= td; + } + } + + fstride_whole /= td as isize; + to_strides[ti] = fstride_whole as usize; + + fi += 1; + ti += 1; + } + + // skip past 1-dims at the end + while fi < from_dim.len() && from_dim[fi] == 1 { + fi += 1; + } + + while ti < to_dim.len() && to_dim[ti] == 1 { + to_strides[ti] = 1; + ti += 1; + } + + if fi < from_dim.len() || ti < to_dim.len() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + + Ok(()) +} + +#[cfg(feature = "std")] +#[test] +fn test_reshape() +{ + use crate::Dim; + + macro_rules! test_reshape { + (fail $order:ident from $from:expr, $stride:expr, to $to:expr) => { + let res = reshape_dim(&Dim($from), &Dim($stride), &Dim($to), Order::$order); + println!("Reshape {:?} {:?} to {:?}, order {:?}\n => {:?}", + $from, $stride, $to, Order::$order, res); + let _res = res.expect_err("Expected failed reshape"); + }; + (ok $order:ident from $from:expr, $stride:expr, to $to:expr, $to_stride:expr) => {{ + let res = reshape_dim(&Dim($from), &Dim($stride), &Dim($to), Order::$order); + println!("Reshape {:?} {:?} to {:?}, order {:?}\n => {:?}", + $from, $stride, $to, Order::$order, res); + println!("default stride for from dim: {:?}", Dim($from).default_strides()); + println!("default stride for to dim: {:?}", Dim($to).default_strides()); + let res = res.expect("Expected successful reshape"); + assert_eq!(res, Dim($to_stride), "mismatch in strides"); + }}; + } + + test_reshape!(ok C from [1, 2, 3], [6, 3, 1], to [1, 2, 3], [6, 3, 1]); + test_reshape!(ok C from [1, 2, 3], [6, 3, 1], to [2, 3], [3, 1]); + test_reshape!(ok C from [1, 2, 3], [6, 3, 1], to [6], [1]); + test_reshape!(fail C from [1, 2, 3], [6, 3, 1], to [1]); + test_reshape!(fail F from [1, 2, 3], [6, 3, 1], to [1]); + + test_reshape!(ok C from [6], [1], to [3, 2], [2, 1]); + test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [4, 15], [15, 1]); + + test_reshape!(ok C from [4, 4, 4], [16, 4, 1], to [16, 4], [4, 1]); + + test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 4, 1], [8, 4, 1, 1]); + test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 4], [8, 4, 1]); + test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 2, 2], [8, 4, 2, 1]); + + test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 1, 4], [8, 4, 1, 1]); + + test_reshape!(ok C from [4, 4, 4], [16, 4, 1], to [16, 4], [4, 1]); + test_reshape!(ok C from [3, 4, 4], [16, 4, 1], to [3, 16], [16, 1]); + + test_reshape!(ok C from [4, 4], [8, 1], to [2, 2, 2, 2], [16, 8, 2, 1]); + + test_reshape!(fail C from [4, 4], [8, 1], to [2, 1, 4, 2]); + + test_reshape!(ok C from [16], [4], to [2, 2, 4], [32, 16, 4]); + test_reshape!(ok C from [16], [-4isize as usize], to [2, 2, 4], + [-32isize as usize, -16isize as usize, -4isize as usize]); + test_reshape!(ok F from [16], [4], to [2, 2, 4], [4, 8, 16]); + test_reshape!(ok F from [16], [-4isize as usize], to [2, 2, 4], + [-4isize as usize, -8isize as usize, -16isize as usize]); + + test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [12, 5], [5, 1]); + test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [4, 15], [15, 1]); + test_reshape!(fail F from [3, 4, 5], [20, 5, 1], to [4, 15]); + test_reshape!(ok C from [3, 4, 5, 7], [140, 35, 7, 1], to [28, 15], [15, 1]); + + // preserve stride if shape matches + test_reshape!(ok C from [10], [2], to [10], [2]); + test_reshape!(ok F from [10], [2], to [10], [2]); + test_reshape!(ok C from [2, 10], [1, 2], to [2, 10], [1, 2]); + test_reshape!(ok F from [2, 10], [1, 2], to [2, 10], [1, 2]); + test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [3, 4, 5], [20, 5, 1]); + test_reshape!(ok F from [3, 4, 5], [20, 5, 1], to [3, 4, 5], [20, 5, 1]); + + test_reshape!(ok C from [3, 4, 5], [4, 1, 1], to [12, 5], [1, 1]); + test_reshape!(ok F from [3, 4, 5], [1, 3, 12], to [12, 5], [1, 12]); + test_reshape!(ok F from [3, 4, 5], [1, 3, 1], to [12, 5], [1, 1]); + + // broadcast shapes + test_reshape!(ok C from [3, 4, 5, 7], [0, 0, 7, 1], to [12, 35], [0, 1]); + test_reshape!(fail C from [3, 4, 5, 7], [0, 0, 7, 1], to [28, 15]); + + // one-filled shapes + test_reshape!(ok C from [10], [1], to [1, 10, 1, 1, 1], [1, 1, 1, 1, 1]); + test_reshape!(ok F from [10], [1], to [1, 10, 1, 1, 1], [1, 1, 1, 1, 1]); + test_reshape!(ok C from [1, 10], [10, 1], to [1, 10, 1, 1, 1], [10, 1, 1, 1, 1]); + test_reshape!(ok F from [1, 10], [10, 1], to [1, 10, 1, 1, 1], [10, 1, 1, 1, 1]); + test_reshape!(ok C from [1, 10], [1, 1], to [1, 5, 1, 1, 2], [1, 2, 2, 2, 1]); + test_reshape!(ok F from [1, 10], [1, 1], to [1, 5, 1, 1, 2], [1, 1, 5, 5, 5]); + test_reshape!(ok C from [10, 1, 1, 1, 1], [1, 1, 1, 1, 1], to [10], [1]); + test_reshape!(ok F from [10, 1, 1, 1, 1], [1, 1, 1, 1, 1], to [10], [1]); + test_reshape!(ok C from [1, 5, 1, 2, 1], [1, 2, 1, 1, 1], to [10], [1]); + test_reshape!(fail F from [1, 5, 1, 2, 1], [1, 2, 1, 1, 1], to [10]); + test_reshape!(ok F from [1, 5, 1, 2, 1], [1, 1, 1, 5, 1], to [10], [1]); + test_reshape!(fail C from [1, 5, 1, 2, 1], [1, 1, 1, 5, 1], to [10]); +} diff --git a/src/dimension/sequence.rs b/src/dimension/sequence.rs new file mode 100644 index 000000000..ed3605d57 --- /dev/null +++ b/src/dimension/sequence.rs @@ -0,0 +1,129 @@ +use std::ops::Index; +use std::ops::IndexMut; + +use crate::dimension::Dimension; + +pub(in crate::dimension) struct Forward(pub(crate) D); +pub(in crate::dimension) struct Reverse(pub(crate) D); + +impl Index for Forward<&D> +where D: Dimension +{ + type Output = usize; + + #[inline] + fn index(&self, index: usize) -> &usize + { + &self.0[index] + } +} + +impl Index for Forward<&mut D> +where D: Dimension +{ + type Output = usize; + + #[inline] + fn index(&self, index: usize) -> &usize + { + &self.0[index] + } +} + +impl IndexMut for Forward<&mut D> +where D: Dimension +{ + #[inline] + fn index_mut(&mut self, index: usize) -> &mut usize + { + &mut self.0[index] + } +} + +impl Index for Reverse<&D> +where D: Dimension +{ + type Output = usize; + + #[inline] + fn index(&self, index: usize) -> &usize + { + &self.0[self.len() - index - 1] + } +} + +impl Index for Reverse<&mut D> +where D: Dimension +{ + type Output = usize; + + #[inline] + fn index(&self, index: usize) -> &usize + { + &self.0[self.len() - index - 1] + } +} + +impl IndexMut for Reverse<&mut D> +where D: Dimension +{ + #[inline] + fn index_mut(&mut self, index: usize) -> &mut usize + { + let len = self.len(); + &mut self.0[len - index - 1] + } +} + +/// Indexable sequence with length +pub(in crate::dimension) trait Sequence: Index +{ + fn len(&self) -> usize; +} + +/// Indexable sequence with length (mut) +pub(in crate::dimension) trait SequenceMut: Sequence + IndexMut {} + +impl Sequence for Forward<&D> +where D: Dimension +{ + #[inline] + fn len(&self) -> usize + { + self.0.ndim() + } +} + +impl Sequence for Forward<&mut D> +where D: Dimension +{ + #[inline] + fn len(&self) -> usize + { + self.0.ndim() + } +} + +impl SequenceMut for Forward<&mut D> where D: Dimension {} + +impl Sequence for Reverse<&D> +where D: Dimension +{ + #[inline] + fn len(&self) -> usize + { + self.0.ndim() + } +} + +impl Sequence for Reverse<&mut D> +where D: Dimension +{ + #[inline] + fn len(&self) -> usize + { + self.0.ndim() + } +} + +impl SequenceMut for Reverse<&mut D> where D: Dimension {} diff --git a/src/doc/crate_feature_flags.rs b/src/doc/crate_feature_flags.rs new file mode 100644 index 000000000..fc2c2bd49 --- /dev/null +++ b/src/doc/crate_feature_flags.rs @@ -0,0 +1,35 @@ +//! Crate Feature Flags +//! +//! The following crate feature flags are available. They are configured in your +//! `Cargo.toml` where the dependency on `ndarray` is defined. +//! +//! ## `std` +//! - Rust standard library (enabled by default) +//! - This crate can be used without the standard library by disabling the +//! default `std` feature. To do so, use `default-features = false` in +//! your `Cargo.toml`. +//! - The `geomspace` `linspace` `logspace` `range` `std` `var` `var_axis` +//! and `std_axis` methods are only available when `std` is enabled. +//! +//! ## `serde` +//! - Enables serialization support for serde 1.x +//! +//! ## `rayon` +//! - Enables parallel iterators, parallelized methods, the [`parallel`] module and [`par_azip!`]. +//! - Implies std +//! +//! ## `approx` +//! - Enables implementations of traits of the [`approx`] crate. +//! +//! ## `blas` +//! - Enable transparent BLAS support for matrix multiplication. +//! Uses ``blas-src`` for pluggable backend, which needs to be configured +//! separately (see the README). +//! +//! ## `matrixmultiply-threading` +//! - Enable the ``threading`` feature in the matrixmultiply package +//! +//! [`parallel`]: crate::parallel + +#[cfg(doc)] +use crate::parallel::par_azip; diff --git a/src/doc/mod.rs b/src/doc/mod.rs index b98c9cab8..c0d7fab91 100644 --- a/src/doc/mod.rs +++ b/src/doc/mod.rs @@ -1,3 +1,4 @@ //! Standalone documentation pages. +pub mod crate_feature_flags; pub mod ndarray_for_numpy_users; diff --git a/src/doc/ndarray_for_numpy_users/coord_transform.rs b/src/doc/ndarray_for_numpy_users/coord_transform.rs index ef922f2ad..1529e8746 100644 --- a/src/doc/ndarray_for_numpy_users/coord_transform.rs +++ b/src/doc/ndarray_for_numpy_users/coord_transform.rs @@ -49,44 +49,40 @@ //! This is a direct translation to `ndarray`: //! //! ``` -//! extern crate ndarray; -//! //! use ndarray::prelude::*; //! -//! fn main() { -//! let nelems = 4; -//! let bunge = Array::ones((3, nelems)); -//! -//! let s1 = bunge.slice(s![0, ..]).mapv(f64::sin); -//! let c1 = bunge.slice(s![0, ..]).mapv(f64::cos); -//! let s2 = bunge.slice(s![1, ..]).mapv(f64::sin); -//! let c2 = bunge.slice(s![1, ..]).mapv(f64::cos); -//! let s3 = bunge.slice(s![2, ..]).mapv(f64::sin); -//! let c3 = bunge.slice(s![2, ..]).mapv(f64::cos); -//! -//! let mut rmat = Array::zeros((3, 3, nelems).f()); -//! for i in 0..nelems { -//! rmat[[0, 0, i]] = c1[i] * c3[i] - s1[i] * s3[i] * c2[i]; -//! rmat[[0, 1, i]] = -c1[i] * s3[i] - s1[i] * c2[i] * c3[i]; -//! rmat[[0, 2, i]] = s1[i] * s2[i]; -//! -//! rmat[[1, 0, i]] = s1[i] * c3[i] + c1[i] * c2[i] * s3[i]; -//! rmat[[1, 1, i]] = -s1[i] * s3[i] + c1[i] * c2[i] * c3[i]; -//! rmat[[1, 2, i]] = -c1[i] * s2[i]; -//! -//! rmat[[2, 0, i]] = s2[i] * s3[i]; -//! rmat[[2, 1, i]] = s2[i] * c3[i]; -//! rmat[[2, 2, i]] = c2[i]; -//! } -//! -//! let eye2d = Array::eye(3); -//! -//! let mut rotated = Array::zeros((3, 3, nelems).f()); -//! for i in 0..nelems { -//! rotated -//! .slice_mut(s![.., .., i]) -//! .assign({ &rmat.slice(s![.., .., i]).dot(&eye2d) }); -//! } +//! let nelems = 4; +//! let bunge = Array::ones((3, nelems)); +//! +//! let s1 = bunge.slice(s![0, ..]).mapv(f64::sin); +//! let c1 = bunge.slice(s![0, ..]).mapv(f64::cos); +//! let s2 = bunge.slice(s![1, ..]).mapv(f64::sin); +//! let c2 = bunge.slice(s![1, ..]).mapv(f64::cos); +//! let s3 = bunge.slice(s![2, ..]).mapv(f64::sin); +//! let c3 = bunge.slice(s![2, ..]).mapv(f64::cos); +//! +//! let mut rmat = Array::zeros((3, 3, nelems).f()); +//! for i in 0..nelems { +//! rmat[[0, 0, i]] = c1[i] * c3[i] - s1[i] * s3[i] * c2[i]; +//! rmat[[0, 1, i]] = -c1[i] * s3[i] - s1[i] * c2[i] * c3[i]; +//! rmat[[0, 2, i]] = s1[i] * s2[i]; +//! +//! rmat[[1, 0, i]] = s1[i] * c3[i] + c1[i] * c2[i] * s3[i]; +//! rmat[[1, 1, i]] = -s1[i] * s3[i] + c1[i] * c2[i] * c3[i]; +//! rmat[[1, 2, i]] = -c1[i] * s2[i]; +//! +//! rmat[[2, 0, i]] = s2[i] * s3[i]; +//! rmat[[2, 1, i]] = s2[i] * c3[i]; +//! rmat[[2, 2, i]] = c2[i]; +//! } +//! +//! let eye2d = Array::eye(3); +//! +//! let mut rotated = Array::zeros((3, 3, nelems).f()); +//! for i in 0..nelems { +//! rotated +//! .slice_mut(s![.., .., i]) +//! .assign(&rmat.slice(s![.., .., i]).dot(&eye2d)); //! } //! ``` //! @@ -96,41 +92,37 @@ //! this: //! //! ``` -//! extern crate ndarray; -//! //! use ndarray::prelude::*; //! -//! fn main() { -//! let nelems = 4; -//! let bunge = Array2::::ones((3, nelems)); -//! -//! let mut rmat = Array::zeros((3, 3, nelems).f()); -//! azip!((mut rmat in rmat.axis_iter_mut(Axis(2)), bunge in bunge.axis_iter(Axis(1))) { -//! let s1 = bunge[0].sin(); -//! let c1 = bunge[0].cos(); -//! let s2 = bunge[1].sin(); -//! let c2 = bunge[1].cos(); -//! let s3 = bunge[2].sin(); -//! let c3 = bunge[2].cos(); -//! -//! rmat[[0, 0]] = c1 * c3 - s1 * s3 * c2; -//! rmat[[0, 1]] = -c1 * s3 - s1 * c2 * c3; -//! rmat[[0, 2]] = s1 * s2; -//! -//! rmat[[1, 0]] = s1 * c3 + c1 * c2 * s3; -//! rmat[[1, 1]] = -s1 * s3 + c1 * c2 * c3; -//! rmat[[1, 2]] = -c1 * s2; -//! -//! rmat[[2, 0]] = s2 * s3; -//! rmat[[2, 1]] = s2 * c3; -//! rmat[[2, 2]] = c2; -//! }); -//! -//! let eye2d = Array2::::eye(3); -//! -//! let mut rotated = Array3::::zeros((3, 3, nelems).f()); -//! azip!((mut rotated in rotated.axis_iter_mut(Axis(2)), rmat in rmat.axis_iter(Axis(2))) { -//! rotated.assign(&rmat.dot(&eye2d)); -//! }); -//! } +//! let nelems = 4; +//! let bunge = Array2::::ones((3, nelems)); +//! +//! let mut rmat = Array::zeros((3, 3, nelems).f()); +//! azip!((mut rmat in rmat.axis_iter_mut(Axis(2)), bunge in bunge.axis_iter(Axis(1))) { +//! let s1 = bunge[0].sin(); +//! let c1 = bunge[0].cos(); +//! let s2 = bunge[1].sin(); +//! let c2 = bunge[1].cos(); +//! let s3 = bunge[2].sin(); +//! let c3 = bunge[2].cos(); +//! +//! rmat[[0, 0]] = c1 * c3 - s1 * s3 * c2; +//! rmat[[0, 1]] = -c1 * s3 - s1 * c2 * c3; +//! rmat[[0, 2]] = s1 * s2; +//! +//! rmat[[1, 0]] = s1 * c3 + c1 * c2 * s3; +//! rmat[[1, 1]] = -s1 * s3 + c1 * c2 * c3; +//! rmat[[1, 2]] = -c1 * s2; +//! +//! rmat[[2, 0]] = s2 * s3; +//! rmat[[2, 1]] = s2 * c3; +//! rmat[[2, 2]] = c2; +//! }); +//! +//! let eye2d = Array2::::eye(3); +//! +//! let mut rotated = Array3::::zeros((3, 3, nelems).f()); +//! azip!((mut rotated in rotated.axis_iter_mut(Axis(2)), rmat in rmat.axis_iter(Axis(2))) { +//! rotated.assign(&rmat.dot(&eye2d)); +//! }); //! ``` diff --git a/src/doc/ndarray_for_numpy_users/mod.rs b/src/doc/ndarray_for_numpy_users/mod.rs index d76de9130..eba96cdd0 100644 --- a/src/doc/ndarray_for_numpy_users/mod.rs +++ b/src/doc/ndarray_for_numpy_users/mod.rs @@ -19,11 +19,12 @@ //! * [Mathematics](#mathematics) //! * [Array manipulation](#array-manipulation) //! * [Iteration](#iteration) +//! * [Type conversions](#type-conversions) //! * [Convenience methods for 2-D arrays](#convenience-methods-for-2-d-arrays) //! //! # Similarities //! -//! `ndarray`'s array type ([`ArrayBase`][ArrayBase]), is very similar to +//! `ndarray`'s array type ([`ArrayBase`]), is very similar to //! NumPy's array type (`numpy.ndarray`): //! //! * Arrays have a single element type. @@ -70,12 +71,12 @@ //! //! //! -//! In `ndarray`, all arrays are instances of [`ArrayBase`][ArrayBase], but -//! `ArrayBase` is generic over the ownership of the data. [`Array`][Array] -//! owns its data; [`ArrayView`][ArrayView] is a view; -//! [`ArrayViewMut`][ArrayViewMut] is a mutable view; [`CowArray`][CowArray] +//! In `ndarray`, all arrays are instances of [`ArrayBase`], but +//! `ArrayBase` is generic over the ownership of the data. [`Array`] +//! owns its data; [`ArrayView`] is a view; +//! [`ArrayViewMut`] is a mutable view; [`CowArray`] //! either owns its data or is a view (with copy-on-write mutation of the view -//! variant); and [`ArcArray`][ArcArray] has a reference-counted pointer to its +//! variant); and [`ArcArray`] has a reference-counted pointer to its //! data (with copy-on-write mutation). Arrays and views follow Rust's aliasing //! rules. //! @@ -91,7 +92,7 @@ //! //! //! In `ndarray`, you can create fixed-dimension arrays, such as -//! [`Array2`][Array2]. This takes advantage of the type system to help you +//! [`Array2`]. This takes advantage of the type system to help you //! write correct code and also avoids small heap allocations for the shape and //! strides. //! @@ -110,7 +111,7 @@ //! When slicing in `ndarray`, the axis is first sliced with `start..end`. Then if //! `step` is positive, the first index is the front of the slice; if `step` is //! negative, the first index is the back of the slice. This means that the -//! behavior is the same as NumPy except when `step < -1`. See the docs for the +//! behavior is different from NumPy when `step < 0`. See the docs for the //! [`s![]` macro][s!] for more details. //! //! @@ -177,11 +178,9 @@ //! and `ndarray` like this: //! //! ``` -//! extern crate ndarray; -//! //! use ndarray::prelude::*; //! # -//! # fn main() {} +//! # fn main() { let _ = arr0(1); } //! ``` //! //! ## Array creation @@ -247,8 +246,8 @@ //! methods [`.slice_mut()`][.slice_mut()], [`.slice_move()`][.slice_move()], and //! [`.slice_collapse()`][.slice_collapse()]. //! -//! * The behavior of slicing is slightly different from NumPy for slices with -//! `step < -1`. See the docs for the [`s![]` macro][s!] for more details. +//! * The behavior of slicing is different from NumPy for slices with +//! `step < 0`. See the docs for the [`s![]` macro][s!] for more details. //! //! NumPy | `ndarray` | Notes //! ------|-----------|------ @@ -259,13 +258,14 @@ //! `a[-5:]` or `a[-5:, :]` | [`a.slice(s![-5.., ..])`][.slice()] or [`a.slice_axis(Axis(0), Slice::from(-5..))`][.slice_axis()] | get the last 5 rows of a 2-D array //! `a[:3, 4:9]` | [`a.slice(s![..3, 4..9])`][.slice()] | columns 4, 5, 6, 7, and 8 of the first 3 rows //! `a[1:4:2, ::-1]` | [`a.slice(s![1..4;2, ..;-1])`][.slice()] | rows 1 and 3 with the columns in reverse order +//! `a.take([4, 2])` | `a.select(Axis(0), &[4, 2])` | rows 4 and 2 of the array //! //! ## Shape and strides //! //! Note that [`a.shape()`][.shape()], [`a.dim()`][.dim()], and //! [`a.raw_dim()`][.raw_dim()] all return the shape of the array, but as //! different types. `a.shape()` returns the shape as `&[Ix]`, (where -//! [`Ix`][Ix] is `usize`) which is useful for general operations on the shape. +//! [`Ix`] is `usize`) which is useful for general operations on the shape. //! `a.dim()` returns the shape as `D::Pattern`, which is useful for //! pattern-matching shapes. `a.raw_dim()` returns the shape as `D`, which is //! useful for creating other arrays of the same shape. @@ -284,7 +284,7 @@ //! Note that [`.mapv()`][.mapv()] has corresponding methods [`.map()`][.map()], //! [`.mapv_into()`][.mapv_into()], [`.map_inplace()`][.map_inplace()], and //! [`.mapv_inplace()`][.mapv_inplace()]. Also look at [`.fold()`][.fold()], -//! [`.visit()`][.visit()], [`.fold_axis()`][.fold_axis()], and +//! [`.for_each()`][.for_each()], [`.fold_axis()`][.fold_axis()], and //! [`.map_axis()`][.map_axis()]. //! //! @@ -378,7 +378,7 @@ //! //! /// /// /// /// /// /// @@ -973,19 +1037,19 @@ pub type Ixs = isize; /// /// Input | Output | Methods /// ------|--------|-------- -/// `Vec` | `ArrayBase` | [`::from_vec()`](#method.from_vec) -/// `Vec` | `ArrayBase` | [`::from_shape_vec()`](#method.from_shape_vec) -/// `&[A]` | `ArrayView1` | [`::from()`](type.ArrayView.html#method.from) -/// `&[A]` | `ArrayView` | [`::from_shape()`](type.ArrayView.html#method.from_shape) -/// `&mut [A]` | `ArrayViewMut1` | [`::from()`](type.ArrayViewMut.html#method.from) -/// `&mut [A]` | `ArrayViewMut` | [`::from_shape()`](type.ArrayViewMut.html#method.from_shape) -/// `&ArrayBase` | `Vec` | [`.to_vec()`](#method.to_vec) -/// `Array` | `Vec` | [`.into_raw_vec()`](type.Array.html#method.into_raw_vec)[1](#into_raw_vec) -/// `&ArrayBase` | `&[A]` | [`.as_slice()`](#method.as_slice)[2](#req_contig_std), [`.as_slice_memory_order()`](#method.as_slice_memory_order)[3](#req_contig) -/// `&mut ArrayBase` | `&mut [A]` | [`.as_slice_mut()`](#method.as_slice_mut)[2](#req_contig_std), [`.as_slice_memory_order_mut()`](#method.as_slice_memory_order_mut)[3](#req_contig) -/// `ArrayView` | `&[A]` | [`.to_slice()`](type.ArrayView.html#method.to_slice)[2](#req_contig_std) -/// `ArrayViewMut` | `&mut [A]` | [`.into_slice()`](type.ArrayViewMut.html#method.into_slice)[2](#req_contig_std) -/// `Array0` | `A` | [`.into_scalar()`](type.Array.html#method.into_scalar) +/// `Vec` | `ArrayBase` | [`::from_vec()`](Self::from_vec) +/// `Vec` | `ArrayBase` | [`::from_shape_vec()`](Self::from_shape_vec) +/// `&[A]` | `ArrayView1` | [`::from()`](ArrayView#method.from) +/// `&[A]` | `ArrayView` | [`::from_shape()`](ArrayView#method.from_shape) +/// `&mut [A]` | `ArrayViewMut1` | [`::from()`](ArrayViewMut#method.from) +/// `&mut [A]` | `ArrayViewMut` | [`::from_shape()`](ArrayViewMut#method.from_shape) +/// `&ArrayBase` | `Vec` | [`.to_vec()`](Self::to_vec) +/// `Array` | `Vec` | [`.into_raw_vec()`](Array#method.into_raw_vec)[1](#into_raw_vec) +/// `&ArrayBase` | `&[A]` | [`.as_slice()`](Self::as_slice)[2](#req_contig_std), [`.as_slice_memory_order()`](Self::as_slice_memory_order)[3](#req_contig) +/// `&mut ArrayBase` | `&mut [A]` | [`.as_slice_mut()`](Self::as_slice_mut)[2](#req_contig_std), [`.as_slice_memory_order_mut()`](Self::as_slice_memory_order_mut)[3](#req_contig) +/// `ArrayView` | `&[A]` | [`.to_slice()`](ArrayView#method.to_slice)[2](#req_contig_std) +/// `ArrayViewMut` | `&mut [A]` | [`.into_slice()`](ArrayViewMut#method.into_slice)[2](#req_contig_std) +/// `Array0` | `A` | [`.into_scalar()`](Array#method.into_scalar) /// /// 1Returns the data in memory order. /// @@ -998,16 +1062,16 @@ pub type Ixs = isize; /// conversions to/from `Vec`s/slices. See /// [below](#constructor-methods-for-owned-arrays) for more constructors. /// -/// [ArrayView::reborrow()]: type.ArrayView.html#method.reborrow -/// [ArrayViewMut::reborrow()]: type.ArrayViewMut.html#method.reborrow -/// [.into_dimensionality()]: #method.into_dimensionality -/// [.into_dyn()]: #method.into_dyn -/// [.into_owned()]: #method.into_owned -/// [.into_shared()]: #method.into_shared -/// [.to_owned()]: #method.to_owned -/// [.map()]: #method.map -/// [.view()]: #method.view -/// [.view_mut()]: #method.view_mut +/// [ArrayView::reborrow()]: ArrayView#method.reborrow +/// [ArrayViewMut::reborrow()]: ArrayViewMut#method.reborrow +/// [.into_dimensionality()]: Self::into_dimensionality +/// [.into_dyn()]: Self::into_dyn +/// [.into_owned()]: Self::into_owned +/// [.into_shared()]: Self::into_shared +/// [.to_owned()]: Self::to_owned +/// [.map()]: Self::map +/// [.view()]: Self::view +/// [.view_mut()]: Self::view_mut /// /// ### Conversions from Nested `Vec`s/`Array`s /// @@ -1049,7 +1113,7 @@ pub type Ixs = isize; /// If you don't know ahead-of-time the shape of the final array, then the /// cleanest solution is generally to append the data to a flat `Vec`, and then /// convert it to an `Array` at the end with -/// [`::from_shape_vec()`](#method.from_shape_vec). You just have to be careful +/// [`::from_shape_vec()`](Self::from_shape_vec). You just have to be careful /// that the layout of the data (the order of the elements in the flat `Vec`) /// is correct. /// @@ -1072,10 +1136,9 @@ pub type Ixs = isize; /// /// If neither of these options works for you, and you really need to convert /// nested `Vec`/`Array` instances to an `Array`, the cleanest solution is -/// generally to use -/// [`Iterator::flatten()`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.flatten) +/// generally to use [`Iterator::flatten()`] /// to get a flat `Vec`, and then convert the `Vec` to an `Array` with -/// [`::from_shape_vec()`](#method.from_shape_vec), like this: +/// [`::from_shape_vec()`](Self::from_shape_vec), like this: /// /// ```rust /// use ndarray::{array, Array2, Array3}; @@ -1215,8 +1278,7 @@ pub type Ixs = isize; // // [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset-1 pub struct ArrayBase -where - S: RawData, +where S: RawData { /// Data buffer / ownership information. (If owned, contains the data /// buffer; if borrowed, contains the lifetime and mutability.) @@ -1230,14 +1292,6 @@ where strides: D, } -/// An array where the data has shared ownership and is copy on write. -/// -/// It can act as both an owner as the data as well as a shared reference (view like). -/// -/// **Note: this type alias is obsolete.** See the equivalent [`ArcArray`] instead. -#[deprecated(note = "`RcArray` has been renamed to `ArcArray`")] -pub type RcArray = ArrayBase, D>; - /// An array where the data has shared ownership and is copy on write. /// /// The `ArcArray` is parameterized by `A` for the element type and `D` for @@ -1246,20 +1300,20 @@ pub type RcArray = ArrayBase, D>; /// It can act as both an owner as the data as well as a shared reference (view /// like). /// Calling a method for mutating elements on `ArcArray`, for example -/// [`view_mut()`](struct.ArrayBase.html#method.view_mut) or -/// [`get_mut()`](struct.ArrayBase.html#method.get_mut), will break sharing and +/// [`view_mut()`](ArrayBase::view_mut) or +/// [`get_mut()`](ArrayBase::get_mut), will break sharing and /// require a clone of the data (if it is not uniquely held). /// /// `ArcArray` uses atomic reference counting like `Arc`, so it is `Send` and /// `Sync` (when allowed by the element type of the array too). /// -/// [**`ArrayBase`**](struct.ArrayBase.html) is used to implement both the owned +/// **[`ArrayBase`]** is used to implement both the owned /// arrays and the views; see its docs for an overview of all array features. /// /// See also: /// -/// + [Constructor Methods for Owned Arrays](struct.ArrayBase.html#constructor-methods-for-owned-arrays) -/// + [Methods For All Array Types](struct.ArrayBase.html#methods-for-all-array-types) +/// + [Constructor Methods for Owned Arrays](ArrayBase#constructor-methods-for-owned-arrays) +/// + [Methods For All Array Types](ArrayBase#methods-for-all-array-types) pub type ArcArray = ArrayBase, D>; /// An array that owns its data uniquely. @@ -1270,19 +1324,19 @@ pub type ArcArray = ArrayBase, D>; /// The `Array` is parameterized by `A` for the element type and `D` for /// the dimensionality. /// -/// [**`ArrayBase`**](struct.ArrayBase.html) is used to implement both the owned +/// **[`ArrayBase`]** is used to implement both the owned /// arrays and the views; see its docs for an overview of all array features. /// /// See also: /// -/// + [Constructor Methods for Owned Arrays](struct.ArrayBase.html#constructor-methods-for-owned-arrays) -/// + [Methods For All Array Types](struct.ArrayBase.html#methods-for-all-array-types) +/// + [Constructor Methods for Owned Arrays](ArrayBase#constructor-methods-for-owned-arrays) +/// + [Methods For All Array Types](ArrayBase#methods-for-all-array-types) /// + Dimensionality-specific type alises -/// [`Array1`](type.Array1.html), -/// [`Array2`](type.Array2.html), -/// [`Array3`](type.Array3.html), ..., -/// [`ArrayD`](type.ArrayD.html), -/// and so on. +/// [`Array1`], +/// [`Array2`], +/// [`Array3`], ..., +/// [`ArrayD`], +/// and so on. pub type Array = ArrayBase, D>; /// An array with copy-on-write behavior. @@ -1290,20 +1344,17 @@ pub type Array = ArrayBase, D>; /// An `CowArray` represents either a uniquely owned array or a view of an /// array. The `'a` corresponds to the lifetime of the view variant. /// -/// This type is analogous to -/// [`std::borrow::Cow`](https://doc.rust-lang.org/std/borrow/enum.Cow.html). +/// This type is analogous to [`std::borrow::Cow`]. /// If a `CowArray` instance is the immutable view variant, then calling a /// method for mutating elements in the array will cause it to be converted /// into the owned variant (by cloning all the elements) before the /// modification is performed. /// -/// Array views have all the methods of an array (see [`ArrayBase`][ab]). +/// Array views have all the methods of an array (see [`ArrayBase`]). /// -/// See also [`ArcArray`](type.ArcArray.html), which also provides +/// See also [`ArcArray`], which also provides /// copy-on-write behavior but has a reference-counted pointer to the data /// instead of either a view or a uniquely owned copy. -/// -/// [ab]: struct.ArrayBase.html pub type CowArray<'a, A, D> = ArrayBase, D>; /// A read-only array view. @@ -1314,11 +1365,9 @@ pub type CowArray<'a, A, D> = ArrayBase, D>; /// The `ArrayView<'a, A, D>` is parameterized by `'a` for the scope of the /// borrow, `A` for the element type and `D` for the dimensionality. /// -/// Array views have all the methods of an array (see [`ArrayBase`][ab]). -/// -/// See also [`ArrayViewMut`](type.ArrayViewMut.html). +/// Array views have all the methods of an array (see [`ArrayBase`]). /// -/// [ab]: struct.ArrayBase.html +/// See also [`ArrayViewMut`]. pub type ArrayView<'a, A, D> = ArrayBase, D>; /// A read-write array view. @@ -1329,11 +1378,9 @@ pub type ArrayView<'a, A, D> = ArrayBase, D>; /// The `ArrayViewMut<'a, A, D>` is parameterized by `'a` for the scope of the /// borrow, `A` for the element type and `D` for the dimensionality. /// -/// Array views have all the methods of an array (see [`ArrayBase`][ab]). +/// Array views have all the methods of an array (see [`ArrayBase`]). /// -/// See also [`ArrayView`](type.ArrayView.html). -/// -/// [ab]: struct.ArrayBase.html +/// See also [`ArrayView`]. pub type ArrayViewMut<'a, A, D> = ArrayBase, D>; /// A read-only array view without a lifetime. @@ -1345,19 +1392,17 @@ pub type ArrayViewMut<'a, A, D> = ArrayBase, D>; /// T` and `&T`, but `RawArrayView` has additional requirements that `*const T` /// does not, such as non-nullness. /// -/// [`ArrayView`]: type.ArrayView.html -/// /// The `RawArrayView` is parameterized by `A` for the element type and /// `D` for the dimensionality. /// /// Raw array views have all the methods of an array (see -/// [`ArrayBase`](struct.ArrayBase.html)). +/// [`ArrayBase`]). /// -/// See also [`RawArrayViewMut`](type.RawArrayViewMut.html). +/// See also [`RawArrayViewMut`]. /// /// # Warning /// -/// You can't use this type wih an arbitrary raw pointer; see +/// You can't use this type with an arbitrary raw pointer; see /// [`from_shape_ptr`](#method.from_shape_ptr) for details. pub type RawArrayView = ArrayBase, D>; @@ -1370,45 +1415,33 @@ pub type RawArrayView = ArrayBase, D>; /// relationship between `*mut T` and `&mut T`, but `RawArrayViewMut` has /// additional requirements that `*mut T` does not, such as non-nullness. /// -/// [`ArrayViewMut`]: type.ArrayViewMut.html -/// /// The `RawArrayViewMut` is parameterized by `A` for the element type /// and `D` for the dimensionality. /// /// Raw array views have all the methods of an array (see -/// [`ArrayBase`](struct.ArrayBase.html)). +/// [`ArrayBase`]). /// -/// See also [`RawArrayView`](type.RawArrayView.html). +/// See also [`RawArrayView`]. /// /// # Warning /// -/// You can't use this type wih an arbitrary raw pointer; see +/// You can't use this type with an arbitrary raw pointer; see /// [`from_shape_ptr`](#method.from_shape_ptr) for details. pub type RawArrayViewMut = ArrayBase, D>; -/// Array's representation. -/// -/// *Don’t use this type directly—use the type alias -/// [`Array`](type.Array.html) for the array type!* -#[derive(Clone, Debug)] -pub struct OwnedRepr(Vec); - -/// RcArray's representation. -/// -/// *Don’t use this type directly—use the type alias -/// [`RcArray`](type.RcArray.html) for the array type!* -#[deprecated(note = "RcArray is replaced by ArcArray")] -pub use self::OwnedArcRepr as OwnedRcRepr; +pub use data_repr::OwnedRepr; /// ArcArray's representation. /// /// *Don’t use this type directly—use the type alias -/// [`ArcArray`](type.ArcArray.html) for the array type!* +/// [`ArcArray`] for the array type!* #[derive(Debug)] -pub struct OwnedArcRepr(Arc>); +pub struct OwnedArcRepr(Arc>); -impl Clone for OwnedArcRepr { - fn clone(&self) -> Self { +impl Clone for OwnedArcRepr +{ + fn clone(&self) -> Self + { OwnedArcRepr(self.0.clone()) } } @@ -1416,17 +1449,19 @@ impl Clone for OwnedArcRepr { /// Array pointer’s representation. /// /// *Don’t use this type directly—use the type aliases -/// [`RawArrayView`](type.RawArrayView.html) / -/// [`RawArrayViewMut`](type.RawArrayViewMut.html) for the array type!* +/// [`RawArrayView`] / [`RawArrayViewMut`] for the array type!* #[derive(Copy, Clone)] // This is just a marker type, to carry the mutability and element type. -pub struct RawViewRepr { +pub struct RawViewRepr +{ ptr: PhantomData, } -impl RawViewRepr { +impl RawViewRepr +{ #[inline(always)] - fn new() -> Self { + const fn new() -> Self + { RawViewRepr { ptr: PhantomData } } } @@ -1434,17 +1469,19 @@ impl RawViewRepr { /// Array view’s representation. /// /// *Don’t use this type directly—use the type aliases -/// [`ArrayView`](type.ArrayView.html) -/// / [`ArrayViewMut`](type.ArrayViewMut.html) for the array type!* +/// [`ArrayView`] / [`ArrayViewMut`] for the array type!* #[derive(Copy, Clone)] // This is just a marker type, to carry the lifetime parameter. -pub struct ViewRepr { +pub struct ViewRepr +{ life: PhantomData, } -impl ViewRepr { +impl ViewRepr +{ #[inline(always)] - fn new() -> Self { + const fn new() -> Self + { ViewRepr { life: PhantomData } } } @@ -1452,17 +1489,20 @@ impl ViewRepr { /// CowArray's representation. /// /// *Don't use this type directly—use the type alias -/// [`CowArray`](type.CowArray.html) for the array type!* -pub enum CowRepr<'a, A> { +/// [`CowArray`] for the array type!* +pub enum CowRepr<'a, A> +{ /// Borrowed data. View(ViewRepr<&'a A>), /// Owned data. Owned(OwnedRepr), } -impl<'a, A> CowRepr<'a, A> { +impl CowRepr<'_, A> +{ /// Returns `true` iff the data is the `View` variant. - pub fn is_view(&self) -> bool { + pub fn is_view(&self) -> bool + { match self { CowRepr::View(_) => true, CowRepr::Owned(_) => false, @@ -1470,7 +1510,8 @@ impl<'a, A> CowRepr<'a, A> { } /// Returns `true` iff the data is the `Owned` variant. - pub fn is_owned(&self) -> bool { + pub fn is_owned(&self) -> bool + { match self { CowRepr::View(_) => false, CowRepr::Owned(_) => true, @@ -1478,12 +1519,17 @@ impl<'a, A> CowRepr<'a, A> { } } +// NOTE: The order of modules decides in which order methods on the type ArrayBase +// (mainly mentioning that as the most relevant type) show up in the documentation. +// Consider the doc effect of ordering modules here. mod impl_clone; +mod impl_internal_constructors; mod impl_constructors; mod impl_methods; mod impl_owned_array; +mod impl_special_element_types; /// Private Methods impl ArrayBase @@ -1493,8 +1539,7 @@ where { #[inline] fn broadcast_unwrap(&self, dim: E) -> ArrayView<'_, A, E> - where - E: Dimension, + where E: Dimension { #[cold] #[inline(never)] @@ -1520,8 +1565,7 @@ where // (Checked in debug assertions). #[inline] fn broadcast_assume(&self, dim: E) -> ArrayView<'_, A, E> - where - E: Dimension, + where E: Dimension { let dim = dim.into_dimension(); debug_assert_eq!(self.shape(), dim.slice()); @@ -1531,51 +1575,13 @@ where unsafe { ArrayView::new(ptr, dim, strides) } } - fn raw_strides(&self) -> D { - self.strides.clone() - } - - /// Apply closure `f` to each element in the array, in whatever - /// order is the fastest to visit. - fn unordered_foreach_mut(&mut self, mut f: F) - where - S: DataMut, - F: FnMut(&mut A), - { - if let Some(slc) = self.as_slice_memory_order_mut() { - slc.iter_mut().for_each(f); - } else { - for row in self.inner_rows_mut() { - row.into_iter_().fold((), |(), elt| f(elt)); - } - } - } - /// Remove array axis `axis` and return the result. - fn try_remove_axis(self, axis: Axis) -> ArrayBase { + fn try_remove_axis(self, axis: Axis) -> ArrayBase + { let d = self.dim.try_remove_axis(axis); let s = self.strides.try_remove_axis(axis); - ArrayBase { - ptr: self.ptr, - data: self.data, - dim: d, - strides: s, - } - } - - /// n-d generalization of rows, just like inner iter - fn inner_rows(&self) -> iterators::Lanes<'_, A, D::Smaller> { - let n = self.ndim(); - Lanes::new(self.view(), Axis(n.saturating_sub(1))) - } - - /// n-d generalization of rows, just like inner iter - fn inner_rows_mut(&mut self) -> iterators::LanesMut<'_, A, D::Smaller> - where - S: DataMut, - { - let n = self.ndim(); - LanesMut::new(self.view_mut(), Axis(n.saturating_sub(1))) + // safe because new dimension, strides allow access to a subset of old data + unsafe { self.with_strides_dim(s, d) } } } @@ -1594,6 +1600,9 @@ pub mod linalg; mod impl_ops; pub use crate::impl_ops::ScalarOperand; +#[cfg(feature = "approx")] +mod array_approx; + // Array view methods mod impl_views; @@ -1603,24 +1612,14 @@ mod impl_raw_views; // Copy-on-write array methods mod impl_cow; -/// A contiguous array shape of n dimensions. -/// -/// Either c- or f- memory ordered (*c* a.k.a *row major* is the default). -#[derive(Copy, Clone, Debug)] -pub struct Shape { - dim: D, - is_c: bool, -} - -/// An array shape of n dimensions in c-order, f-order or custom strides. -#[derive(Copy, Clone, Debug)] -pub struct StrideShape { - dim: D, - strides: D, - custom: bool, -} +// Arc array methods +mod impl_arc_array; /// Returns `true` if the pointer is aligned. -pub(crate) fn is_aligned(ptr: *const T) -> bool { +pub(crate) fn is_aligned(ptr: *const T) -> bool +{ (ptr as usize) % ::std::mem::align_of::() == 0 } + +// Triangular constructors +mod tri; diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index fd8d77d85..7472d8292 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1,4 +1,4 @@ -// Copyright 2014-2016 bluss and ndarray developers. +// Copyright 2014-2020 bluss and ndarray developers. // // Licensed under the Apache License, Version 2.0 or the MIT license @@ -7,23 +7,28 @@ // except according to those terms. use crate::imp_prelude::*; + +#[cfg(feature = "blas")] +use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr; use crate::numeric_util; use crate::{LinalgScalar, Zip}; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; use std::any::TypeId; +use std::mem::MaybeUninit; + +use num_complex::Complex; +use num_complex::{Complex32 as c32, Complex64 as c64}; #[cfg(feature = "blas")] -use std::cmp; -#[cfg(feature = "blas")] -use std::mem::swap; -#[cfg(feature = "blas")] -use std::os::raw::c_int; +use libc::c_int; #[cfg(feature = "blas")] use cblas_sys as blas_sys; #[cfg(feature = "blas")] -use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT}; +use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT, CBLAS_TRANSPOSE}; /// len of vector before we use blas #[cfg(feature = "blas")] @@ -36,8 +41,7 @@ const GEMM_BLAS_CUTOFF: usize = 7; type blas_index = c_int; // blas index type impl ArrayBase -where - S: Data, +where S: Data { /// Perform dot product or matrix multiplication of arrays `self` and `rhs`. /// @@ -56,9 +60,9 @@ where /// **Panics** if the array shapes are incompatible.
/// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory /// layout allows. + #[track_caller] pub fn dot(&self, rhs: &Rhs) -> >::Output - where - Self: Dot, + where Self: Dot { Dot::dot(self, rhs) } @@ -137,11 +141,8 @@ where /// which agrees with our pointer for non-negative strides, but /// is at the opposite end for negative strides. #[cfg(feature = "blas")] -unsafe fn blas_1d_params
( - ptr: *const A, - len: usize, - stride: isize, -) -> (*const A, blas_index, blas_index) { +unsafe fn blas_1d_params(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index) +{ // [x x x x] // ^--ptr // stride = -1 @@ -158,7 +159,8 @@ unsafe fn blas_1d_params( /// /// For two-dimensional arrays, the dot method computes the matrix /// multiplication. -pub trait Dot { +pub trait Dot +{ /// The result of the operation. /// /// For two-dimensional arrays: a rectangular array. @@ -182,7 +184,9 @@ where /// **Panics** if the arrays are not of the same length.
/// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory /// layout allows. - fn dot(&self, rhs: &ArrayBase) -> A { + #[track_caller] + fn dot(&self, rhs: &ArrayBase) -> A + { self.dot_impl(rhs) } } @@ -204,14 +208,15 @@ where /// Return a result array with shape *N*. /// /// **Panics** if shapes are incompatible. - fn dot(&self, rhs: &ArrayBase) -> Array { + #[track_caller] + fn dot(&self, rhs: &ArrayBase) -> Array + { rhs.t().dot(self) } } impl ArrayBase -where - S: Data, +where S: Data { /// Perform matrix multiplication of rectangular arrays `self` and `rhs`. /// @@ -242,9 +247,9 @@ where /// [2., 3.]]) /// ); /// ``` + #[track_caller] pub fn dot(&self, rhs: &Rhs) -> >::Output - where - Self: Dot, + where Self: Dot { Dot::dot(self, rhs) } @@ -257,7 +262,8 @@ where A: LinalgScalar, { type Output = Array2
; - fn dot(&self, b: &ArrayBase) -> Array2 { + fn dot(&self, b: &ArrayBase) -> Array2 + { let a = self.view(); let b = b.view(); let ((m, k), (k2, n)) = (a.dim(), b.dim()); @@ -283,9 +289,10 @@ where /// Assumes that `m` and `n` are ≤ `isize::MAX`. #[cold] #[inline(never)] -fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! { +fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! +{ match m.checked_mul(n) { - Some(len) if len <= ::std::isize::MAX as usize => {} + Some(len) if len <= isize::MAX as usize => {} _ => panic!("ndarray: shape {} × {} overflows isize", m, n), } panic!( @@ -296,7 +303,8 @@ fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! { #[cold] #[inline(never)] -fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! { +fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! +{ panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication", m, k, k2, n, c1, c2); } @@ -317,7 +325,9 @@ where A: LinalgScalar, { type Output = Array; - fn dot(&self, rhs: &ArrayBase) -> Array { + #[track_caller] + fn dot(&self, rhs: &ArrayBase) -> Array + { let ((m, a), n) = (self.dim(), rhs.dim()); if a != n { dot_shape_error(m, a, n, 1); @@ -325,9 +335,9 @@ where // Avoid initializing the memory in vec -- set it during iteration unsafe { - let mut c = Array::uninitialized(m); - general_mat_vec_mul(A::one(), self, rhs, A::zero(), &mut c); - c + let mut c = Array1::uninit(m); + general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::()); + c.assume_init() } } } @@ -344,6 +354,7 @@ where /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. /// /// **Panics** if broadcasting isn’t possible. + #[track_caller] pub fn scaled_add(&mut self, alpha: A, rhs: &ArrayBase) where S: DataMut, @@ -361,109 +372,95 @@ where use self::mat_mul_general as mat_mul_impl; #[cfg(feature = "blas")] -fn mat_mul_impl( - alpha: A, - lhs: &ArrayView2<'_, A>, - rhs: &ArrayView2<'_, A>, - beta: A, - c: &mut ArrayViewMut2<'_, A>, -) where - A: LinalgScalar, +fn mat_mul_impl(alpha: A, a: &ArrayView2<'_, A>, b: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>) +where A: LinalgScalar { - // size cutoff for using BLAS - let cut = GEMM_BLAS_CUTOFF; - let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim()); - if !(m > cut || n > cut || a > cut) || !(same_type::() || same_type::()) { - return mat_mul_general(alpha, lhs, rhs, beta, c); - } + let ((m, k), (k2, n)) = (a.dim(), b.dim()); + debug_assert_eq!(k, k2); + if (m > GEMM_BLAS_CUTOFF || n > GEMM_BLAS_CUTOFF || k > GEMM_BLAS_CUTOFF) + && (same_type::() || same_type::() || same_type::() || same_type::()) { - // Use `c` for c-order and `f` for an f-order matrix - // We can handle c * c, f * f generally and - // c * f and f * c if the `f` matrix is square. - let mut lhs_ = lhs.view(); - let mut rhs_ = rhs.view(); - let mut c_ = c.view_mut(); - let lhs_s0 = lhs_.strides()[0]; - let rhs_s0 = rhs_.strides()[0]; - let both_f = lhs_s0 == 1 && rhs_s0 == 1; - let mut lhs_trans = CblasNoTrans; - let mut rhs_trans = CblasNoTrans; - if both_f { - // A^t B^t = C^t => B A = C - let lhs_t = lhs_.reversed_axes(); - lhs_ = rhs_.reversed_axes(); - rhs_ = lhs_t; - c_ = c_.reversed_axes(); - swap(&mut m, &mut n); - } else if lhs_s0 == 1 && m == a { - lhs_ = lhs_.reversed_axes(); - lhs_trans = CblasTrans; - } else if rhs_s0 == 1 && a == n { - rhs_ = rhs_.reversed_axes(); - rhs_trans = CblasTrans; - } + // Compute A B -> C + // We require for BLAS compatibility that: + // A, B, C are contiguous (stride=1) in their fastest dimension, + // but they can be either row major/"c" or col major/"f". + // + // The "normal case" is CblasRowMajor for cblas. + // Select CblasRowMajor / CblasColMajor to fit C's memory order. + // + // Apply transpose to A, B as needed if they differ from the row major case. + // If C is CblasColMajor then transpose both A, B (again!) + + if let (Some(a_layout), Some(b_layout), Some(c_layout)) = + (get_blas_compatible_layout(a), get_blas_compatible_layout(b), get_blas_compatible_layout(c)) + { + let cblas_layout = c_layout.to_cblas_layout(); + let a_trans = a_layout.to_cblas_transpose_for(cblas_layout); + let lda = blas_stride(&a, a_layout); + + let b_trans = b_layout.to_cblas_transpose_for(cblas_layout); + let ldb = blas_stride(&b, b_layout); - macro_rules! gemm { - ($ty:ty, $gemm:ident) => { - if blas_row_major_2d::<$ty, _>(&lhs_) - && blas_row_major_2d::<$ty, _>(&rhs_) - && blas_row_major_2d::<$ty, _>(&c_) - { - let (m, k) = match lhs_trans { - CblasNoTrans => lhs_.dim(), - _ => { - let (rows, cols) = lhs_.dim(); - (cols, rows) + let ldc = blas_stride(&c, c_layout); + + macro_rules! gemm_scalar_cast { + (f32, $var:ident) => { + cast_as(&$var) + }; + (f64, $var:ident) => { + cast_as(&$var) + }; + (c32, $var:ident) => { + &$var as *const A as *const _ + }; + (c64, $var:ident) => { + &$var as *const A as *const _ + }; + } + + macro_rules! gemm { + ($ty:tt, $gemm:ident) => { + if same_type::() { + // gemm is C ← αA^Op B^Op + βC + // Where Op is notrans/trans/conjtrans + unsafe { + blas_sys::$gemm( + cblas_layout, + a_trans, + b_trans, + m as blas_index, // m, rows of Op(a) + n as blas_index, // n, cols of Op(b) + k as blas_index, // k, cols of Op(a) + gemm_scalar_cast!($ty, alpha), // alpha + a.ptr.as_ptr() as *const _, // a + lda, // lda + b.ptr.as_ptr() as *const _, // b + ldb, // ldb + gemm_scalar_cast!($ty, beta), // beta + c.ptr.as_ptr() as *mut _, // c + ldc, // ldc + ); } - }; - let n = match rhs_trans { - CblasNoTrans => rhs_.raw_dim()[1], - _ => rhs_.raw_dim()[0], - }; - // adjust strides, these may [1, 1] for column matrices - let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index); - let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index); - let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index); - - // gemm is C ← αA^Op B^Op + βC - // Where Op is notrans/trans/conjtrans - unsafe { - blas_sys::$gemm( - CblasRowMajor, - lhs_trans, - rhs_trans, - m as blas_index, // m, rows of Op(a) - n as blas_index, // n, cols of Op(b) - k as blas_index, // k, cols of Op(a) - cast_as(&alpha), // alpha - lhs_.ptr.as_ptr() as *const _, // a - lhs_stride, // lda - rhs_.ptr.as_ptr() as *const _, // b - rhs_stride, // ldb - cast_as(&beta), // beta - c_.ptr.as_ptr() as *mut _, // c - c_stride, // ldc - ); + return; } - return; - } - }; + }; + } + + gemm!(f32, cblas_sgemm); + gemm!(f64, cblas_dgemm); + gemm!(c32, cblas_cgemm); + gemm!(c64, cblas_zgemm); + + unreachable!() // we checked above that A is one of f32, f64, c32, c64 } - gemm!(f32, cblas_sgemm); - gemm!(f64, cblas_dgemm); } - mat_mul_general(alpha, lhs, rhs, beta, c) + mat_mul_general(alpha, a, b, beta, c) } /// C ← α A B + β C fn mat_mul_general( - alpha: A, - lhs: &ArrayView2<'_, A>, - rhs: &ArrayView2<'_, A>, - beta: A, - c: &mut ArrayViewMut2<'_, A>, -) where - A: LinalgScalar, + alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>, +) where A: LinalgScalar { let ((m, k), (_, n)) = (lhs.dim(), rhs.dim()); @@ -474,7 +471,7 @@ fn mat_mul_general( let (rsc, csc) = (c.strides()[0], c.strides()[1]); if same_type::() { unsafe { - ::matrixmultiply::sgemm( + matrixmultiply::sgemm( m, k, n, @@ -493,7 +490,7 @@ fn mat_mul_general( } } else if same_type::() { unsafe { - ::matrixmultiply::dgemm( + matrixmultiply::dgemm( m, k, n, @@ -510,6 +507,48 @@ fn mat_mul_general( csc, ); } + } else if same_type::() { + unsafe { + matrixmultiply::cgemm( + matrixmultiply::CGemmOption::Standard, + matrixmultiply::CGemmOption::Standard, + m, + k, + n, + complex_array(cast_as(&alpha)), + ap as *const _, + lhs.strides()[0], + lhs.strides()[1], + bp as *const _, + rhs.strides()[0], + rhs.strides()[1], + complex_array(cast_as(&beta)), + cp as *mut _, + rsc, + csc, + ); + } + } else if same_type::() { + unsafe { + matrixmultiply::zgemm( + matrixmultiply::CGemmOption::Standard, + matrixmultiply::CGemmOption::Standard, + m, + k, + n, + complex_array(cast_as(&alpha)), + ap as *const _, + lhs.strides()[0], + lhs.strides()[1], + bp as *const _, + rhs.strides()[0], + rhs.strides()[1], + complex_array(cast_as(&beta)), + cp as *mut _, + rsc, + csc, + ); + } } else { // It's a no-op if `c` has zero length. if c.is_empty() { @@ -526,11 +565,8 @@ fn mat_mul_general( loop { unsafe { let elt = c.uget_mut((i, j)); - *elt = *elt * beta - + alpha - * (0..k).fold(A::zero(), move |s, x| { - s + *lhs.uget((i, x)) * *rhs.uget((x, j)) - }); + *elt = + *elt * beta + alpha * (0..k).fold(A::zero(), move |s, x| s + *lhs.uget((i, x)) * *rhs.uget((x, j))); } j += 1; if j == n { @@ -555,12 +591,9 @@ fn mat_mul_general( /// *Note:* If enabled, uses blas `gemm` for elements of `f32, f64` when memory /// layout allows. The default matrixmultiply backend is otherwise used for /// `f32, f64` for all memory layouts. +#[track_caller] pub fn general_mat_mul( - alpha: A, - a: &ArrayBase, - b: &ArrayBase, - beta: A, - c: &mut ArrayBase, + alpha: A, a: &ArrayBase, b: &ArrayBase, beta: A, c: &mut ArrayBase, ) where S1: Data, S2: Data, @@ -586,18 +619,34 @@ pub fn general_mat_mul( /// ***Panics*** if array shapes are not compatible
/// *Note:* If enabled, uses blas `gemv` for elements of `f32, f64` when memory /// layout allows. +#[track_caller] #[allow(clippy::collapsible_if)] pub fn general_mat_vec_mul( - alpha: A, - a: &ArrayBase, - x: &ArrayBase, - beta: A, - y: &mut ArrayBase, + alpha: A, a: &ArrayBase, x: &ArrayBase, beta: A, y: &mut ArrayBase, ) where S1: Data, S2: Data, S3: DataMut, A: LinalgScalar, +{ + unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) } +} + +/// General matrix-vector multiplication +/// +/// Use a raw view for the destination vector, so that it can be uninitialized. +/// +/// ## Safety +/// +/// The caller must ensure that the raw view is valid for writing. +/// the destination may be uninitialized iff beta is zero. +#[allow(clippy::collapsible_else_if)] +unsafe fn general_mat_vec_mul_impl( + alpha: A, a: &ArrayBase, x: &ArrayBase, beta: A, y: RawArrayViewMut, +) where + S1: Data, + S2: Data, + A: LinalgScalar, { let ((m, k), k2) = (a.dim(), x.dim()); let m2 = y.dim(); @@ -607,42 +656,43 @@ pub fn general_mat_vec_mul( #[cfg(feature = "blas")] macro_rules! gemv { ($ty:ty, $gemv:ident) => { - if let Some(layout) = blas_layout::<$ty, _>(&a) { - if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) { - // Determine stride between rows or columns. Note that the stride is - // adjusted to at least `k` or `m` to handle the case of a matrix with a - // trivial (length 1) dimension, since the stride for the trivial dimension - // may be arbitrary. - let a_trans = CblasNoTrans; - let a_stride = match layout { - CBLAS_LAYOUT::CblasRowMajor => { - a.strides()[0].max(k as isize) as blas_index - } - CBLAS_LAYOUT::CblasColMajor => { - a.strides()[1].max(m as isize) as blas_index - } - }; - - let x_stride = x.strides()[0] as blas_index; - let y_stride = y.strides()[0] as blas_index; + if same_type::() { + if let Some(layout) = get_blas_compatible_layout(&a) { + if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) { + // Determine stride between rows or columns. Note that the stride is + // adjusted to at least `k` or `m` to handle the case of a matrix with a + // trivial (length 1) dimension, since the stride for the trivial dimension + // may be arbitrary. + let a_trans = CblasNoTrans; + + let a_stride = blas_stride(&a, layout); + let cblas_layout = layout.to_cblas_layout(); + + // Low addr in memory pointers required for x, y + let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides); + let x_ptr = x.ptr.as_ptr().sub(x_offset); + let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides); + let y_ptr = y.ptr.as_ptr().sub(y_offset); + + let x_stride = x.strides()[0] as blas_index; + let y_stride = y.strides()[0] as blas_index; - unsafe { blas_sys::$gemv( - layout, + cblas_layout, a_trans, m as blas_index, // m, rows of Op(a) k as blas_index, // n, cols of Op(a) cast_as(&alpha), // alpha a.ptr.as_ptr() as *const _, // a a_stride, // lda - x.ptr.as_ptr() as *const _, // x + x_ptr as *const _, // x x_stride, - cast_as(&beta), // beta - y.ptr.as_ptr() as *mut _, // x + cast_as(&beta), // beta + y_ptr as *mut _, // y y_stride, ); + return; } - return; } } }; @@ -655,153 +705,285 @@ pub fn general_mat_vec_mul( /* general */ if beta.is_zero() { - Zip::from(a.outer_iter()).and(y).apply(|row, elt| { - *elt = row.dot(x) * alpha; + // when beta is zero, c may be uninitialized + Zip::from(a.outer_iter()).and(y).for_each(|row, elt| { + elt.write(row.dot(x) * alpha); }); } else { - Zip::from(a.outer_iter()).and(y).apply(|row, elt| { + Zip::from(a.outer_iter()).and(y).for_each(|row, elt| { *elt = *elt * beta + row.dot(x) * alpha; }); } } } +/// Kronecker product of 2D matrices. +/// +/// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R) +/// matrix K formed by the block multiplication A_ij * B. +pub fn kron(a: &ArrayBase, b: &ArrayBase) -> Array +where + S1: Data, + S2: Data, + A: LinalgScalar, +{ + let dimar = a.shape()[0]; + let dimac = a.shape()[1]; + let dimbr = b.shape()[0]; + let dimbc = b.shape()[1]; + let mut out: Array2> = Array2::uninit(( + dimar + .checked_mul(dimbr) + .expect("Dimensions of kronecker product output array overflows usize."), + dimac + .checked_mul(dimbc) + .expect("Dimensions of kronecker product output array overflows usize."), + )); + Zip::from(out.exact_chunks_mut((dimbr, dimbc))) + .and(a) + .for_each(|out, &a| { + Zip::from(out).and(b).for_each(|out, &b| { + *out = MaybeUninit::new(a * b); + }) + }); + unsafe { out.assume_init() } +} + #[inline(always)] /// Return `true` if `A` and `B` are the same type -fn same_type() -> bool { +fn same_type() -> bool +{ TypeId::of::
() == TypeId::of::() } // Read pointer to type `A` as type `B`. // // **Panics** if `A` and `B` are not the same type -fn cast_as(a: &A) -> B { - assert!(same_type::()); +fn cast_as(a: &A) -> B +{ + assert!(same_type::(), "expect type {} and {} to match", + std::any::type_name::(), std::any::type_name::()); unsafe { ::std::ptr::read(a as *const _ as *const B) } } +/// Return the complex in the form of an array [re, im] +#[inline] +fn complex_array(z: Complex) -> [A; 2] +{ + [z.re, z.im] +} + #[cfg(feature = "blas")] fn blas_compat_1d(a: &ArrayBase) -> bool where - S: Data, + S: RawData, A: 'static, S::Elem: 'static, { if !same_type::() { return false; } - if a.len() > blas_index::max_value() as usize { + if a.len() > blas_index::MAX as usize { return false; } let stride = a.strides()[0]; - if stride > blas_index::max_value() as isize || stride < blas_index::min_value() as isize { + if stride == 0 || stride > blas_index::MAX as isize || stride < blas_index::MIN as isize { return false; } true } #[cfg(feature = "blas")] -enum MemoryOrder { +#[derive(Copy, Clone)] +#[cfg_attr(test, derive(PartialEq, Eq, Debug))] +enum BlasOrder +{ C, F, } #[cfg(feature = "blas")] -fn blas_row_major_2d(a: &ArrayBase) -> bool -where - S: Data, - A: 'static, - S::Elem: 'static, +impl BlasOrder { - if !same_type::() { - return false; + fn transpose(self) -> Self + { + match self { + Self::C => Self::F, + Self::F => Self::C, + } } - is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) -} -#[cfg(feature = "blas")] -fn blas_column_major_2d(a: &ArrayBase) -> bool -where - S: Data, - A: 'static, - S::Elem: 'static, -{ - if !same_type::() { - return false; + #[inline] + /// Axis of leading stride (opposite of contiguous axis) + fn get_blas_lead_axis(self) -> usize + { + match self { + Self::C => 0, + Self::F => 1, + } + } + + fn to_cblas_layout(self) -> CBLAS_LAYOUT + { + match self { + Self::C => CBLAS_LAYOUT::CblasRowMajor, + Self::F => CBLAS_LAYOUT::CblasColMajor, + } + } + + /// When using cblas_sgemm (etc) with C matrix using `for_layout`, + /// how should this `self` matrix be transposed + fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE + { + let effective_order = match for_layout { + CBLAS_LAYOUT::CblasRowMajor => self, + CBLAS_LAYOUT::CblasColMajor => self.transpose(), + }; + + match effective_order { + Self::C => CblasNoTrans, + Self::F => CblasTrans, + } } - is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) } #[cfg(feature = "blas")] -fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool { +fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool +{ let (m, n) = dim.into_pattern(); let s0 = stride[0] as isize; let s1 = stride[1] as isize; - let (inner_stride, outer_dim) = match order { - MemoryOrder::C => (s1, n), - MemoryOrder::F => (s0, m), + let (inner_stride, outer_stride, inner_dim, outer_dim) = match order { + BlasOrder::C => (s1, s0, m, n), + BlasOrder::F => (s0, s1, n, m), }; + if !(inner_stride == 1 || outer_dim == 1) { return false; } + if s0 < 1 || s1 < 1 { return false; } - if (s0 > blas_index::max_value() as isize || s0 < blas_index::min_value() as isize) - || (s1 > blas_index::max_value() as isize || s1 < blas_index::min_value() as isize) + + if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize) + || (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize) { return false; } - if m > blas_index::max_value() as usize || n > blas_index::max_value() as usize { + + // leading stride must >= the dimension (no broadcasting/aliasing) + if inner_dim > 1 && (outer_stride as usize) < outer_dim { return false; } + + if m > blas_index::MAX as usize || n > blas_index::MAX as usize { + return false; + } + true } +/// Get BLAS compatible layout if any (C or F, preferring the former) +#[cfg(feature = "blas")] +fn get_blas_compatible_layout(a: &ArrayBase) -> Option +where S: Data +{ + if is_blas_2d(&a.dim, &a.strides, BlasOrder::C) { + Some(BlasOrder::C) + } else if is_blas_2d(&a.dim, &a.strides, BlasOrder::F) { + Some(BlasOrder::F) + } else { + None + } +} + +/// `a` should be blas compatible. +/// axis: 0 or 1. +/// +/// Return leading stride (lda, ldb, ldc) of array +#[cfg(feature = "blas")] +fn blas_stride(a: &ArrayBase, order: BlasOrder) -> blas_index +where S: Data +{ + let axis = order.get_blas_lead_axis(); + let other_axis = 1 - axis; + let len_this = a.shape()[axis]; + let len_other = a.shape()[other_axis]; + let stride = a.strides()[axis]; + + // if current axis has length == 1, then stride does not matter for ndarray + // but for BLAS we need a stride that makes sense, i.e. it's >= the other axis + + // cast: a should already be blas compatible + (if len_this <= 1 { + Ord::max(stride, len_other as isize) + } else { + stride + }) as blas_index +} + +#[cfg(test)] #[cfg(feature = "blas")] -fn blas_layout(a: &ArrayBase) -> Option +fn blas_row_major_2d(a: &ArrayBase) -> bool where S: Data, A: 'static, S::Elem: 'static, { - if blas_row_major_2d::(a) { - Some(CBLAS_LAYOUT::CblasRowMajor) - } else if blas_column_major_2d::(a) { - Some(CBLAS_LAYOUT::CblasColMajor) - } else { - None + if !same_type::() { + return false; + } + is_blas_2d(&a.dim, &a.strides, BlasOrder::C) +} + +#[cfg(test)] +#[cfg(feature = "blas")] +fn blas_column_major_2d(a: &ArrayBase) -> bool +where + S: Data, + A: 'static, + S::Elem: 'static, +{ + if !same_type::() { + return false; } + is_blas_2d(&a.dim, &a.strides, BlasOrder::F) } #[cfg(test)] #[cfg(feature = "blas")] -mod blas_tests { +mod blas_tests +{ use super::*; #[test] - fn blas_row_major_2d_normal_matrix() { + fn blas_row_major_2d_normal_matrix() + { let m: Array2 = Array2::zeros((3, 5)); assert!(blas_row_major_2d::(&m)); assert!(!blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_row_matrix() { + fn blas_row_major_2d_row_matrix() + { let m: Array2 = Array2::zeros((1, 5)); assert!(blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_column_matrix() { + fn blas_row_major_2d_column_matrix() + { let m: Array2 = Array2::zeros((5, 1)); assert!(blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_transposed_row_matrix() { + fn blas_row_major_2d_transposed_row_matrix() + { let m: Array2 = Array2::zeros((1, 5)); let m_t = m.t(); assert!(blas_row_major_2d::(&m_t)); @@ -809,7 +991,8 @@ mod blas_tests { } #[test] - fn blas_row_major_2d_transposed_column_matrix() { + fn blas_row_major_2d_transposed_column_matrix() + { let m: Array2 = Array2::zeros((5, 1)); let m_t = m.t(); assert!(blas_row_major_2d::(&m_t)); @@ -817,9 +1000,70 @@ mod blas_tests { } #[test] - fn blas_column_major_2d_normal_matrix() { + fn blas_column_major_2d_normal_matrix() + { let m: Array2 = Array2::zeros((3, 5).f()); assert!(!blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } + + #[test] + fn blas_row_major_2d_skip_rows_ok() + { + let m: Array2 = Array2::zeros((5, 5)); + let mv = m.slice(s![..;2, ..]); + assert!(blas_row_major_2d::(&mv)); + assert!(!blas_column_major_2d::(&mv)); + } + + #[test] + fn blas_row_major_2d_skip_columns_fail() + { + let m: Array2 = Array2::zeros((5, 5)); + let mv = m.slice(s![.., ..;2]); + assert!(!blas_row_major_2d::(&mv)); + assert!(!blas_column_major_2d::(&mv)); + } + + #[test] + fn blas_col_major_2d_skip_columns_ok() + { + let m: Array2 = Array2::zeros((5, 5).f()); + let mv = m.slice(s![.., ..;2]); + assert!(blas_column_major_2d::(&mv)); + assert!(!blas_row_major_2d::(&mv)); + } + + #[test] + fn blas_col_major_2d_skip_rows_fail() + { + let m: Array2 = Array2::zeros((5, 5).f()); + let mv = m.slice(s![..;2, ..]); + assert!(!blas_column_major_2d::(&mv)); + assert!(!blas_row_major_2d::(&mv)); + } + + #[test] + fn blas_too_short_stride() + { + // leading stride must be longer than the other dimension + // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS. + + const N: usize = 5; + const MAXSTRIDE: usize = N + 2; + let mut data = [0; MAXSTRIDE * N]; + let mut iter = 0..data.len(); + data.fill_with(|| iter.next().unwrap()); + + for stride in 1..=MAXSTRIDE { + let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap(); + eprintln!("{:?}", m); + + if stride < N { + assert_eq!(get_blas_compatible_layout(&m), None); + } else { + assert_eq!(get_blas_compatible_layout(&m), Some(BlasOrder::C)); + } + } + } } diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 8575905cd..dc6964f9b 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -10,6 +10,7 @@ pub use self::impl_linalg::general_mat_mul; pub use self::impl_linalg::general_mat_vec_mul; +pub use self::impl_linalg::kron; pub use self::impl_linalg::Dot; mod impl_linalg; diff --git a/src/linalg_traits.rs b/src/linalg_traits.rs index a7f5a1a3e..65d264c40 100644 --- a/src/linalg_traits.rs +++ b/src/linalg_traits.rs @@ -5,39 +5,31 @@ // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. -use crate::ScalarOperand; -use num_traits::{Float, One, Zero}; + +#[cfg(feature = "std")] +use num_traits::Float; +use num_traits::{One, Zero}; + +#[cfg(feature = "std")] use std::fmt; use std::ops::{Add, Div, Mul, Sub}; +#[cfg(feature = "std")] use std::ops::{AddAssign, DivAssign, MulAssign, RemAssign, SubAssign}; +#[cfg(feature = "std")] +use crate::ScalarOperand; + /// Elements that support linear algebra operations. /// /// `'static` for type-based specialization, `Copy` so that they don't need move /// semantics or destructors, and the rest are numerical traits. pub trait LinalgScalar: - 'static - + Copy - + Zero - + One - + Add - + Sub - + Mul - + Div + 'static + Copy + Zero + One + Add + Sub + Mul + Div { } -impl LinalgScalar for T where - T: 'static - + Copy - + Zero - + One - + Add - + Sub - + Mul - + Div -{ -} +impl LinalgScalar for T where T: 'static + Copy + Zero + One + Add + Sub + Mul + Div +{} /// Floating-point element types `f32` and `f64`. /// @@ -47,6 +39,7 @@ impl LinalgScalar for T where /// operations (`ScalarOperand`). /// /// This trait can only be implemented by `f32` and `f64`. +#[cfg(feature = "std")] pub trait NdFloat: Float + AddAssign @@ -65,5 +58,7 @@ pub trait NdFloat: { } +#[cfg(feature = "std")] impl NdFloat for f32 {} +#[cfg(feature = "std")] impl NdFloat for f64 {} diff --git a/src/linspace.rs b/src/linspace.rs index ca0eae470..411c480db 100644 --- a/src/linspace.rs +++ b/src/linspace.rs @@ -5,12 +5,14 @@ // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. +#![cfg(feature = "std")] use num_traits::Float; /// An iterator of a sequence of evenly spaced floats. /// /// Iterator element type is `F`. -pub struct Linspace { +pub struct Linspace +{ start: F, step: F, index: usize, @@ -18,13 +20,13 @@ pub struct Linspace { } impl Iterator for Linspace -where - F: Float, +where F: Float { type Item = F; #[inline] - fn next(&mut self) -> Option { + fn next(&mut self) -> Option + { if self.index >= self.len { None } else { @@ -36,18 +38,19 @@ where } #[inline] - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { let n = self.len - self.index; (n, Some(n)) } } impl DoubleEndedIterator for Linspace -where - F: Float, +where F: Float { #[inline] - fn next_back(&mut self) -> Option { + fn next_back(&mut self) -> Option + { if self.index >= self.len { None } else { @@ -65,14 +68,13 @@ impl ExactSizeIterator for Linspace where Linspace: Iterator {} /// /// The `Linspace` has `n` elements from `a` to `b` (inclusive). /// -/// The iterator element type is `F`, where `F` must implement `Float`, e.g. -/// `f32` or `f64`. +/// The iterator element type is `F`, where `F` must implement [`Float`], e.g. +/// [`f32`] or [`f64`]. /// /// **Panics** if converting `n - 1` to type `F` fails. #[inline] pub fn linspace(a: F, b: F, n: usize) -> Linspace -where - F: Float, +where F: Float { let step = if n > 1 { let num_steps = F::from(n - 1).expect("Converting number of steps to `A` must not fail."); @@ -88,19 +90,18 @@ where } } -/// Return an iterator of floats from `start` to `end` (exclusive), +/// Return an iterator of floats from `a` to `b` (exclusive), /// incrementing by `step`. /// /// Numerical reasons can result in `b` being included in the result. /// -/// The iterator element type is `F`, where `F` must implement `Float`, e.g. -/// `f32` or `f64`. +/// The iterator element type is `F`, where `F` must implement [`Float`], e.g. +/// [`f32`] or [`f64`]. /// /// **Panics** if converting `((b - a) / step).ceil()` to type `F` fails. #[inline] pub fn range(a: F, b: F, step: F) -> Linspace -where - F: Float, +where F: Float { let len = b - a; let steps = F::ceil(len / step); diff --git a/src/logspace.rs b/src/logspace.rs index 55b5397c8..6f8de885d 100644 --- a/src/logspace.rs +++ b/src/logspace.rs @@ -5,12 +5,14 @@ // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. +#![cfg(feature = "std")] use num_traits::Float; /// An iterator of a sequence of logarithmically spaced number. /// /// Iterator element type is `F`. -pub struct Logspace { +pub struct Logspace +{ sign: F, base: F, start: F, @@ -20,13 +22,13 @@ pub struct Logspace { } impl Iterator for Logspace -where - F: Float, +where F: Float { type Item = F; #[inline] - fn next(&mut self) -> Option { + fn next(&mut self) -> Option + { if self.index >= self.len { None } else { @@ -39,18 +41,19 @@ where } #[inline] - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { let n = self.len - self.index; (n, Some(n)) } } impl DoubleEndedIterator for Logspace -where - F: Float, +where F: Float { #[inline] - fn next_back(&mut self) -> Option { + fn next_back(&mut self) -> Option + { if self.index >= self.len { None } else { @@ -67,18 +70,17 @@ impl ExactSizeIterator for Logspace where Logspace: Iterator {} /// An iterator of a sequence of logarithmically spaced numbers. /// -/// The `Logspace` has `n` elements, where the first element is `base.powf(a)` +/// The [`Logspace`] has `n` elements, where the first element is `base.powf(a)` /// and the last element is `base.powf(b)`. If `base` is negative, this /// iterator will return all negative values. /// -/// The iterator element type is `F`, where `F` must implement `Float`, e.g. -/// `f32` or `f64`. +/// The iterator element type is `F`, where `F` must implement [`Float`], e.g. +/// [`f32`] or [`f64`]. /// /// **Panics** if converting `n - 1` to type `F` fails. #[inline] pub fn logspace(base: F, a: F, b: F, n: usize) -> Logspace -where - F: Float, +where F: Float { let step = if n > 1 { let num_steps = F::from(n - 1).expect("Converting number of steps to `A` must not fail."); @@ -97,12 +99,14 @@ where } #[cfg(test)] -mod tests { +mod tests +{ use super::logspace; #[test] #[cfg(feature = "approx")] - fn valid() { + fn valid() + { use crate::{arr1, Array1}; use approx::assert_abs_diff_eq; @@ -120,7 +124,8 @@ mod tests { } #[test] - fn iter_forward() { + fn iter_forward() + { let mut iter = logspace(10.0f64, 0.0, 3.0, 4); assert!(iter.size_hint() == (4, Some(4))); @@ -135,7 +140,8 @@ mod tests { } #[test] - fn iter_backward() { + fn iter_backward() + { let mut iter = logspace(10.0f64, 0.0, 3.0, 4); assert!(iter.size_hint() == (4, Some(4))); diff --git a/src/low_level_util.rs b/src/low_level_util.rs new file mode 100644 index 000000000..5a615a187 --- /dev/null +++ b/src/low_level_util.rs @@ -0,0 +1,43 @@ +// Copyright 2021 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +/// Guard value that will abort if it is dropped. +/// To defuse, this value must be forgotten before the end of the scope. +/// +/// The string value is added to the message printed if aborting. +#[must_use] +pub(crate) struct AbortIfPanic(pub(crate) &'static &'static str); + +impl AbortIfPanic +{ + /// Defuse the AbortIfPanic guard. This *must* be done when finished. + #[inline] + pub(crate) fn defuse(self) + { + std::mem::forget(self); + } +} + +impl Drop for AbortIfPanic +{ + // The compiler should be able to remove this, if it can see through that there + // is no panic in the code section. + fn drop(&mut self) + { + #[cfg(feature = "std")] + { + eprintln!("ndarray: panic in no-panic section, aborting: {}", self.0); + std::process::abort() + } + #[cfg(not(feature = "std"))] + { + // no-std uses panic-in-panic (should abort) + panic!("ndarray: panic in no-panic section, bailing out: {}", self.0); + } + } +} diff --git a/src/macro_utils.rs b/src/macro_utils.rs index 0480b7c91..75360de37 100644 --- a/src/macro_utils.rs +++ b/src/macro_utils.rs @@ -9,7 +9,7 @@ macro_rules! copy_and_clone { }; ($type_:ty) => { copy_and_clone!{ [] $type_ } - } + }; } macro_rules! clone_bounds { @@ -38,7 +38,7 @@ macro_rules! clone_bounds { /// debug assertions are enabled). #[cfg(debug_assertions)] macro_rules! ndassert { - ($e:expr, $($t:tt)*) => { assert!($e, $($t)*) } + ($e:expr, $($t:tt)*) => { assert!($e, $($t)*) }; } #[cfg(not(debug_assertions))] diff --git a/src/math_cell.rs b/src/math_cell.rs new file mode 100644 index 000000000..6ed1ed71f --- /dev/null +++ b/src/math_cell.rs @@ -0,0 +1,133 @@ +use std::cell::Cell; +use std::cmp::Ordering; +use std::fmt; + +use std::ops::{Deref, DerefMut}; + +/// A transparent wrapper of [`Cell`](std::cell::Cell) which is identical in every way, except +/// it will implement arithmetic operators as well. +/// +/// The purpose of `MathCell` is to be used from [.cell_view()](crate::ArrayBase::cell_view). +/// The `MathCell` derefs to `Cell`, so all the cell's methods are available. +#[repr(transparent)] +#[derive(Default)] +pub struct MathCell(Cell); + +impl MathCell +{ + /// Create a new cell with the given value + #[inline(always)] + pub const fn new(value: T) -> Self + { + MathCell(Cell::new(value)) + } + + /// Return the inner value + pub fn into_inner(self) -> T + { + Cell::into_inner(self.0) + } + + /// Swap value with another cell + pub fn swap(&self, other: &Self) + { + Cell::swap(&self.0, &other.0) + } +} + +impl Deref for MathCell +{ + type Target = Cell; + #[inline(always)] + fn deref(&self) -> &Self::Target + { + &self.0 + } +} + +impl DerefMut for MathCell +{ + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target + { + &mut self.0 + } +} + +impl Clone for MathCell +where T: Copy +{ + fn clone(&self) -> Self + { + MathCell::new(self.get()) + } +} + +impl PartialEq for MathCell +where T: Copy + PartialEq +{ + fn eq(&self, rhs: &Self) -> bool + { + self.get() == rhs.get() + } +} + +impl Eq for MathCell where T: Copy + Eq {} + +impl PartialOrd for MathCell +where T: Copy + PartialOrd +{ + fn partial_cmp(&self, rhs: &Self) -> Option + { + self.get().partial_cmp(&rhs.get()) + } + + fn lt(&self, rhs: &Self) -> bool + { + self.get().lt(&rhs.get()) + } + fn le(&self, rhs: &Self) -> bool + { + self.get().le(&rhs.get()) + } + fn gt(&self, rhs: &Self) -> bool + { + self.get().gt(&rhs.get()) + } + fn ge(&self, rhs: &Self) -> bool + { + self.get().ge(&rhs.get()) + } +} + +impl Ord for MathCell +where T: Copy + Ord +{ + fn cmp(&self, rhs: &Self) -> Ordering + { + self.get().cmp(&rhs.get()) + } +} + +impl fmt::Debug for MathCell +where T: Copy + fmt::Debug +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result + { + self.get().fmt(f) + } +} + +#[cfg(test)] +mod tests +{ + use super::MathCell; + + #[test] + fn test_basic() + { + let c = &MathCell::new(0); + c.set(1); + assert_eq!(c.get(), 1); + } +} diff --git a/src/numeric/impl_float_maths.rs b/src/numeric/impl_float_maths.rs new file mode 100644 index 000000000..54fed49c2 --- /dev/null +++ b/src/numeric/impl_float_maths.rs @@ -0,0 +1,170 @@ +// Element-wise methods for ndarray + +#[cfg(feature = "std")] +use num_traits::Float; + +use crate::imp_prelude::*; + +#[cfg(feature = "std")] +macro_rules! boolean_ops { + ($(#[$meta1:meta])* fn $func:ident + $(#[$meta2:meta])* fn $all:ident + $(#[$meta3:meta])* fn $any:ident) => { + $(#[$meta1])* + #[must_use = "method returns a new array and does not mutate the original value"] + pub fn $func(&self) -> Array { + self.mapv(A::$func) + } + $(#[$meta2])* + #[must_use = "method returns a new boolean value and does not mutate the original value"] + pub fn $all(&self) -> bool { + $crate::Zip::from(self).all(|&elt| !elt.$func()) + } + $(#[$meta3])* + #[must_use = "method returns a new boolean value and does not mutate the original value"] + pub fn $any(&self) -> bool { + !self.$all() + } + }; +} + +#[cfg(feature = "std")] +macro_rules! unary_ops { + ($($(#[$meta:meta])* fn $id:ident)+) => { + $($(#[$meta])* + #[must_use = "method returns a new array and does not mutate the original value"] + pub fn $id(&self) -> Array { + self.mapv(A::$id) + })+ + }; +} + +#[cfg(feature = "std")] +macro_rules! binary_ops { + ($($(#[$meta:meta])* fn $id:ident($ty:ty))+) => { + $($(#[$meta])* + #[must_use = "method returns a new array and does not mutate the original value"] + pub fn $id(&self, rhs: $ty) -> Array { + self.mapv(|v| A::$id(v, rhs)) + })+ + }; +} + +/// # Element-wise methods for float arrays +/// +/// Element-wise math functions for any array type that contains float number. +#[cfg(feature = "std")] +impl ArrayBase +where + A: 'static + Float, + S: Data, + D: Dimension, +{ + boolean_ops! { + /// If the number is `NaN` (not a number), then `true` is returned for each element. + fn is_nan + /// Return `true` if all elements are `NaN` (not a number). + fn is_all_nan + /// Return `true` if any element is `NaN` (not a number). + fn is_any_nan + } + boolean_ops! { + /// If the number is infinity, then `true` is returned for each element. + fn is_infinite + /// Return `true` if all elements are infinity. + fn is_all_infinite + /// Return `true` if any element is infinity. + fn is_any_infinite + } + unary_ops! { + /// The largest integer less than or equal to each element. + fn floor + /// The smallest integer less than or equal to each element. + fn ceil + /// The nearest integer of each element. + fn round + /// The integer part of each element. + fn trunc + /// The fractional part of each element. + fn fract + /// Absolute of each element. + fn abs + /// Sign number of each element. + /// + /// + `1.0` for all positive numbers. + /// + `-1.0` for all negative numbers. + /// + `NaN` for all `NaN` (not a number). + fn signum + /// The reciprocal (inverse) of each element, `1/x`. + fn recip + /// Square root of each element. + fn sqrt + /// `e^x` of each element (exponential function). + fn exp + /// `2^x` of each element. + fn exp2 + /// Natural logarithm of each element. + fn ln + /// Base 2 logarithm of each element. + fn log2 + /// Base 10 logarithm of each element. + fn log10 + /// Cubic root of each element. + fn cbrt + /// Sine of each element (in radians). + fn sin + /// Cosine of each element (in radians). + fn cos + /// Tangent of each element (in radians). + fn tan + /// Converts radians to degrees for each element. + fn to_degrees + /// Converts degrees to radians for each element. + fn to_radians + } + binary_ops! { + /// Integer power of each element. + /// + /// This function is generally faster than using float power. + fn powi(i32) + /// Float power of each element. + fn powf(A) + /// Logarithm of each element with respect to an arbitrary base. + fn log(A) + /// The positive difference between given number and each element. + fn abs_sub(A) + } + + /// Square (two powers) of each element. + #[must_use = "method returns a new array and does not mutate the original value"] + pub fn pow2(&self) -> Array + { + self.mapv(|v: A| v * v) + } +} + +impl ArrayBase +where + A: 'static + PartialOrd + Clone, + S: Data, + D: Dimension, +{ + /// Limit the values for each element, similar to NumPy's `clip` function. + /// + /// ``` + /// use ndarray::array; + /// + /// let a = array![0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]; + /// assert_eq!(a.clamp(1., 8.), array![1., 1., 2., 3., 4., 5., 6., 7., 8., 8.]); + /// assert_eq!(a.clamp(3., 6.), array![3., 3., 3., 3., 4., 5., 6., 6., 6., 6.]); + /// ``` + /// + /// # Panics + /// + /// Panics if `!(min <= max)`. + pub fn clamp(&self, min: A, max: A) -> Array + { + assert!(min <= max, "min must be less than or equal to max"); + self.mapv(|a| num_traits::clamp(a, min.clone(), max.clone())) + } +} diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 85f69444d..6c67b9135 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -6,14 +6,15 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::{self, Float, FromPrimitive, Zero}; -use std::ops::{Add, Div, Mul}; +#[cfg(feature = "std")] +use num_traits::Float; +use num_traits::One; +use num_traits::{FromPrimitive, Zero}; +use std::ops::{Add, Div, Mul, Sub}; use crate::imp_prelude::*; -use crate::itertools::enumerate; use crate::numeric_util; - -use crate::{FoldWhile, Zip}; +use crate::Slice; /// # Numerical Methods for Arrays impl ArrayBase @@ -31,14 +32,13 @@ where /// assert_eq!(a.sum(), 10.); /// ``` pub fn sum(&self) -> A - where - A: Clone + Add + num_traits::Zero, + where A: Clone + Add + num_traits::Zero { if let Some(slc) = self.as_slice_memory_order() { return numeric_util::unrolled_fold(slc, A::zero, A::add); } let mut sum = A::zero(); - for row in self.inner_rows() { + for row in self.rows() { if let Some(slc) = row.as_slice() { sum = sum + numeric_util::unrolled_fold(slc, A::zero, A::add); } else { @@ -62,31 +62,17 @@ where /// /// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean pub fn mean(&self) -> Option - where - A: Clone + FromPrimitive + Add + Div + Zero, + where A: Clone + FromPrimitive + Add + Div + Zero { let n_elements = self.len(); if n_elements == 0 { None } else { - let n_elements = A::from_usize(n_elements) - .expect("Converting number of elements to `A` must not fail."); + let n_elements = A::from_usize(n_elements).expect("Converting number of elements to `A` must not fail."); Some(self.sum() / n_elements) } } - /// Return the sum of all elements in the array. - /// - /// *This method has been renamed to `.sum()` and will be deprecated in the - /// next version.* - // #[deprecated(note="renamed to `sum`", since="0.13")] - pub fn scalar_sum(&self) -> A - where - A: Clone + Add + num_traits::Zero, - { - self.sum() - } - /// Return the product of all elements in the array. /// /// ``` @@ -97,14 +83,13 @@ where /// assert_eq!(a.product(), 24.); /// ``` pub fn product(&self) -> A - where - A: Clone + Mul + num_traits::One, + where A: Clone + Mul + num_traits::One { if let Some(slc) = self.as_slice_memory_order() { return numeric_util::unrolled_fold(slc, A::one, A::mul); } let mut sum = A::one(); - for row in self.inner_rows() { + for row in self.rows() { if let Some(slc) = row.as_slice() { sum = sum * numeric_util::unrolled_fold(slc, A::one, A::mul); } else { @@ -114,6 +99,118 @@ where sum } + /// Return variance of elements in the array. + /// + /// The variance is computed using the [Welford one-pass + /// algorithm](https://www.jstor.org/stable/1266577). + /// + /// The parameter `ddof` specifies the "delta degrees of freedom". For + /// example, to calculate the population variance, use `ddof = 0`, or to + /// calculate the sample variance, use `ddof = 1`. + /// + /// The variance is defined as: + /// + /// ```text + /// 1 n + /// variance = ―――――――― ∑ (xᵢ - x̅)² + /// n - ddof i=1 + /// ``` + /// + /// where + /// + /// ```text + /// 1 n + /// x̅ = ― ∑ xᵢ + /// n i=1 + /// ``` + /// + /// and `n` is the length of the array. + /// + /// **Panics** if `ddof` is less than zero or greater than `n` + /// + /// # Example + /// + /// ``` + /// use ndarray::array; + /// use approx::assert_abs_diff_eq; + /// + /// let a = array![1., -4.32, 1.14, 0.32]; + /// let var = a.var(1.); + /// assert_abs_diff_eq!(var, 6.7331, epsilon = 1e-4); + /// ``` + #[track_caller] + #[cfg(feature = "std")] + pub fn var(&self, ddof: A) -> A + where A: Float + FromPrimitive + { + let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); + let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail."); + assert!( + !(ddof < zero || ddof > n), + "`ddof` must not be less than zero or greater than the length of \ + the axis", + ); + let dof = n - ddof; + let mut mean = A::zero(); + let mut sum_sq = A::zero(); + let mut i = 0; + self.for_each(|&x| { + let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail."); + let delta = x - mean; + mean = mean + delta / count; + sum_sq = (x - mean).mul_add(delta, sum_sq); + i += 1; + }); + sum_sq / dof + } + + /// Return standard deviation of elements in the array. + /// + /// The standard deviation is computed from the variance using + /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577). + /// + /// The parameter `ddof` specifies the "delta degrees of freedom". For + /// example, to calculate the population standard deviation, use `ddof = 0`, + /// or to calculate the sample standard deviation, use `ddof = 1`. + /// + /// The standard deviation is defined as: + /// + /// ```text + /// ⎛ 1 n ⎞ + /// stddev = sqrt ⎜ ―――――――― ∑ (xᵢ - x̅)²⎟ + /// ⎝ n - ddof i=1 ⎠ + /// ``` + /// + /// where + /// + /// ```text + /// 1 n + /// x̅ = ― ∑ xᵢ + /// n i=1 + /// ``` + /// + /// and `n` is the length of the array. + /// + /// **Panics** if `ddof` is less than zero or greater than `n` + /// + /// # Example + /// + /// ``` + /// use ndarray::array; + /// use approx::assert_abs_diff_eq; + /// + /// let a = array![1., -4.32, 1.14, 0.32]; + /// let stddev = a.std(1.); + /// assert_abs_diff_eq!(stddev, 2.59483, epsilon = 1e-4); + /// ``` + #[track_caller] + #[cfg(feature = "std")] + pub fn std(&self, ddof: A) -> A + where A: Float + FromPrimitive + { + self.var(ddof).sqrt() + } + /// Return sum along `axis`. /// /// ``` @@ -130,27 +227,59 @@ where /// ``` /// /// **Panics** if `axis` is out of bounds. + #[track_caller] pub fn sum_axis(&self, axis: Axis) -> Array where A: Clone + Zero + Add, D: RemoveAxis, { - let n = self.len_of(axis); - let mut res = Array::zeros(self.raw_dim().remove_axis(axis)); - let stride = self.strides()[axis.index()]; - if self.ndim() == 2 && stride == 1 { - // contiguous along the axis we are summing - let ax = axis.index(); - for (i, elt) in enumerate(&mut res) { - *elt = self.index_axis(Axis(1 - ax), i).sum(); + let min_stride_axis = self.dim.min_stride_axis(&self.strides); + if axis == min_stride_axis { + crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.sum()) + } else { + let mut res = Array::zeros(self.raw_dim().remove_axis(axis)); + for subview in self.axis_iter(axis) { + res = res + &subview; } + res + } + } + + /// Return product along `axis`. + /// + /// The product of an empty array is 1. + /// + /// ``` + /// use ndarray::{aview0, aview1, arr2, Axis}; + /// + /// let a = arr2(&[[1., 2., 3.], + /// [4., 5., 6.]]); + /// + /// assert!( + /// a.product_axis(Axis(0)) == aview1(&[4., 10., 18.]) && + /// a.product_axis(Axis(1)) == aview1(&[6., 120.]) && + /// + /// a.product_axis(Axis(0)).product_axis(Axis(0)) == aview0(&720.) + /// ); + /// ``` + /// + /// **Panics** if `axis` is out of bounds. + #[track_caller] + pub fn product_axis(&self, axis: Axis) -> Array + where + A: Clone + One + Mul, + D: RemoveAxis, + { + let min_stride_axis = self.dim.min_stride_axis(&self.strides); + if axis == min_stride_axis { + crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.product()) } else { - for i in 0..n { - let view = self.index_axis(axis, i); - res = res + &view; + let mut res = Array::ones(self.raw_dim().remove_axis(axis)); + for subview in self.axis_iter(axis) { + res = res * &subview; } + res } - res } /// Return mean along `axis`. @@ -172,6 +301,7 @@ where /// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5) /// ); /// ``` + #[track_caller] pub fn mean_axis(&self, axis: Axis) -> Option> where A: Clone + Zero + FromPrimitive + Add + Div, @@ -181,8 +311,7 @@ where if axis_length == 0 { None } else { - let axis_length = - A::from_usize(axis_length).expect("Converting axis length to `A` must not fail."); + let axis_length = A::from_usize(axis_length).expect("Converting axis length to `A` must not fail."); let sum = self.sum_axis(axis); Some(sum / aview0(&axis_length)) } @@ -230,6 +359,8 @@ where /// let var = a.var_axis(Axis(0), 1.); /// assert_eq!(var, aview1(&[4., 4.])); /// ``` + #[track_caller] + #[cfg(feature = "std")] pub fn var_axis(&self, axis: Axis, ddof: A) -> Array where A: Float + FromPrimitive, @@ -298,6 +429,8 @@ where /// let stddev = a.std_axis(Axis(0), 1.); /// assert_eq!(stddev, aview1(&[2., 2.])); /// ``` + #[track_caller] + #[cfg(feature = "std")] pub fn std_axis(&self, axis: Axis, ddof: A) -> Array where A: Float + FromPrimitive, @@ -306,31 +439,58 @@ where self.var_axis(axis, ddof).mapv_into(|x| x.sqrt()) } - /// Return `true` if the arrays' elementwise differences are all within - /// the given absolute tolerance, `false` otherwise. + /// Calculates the (forward) finite differences of order `n`, along the `axis`. + /// For the 1D-case, `n==1`, this means: `diff[i] == arr[i+1] - arr[i]` /// - /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. + /// For `n>=2`, the process is iterated: + /// ``` + /// use ndarray::{array, Axis}; + /// let arr = array![1.0, 2.0, 5.0]; + /// assert_eq!(arr.diff(2, Axis(0)), arr.diff(1, Axis(0)).diff(1, Axis(0))) + /// ``` + /// **Panics** if `axis` is out of bounds /// - /// **Panics** if broadcasting to the same shape isn’t possible. - #[deprecated( - note = "Use `abs_diff_eq` - it requires the `approx` crate feature", - since = "0.13.0" - )] - pub fn all_close(&self, rhs: &ArrayBase, tol: A) -> bool - where - A: Float, - S2: Data, - E: Dimension, + /// **Panics** if `n` is too big / the array is to short: + /// ```should_panic + /// use ndarray::{array, Axis}; + /// array![1.0, 2.0, 3.0].diff(10, Axis(0)); + /// ``` + pub fn diff(&self, n: usize, axis: Axis) -> Array + where A: Sub + Zero + Clone { - !Zip::from(self) - .and(rhs.broadcast_unwrap(self.raw_dim())) - .fold_while((), |_, x, y| { - if (*x - *y).abs() <= tol { - FoldWhile::Continue(()) - } else { - FoldWhile::Done(()) - } - }) - .is_done() + if n == 0 { + return self.to_owned(); + } + assert!(axis.0 < self.ndim(), "The array has only ndim {}, but `axis` {:?} is given.", self.ndim(), axis); + assert!( + n < self.shape()[axis.0], + "The array must have length at least `n+1`=={} in the direction of `axis`. It has length {}", + n + 1, + self.shape()[axis.0] + ); + + let mut inp = self.to_owned(); + let mut out = Array::zeros({ + let mut inp_dim = self.raw_dim(); + // inp_dim[axis.0] >= 1 as per the 2nd assertion. + inp_dim[axis.0] -= 1; + inp_dim + }); + for _ in 0..n { + let head = inp.slice_axis(axis, Slice::from(..-1)); + let tail = inp.slice_axis(axis, Slice::from(1..)); + + azip!((o in &mut out, h in head, t in tail) *o = t.clone() - h.clone()); + + // feed the output as the input to the next iteration + std::mem::swap(&mut inp, &mut out); + + // adjust the new output array width along `axis`. + // Current situation: width of `inp`: k, `out`: k+1 + // needed width: `inp`: k, `out`: k-1 + // slice is possible, since k >= 1. + out.slice_axis_inplace(axis, Slice::from(..-2)); + } + inp } } diff --git a/src/numeric/mod.rs b/src/numeric/mod.rs index b3da06746..c0a7228c5 100644 --- a/src/numeric/mod.rs +++ b/src/numeric/mod.rs @@ -1 +1,3 @@ mod impl_numeric; + +mod impl_float_maths; diff --git a/src/numeric_util.rs b/src/numeric_util.rs index b06850fd0..9d5ce66c5 100644 --- a/src/numeric_util.rs +++ b/src/numeric_util.rs @@ -20,16 +20,8 @@ where // eightfold unrolled so that floating point can be vectorized // (even with strict floating point accuracy semantics) let mut acc = init(); - let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) = ( - init(), - init(), - init(), - init(), - init(), - init(), - init(), - init(), - ); + let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) = + (init(), init(), init(), init(), init(), init(), init(), init()); while xs.len() >= 8 { p0 = f(p0, xs[0].clone()); p1 = f(p1, xs[1].clone()); @@ -62,8 +54,7 @@ where /// /// `xs` and `ys` must be the same length pub fn unrolled_dot(xs: &[A], ys: &[A]) -> A -where - A: LinalgScalar, +where A: LinalgScalar { debug_assert_eq!(xs.len(), ys.len()); // eightfold unrolled so that floating point can be vectorized @@ -72,16 +63,8 @@ where let mut xs = &xs[..len]; let mut ys = &ys[..len]; let mut sum = A::zero(); - let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) = ( - A::zero(), - A::zero(), - A::zero(), - A::zero(), - A::zero(), - A::zero(), - A::zero(), - A::zero(), - ); + let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) = + (A::zero(), A::zero(), A::zero(), A::zero(), A::zero(), A::zero(), A::zero(), A::zero()); while xs.len() >= 8 { p0 = p0 + xs[0] * ys[0]; p1 = p1 + xs[1] * ys[1]; @@ -113,8 +96,7 @@ where /// /// `xs` and `ys` must be the same length pub fn unrolled_eq(xs: &[A], ys: &[B]) -> bool -where - A: PartialEq, +where A: PartialEq { debug_assert_eq!(xs.len(), ys.len()); // eightfold unrolled for performance (this is not done by llvm automatically) diff --git a/src/order.rs b/src/order.rs new file mode 100644 index 000000000..a52a32e2c --- /dev/null +++ b/src/order.rs @@ -0,0 +1,93 @@ +/// Array order +/// +/// Order refers to indexing order, or how a linear sequence is translated +/// into a two-dimensional or multi-dimensional array. +/// +/// - `RowMajor` means that the index along the row is the most rapidly changing +/// - `ColumnMajor` means that the index along the column is the most rapidly changing +/// +/// Given a sequence like: 1, 2, 3, 4, 5, 6 +/// +/// If it is laid it out in a 2 x 3 matrix using row major ordering, it results in: +/// +/// ```text +/// 1 2 3 +/// 4 5 6 +/// ``` +/// +/// If it is laid using column major ordering, it results in: +/// +/// ```text +/// 1 3 5 +/// 2 4 6 +/// ``` +/// +/// It can be seen as filling in "rows first" or "columns first". +/// +/// `Order` can be used both to refer to logical ordering as well as memory ordering or memory +/// layout. The orderings have common short names, also seen in other environments, where +/// row major is called "C" order (after the C programming language) and column major is called "F" +/// or "Fortran" order. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum Order +{ + /// Row major or "C" order + RowMajor, + /// Column major or "F" order + ColumnMajor, +} + +impl Order +{ + /// "C" is an alias for row major ordering + pub const C: Order = Order::RowMajor; + + /// "F" (for Fortran) is an alias for column major ordering + pub const F: Order = Order::ColumnMajor; + + /// Return true if input is Order::RowMajor, false otherwise + #[inline] + pub fn is_row_major(self) -> bool + { + match self { + Order::RowMajor => true, + Order::ColumnMajor => false, + } + } + + /// Return true if input is Order::ColumnMajor, false otherwise + #[inline] + pub fn is_column_major(self) -> bool + { + !self.is_row_major() + } + + /// Return Order::RowMajor if the input is true, Order::ColumnMajor otherwise + #[inline] + pub fn row_major(row_major: bool) -> Order + { + if row_major { + Order::RowMajor + } else { + Order::ColumnMajor + } + } + + /// Return Order::ColumnMajor if the input is true, Order::RowMajor otherwise + #[inline] + pub fn column_major(column_major: bool) -> Order + { + Self::row_major(!column_major) + } + + /// Return the transpose: row major becomes column major and vice versa. + #[inline] + pub fn transpose(self) -> Order + { + match self { + Order::RowMajor => Order::ColumnMajor, + Order::ColumnMajor => Order::RowMajor, + } + } +} diff --git a/src/parallel/impl_par_methods.rs b/src/parallel/impl_par_methods.rs index 88fe769bf..c6af4e8f3 100644 --- a/src/parallel/impl_par_methods.rs +++ b/src/parallel/impl_par_methods.rs @@ -1,7 +1,12 @@ -use crate::{ArrayBase, DataMut, Dimension, NdProducer, Zip}; +use crate::AssignElem; +use crate::{Array, ArrayBase, DataMut, Dimension, IntoNdProducer, NdProducer, Zip}; +use super::send_producer::SendProducer; +use crate::parallel::par::ParallelSplits; use crate::parallel::prelude::*; +use crate::partial::Partial; + /// # Parallel methods /// /// These methods require crate feature `rayon`. @@ -17,8 +22,7 @@ where /// /// Elements are visited in arbitrary order. pub fn par_map_inplace(&mut self, f: F) - where - F: Fn(&mut A) + Sync + Send, + where F: Fn(&mut A) + Sync + Send { self.view_mut().into_par_iter().for_each(f) } @@ -42,8 +46,10 @@ where // Zip +const COLLECT_MAX_SPLITS: usize = 10; + macro_rules! zip_impl { - ($([$($p:ident)*],)+) => { + ($([$notlast:ident $($p:ident)*],)+) => { $( #[allow(non_snake_case)] impl Zip<($($p,)*), D> @@ -52,27 +58,146 @@ macro_rules! zip_impl { D: Dimension, $($p: NdProducer ,)* { - /// The `par_apply` method for `Zip`. + /// The `par_for_each` method for `Zip`. /// /// This is a shorthand for using `.into_par_iter().for_each()` on /// `Zip`. /// /// Requires crate feature `rayon`. - pub fn par_apply(self, function: F) + pub fn par_for_each(self, function: F) where F: Fn($($p::Item),*) + Sync + Send { self.into_par_iter().for_each(move |($($p,)*)| function($($p),*)) } + + expand_if!(@bool [$notlast] + + /// Map and collect the results into a new array, which has the same size as the + /// inputs. + /// + /// If all inputs are c- or f-order respectively, that is preserved in the output. + pub fn par_map_collect(self, f: impl Fn($($p::Item,)* ) -> R + Sync + Send) + -> Array + where R: Send + { + let mut output = self.uninitialized_for_current_layout::(); + let total_len = output.len(); + + // Create a parallel iterator that produces chunks of the zip with the output + // array. It's crucial that both parts split in the same way, and in a way + // so that the chunks of the output are still contig. + // + // Use a raw view so that we can alias the output data here and in the partial + // result. + let splits = unsafe { + ParallelSplits { + iter: self.and(SendProducer::new(output.raw_view_mut().cast::())), + // Keep it from splitting the Zip down too small + max_splits: COLLECT_MAX_SPLITS, + } + }; + + let collect_result = splits.map(move |zip| { + // Apply the mapping function on this chunk of the zip + // Create a partial result for the contiguous slice of data being written to + unsafe { + zip.collect_with_partial(&f) + } + }) + .reduce(Partial::stub, Partial::try_merge); + + if std::mem::needs_drop::() { + debug_assert_eq!(total_len, collect_result.len, + "collect len is not correct, expected {}", total_len); + assert!(collect_result.len == total_len, + "Collect: Expected number of writes not completed"); + } + + // Here the collect result is complete, and we release its ownership and transfer + // it to the output array. + collect_result.release_ownership(); + unsafe { + output.assume_init() + } + } + + /// Map and assign the results into the producer `into`, which should have the same + /// size as the other inputs. + /// + /// The producer should have assignable items as dictated by the `AssignElem` trait, + /// for example `&mut R`. + pub fn par_map_assign_into(self, into: Q, f: impl Fn($($p::Item,)* ) -> R + Sync + Send) + where Q: IntoNdProducer, + Q::Item: AssignElem + Send, + Q::Output: Send, + { + self.and(into) + .par_for_each(move |$($p, )* output_| { + output_.assign_elem(f($($p ),*)); + }); + } + + /// Parallel version of `fold`. + /// + /// Splits the producer in multiple tasks which each accumulate a single value + /// using the `fold` closure. Those tasks are executed in parallel and their results + /// are then combined to a single value using the `reduce` closure. + /// + /// The `identity` closure provides the initial values for each of the tasks and + /// for the final reduction. + /// + /// This is a shorthand for calling `self.into_par_iter().fold(...).reduce(...)`. + /// + /// Note that it is often more efficient to parallelize not per-element but rather + /// based on larger chunks of an array like generalized rows and operating on each chunk + /// using a sequential variant of the accumulation. + /// For example, sum each row sequentially and in parallel, taking advantage of locality + /// and vectorization within each task, and then reduce their sums to the sum of the matrix. + /// + /// Also note that the splitting of the producer into multiple tasks is _not_ deterministic + /// which needs to be considered when the accuracy of such an operation is analyzed. + /// + /// ## Examples + /// + /// ```rust + /// use ndarray::{Array, Zip}; + /// + /// let a = Array::::ones((128, 1024)); + /// let b = Array::::ones(128); + /// + /// let weighted_sum = Zip::from(a.rows()).and(&b).par_fold( + /// || 0, + /// |sum, row, factor| sum + row.sum() * factor, + /// |sum, other_sum| sum + other_sum, + /// ); + /// + /// assert_eq!(weighted_sum, a.len()); + /// ``` + pub fn par_fold(self, identity: ID, fold: F, reduce: R) -> T + where + ID: Fn() -> T + Send + Sync + Clone, + F: Fn(T, $($p::Item),*) -> T + Send + Sync, + R: Fn(T, T) -> T + Send + Sync, + T: Send + { + self.into_par_iter() + .fold(identity.clone(), move |accumulator, ($($p,)*)| { + fold(accumulator, $($p),*) + }) + .reduce(identity, reduce) + } + + ); } )+ - } + }; } zip_impl! { - [P1], - [P1 P2], - [P1 P2 P3], - [P1 P2 P3 P4], - [P1 P2 P3 P4 P5], - [P1 P2 P3 P4 P5 P6], + [true P1], + [true P1 P2], + [true P1 P2 P3], + [true P1 P2 P3 P4], + [true P1 P2 P3 P4 P5], + [false P1 P2 P3 P4 P5 P6], } diff --git a/src/parallel/into_impls.rs b/src/parallel/into_impls.rs index c1a5388fd..75bded7de 100644 --- a/src/parallel/into_impls.rs +++ b/src/parallel/into_impls.rs @@ -11,7 +11,8 @@ where { type Item = &'a A; type Iter = Parallel>; - fn into_par_iter(self) -> Self::Iter { + fn into_par_iter(self) -> Self::Iter + { self.view().into_par_iter() } } @@ -25,7 +26,8 @@ where { type Item = &'a A; type Iter = Parallel>; - fn into_par_iter(self) -> Self::Iter { + fn into_par_iter(self) -> Self::Iter + { self.view().into_par_iter() } } @@ -38,7 +40,8 @@ where { type Item = &'a mut A; type Iter = Parallel>; - fn into_par_iter(self) -> Self::Iter { + fn into_par_iter(self) -> Self::Iter + { self.view_mut().into_par_iter() } } @@ -52,7 +55,8 @@ where { type Item = &'a mut A; type Iter = Parallel>; - fn into_par_iter(self) -> Self::Iter { + fn into_par_iter(self) -> Self::Iter + { self.view_mut().into_par_iter() } } diff --git a/src/parallel/mod.rs b/src/parallel/mod.rs index 60dbe4662..0c84baa91 100644 --- a/src/parallel/mod.rs +++ b/src/parallel/mod.rs @@ -10,18 +10,20 @@ //! The following types implement parallel iterators, accessed using these //! methods: //! -//! - [`Array`], [`ArcArray`]: `.par_iter()` and `.par_iter_mut()` -//! - [`ArrayView`](ArrayView): `.into_par_iter()` -//! - [`ArrayViewMut`](ArrayViewMut): `.into_par_iter()` -//! - [`AxisIter`](iter::AxisIter), [`AxisIterMut`](iter::AxisIterMut): `.into_par_iter()` -//! - [`AxisChunksIter`](iter::AxisChunksIter), [`AxisChunksIterMut`](iter::AxisChunksIterMut): `.into_par_iter()` +//! - [`Array`], [`ArcArray`] `.par_iter()` and `.par_iter_mut()` +//! - [`ArrayView`] `.into_par_iter()` +//! - [`ArrayViewMut`] `.into_par_iter()` +//! - [`AxisIter`], [`AxisIterMut`] `.into_par_iter()` +//! - [`AxisChunksIter`], [`AxisChunksIterMut`] `.into_par_iter()` //! - [`Zip`] `.into_par_iter()` //! //! The following other parallelized methods exist: //! //! - [`ArrayBase::par_map_inplace()`] //! - [`ArrayBase::par_mapv_inplace()`] -//! - [`Zip::par_apply()`] (all arities) +//! - [`Zip::par_for_each()`] (all arities) +//! - [`Zip::par_map_collect()`] (all arities) +//! - [`Zip::par_map_assign_into()`] (all arities) //! //! Note that you can use the parallel iterator for [Zip] to access all other //! rayon parallel iterator methods. @@ -30,6 +32,10 @@ //! “unindexed”. Use ndarray’s [Zip] for lock step parallel iteration of //! multiple arrays or producers at a time. //! +//! For the unindexed parallel iterators, an inherent method [`with_min_len`](Parallel::with_min_len) +//! is provided to limit the number of elements each parallel task processes in way that is +//! similar to Rayon's [`IndexedParallelIterator::with_min_len`](rayon::prelude::IndexedParallelIterator::with_min_len). +//! //! # Examples //! //! ## Arrays and array views @@ -37,21 +43,17 @@ //! Compute the exponential of each element in an array, parallelized. //! //! ``` -//! extern crate ndarray; -//! //! use ndarray::Array2; //! use ndarray::parallel::prelude::*; //! -//! fn main() { -//! let mut a = Array2::::zeros((128, 128)); +//! let mut a = Array2::::zeros((128, 128)); //! -//! // Parallel versions of regular array methods -//! a.par_map_inplace(|x| *x = x.exp()); -//! a.par_mapv_inplace(f64::exp); +//! // Parallel versions of regular array methods +//! a.par_map_inplace(|x| *x = x.exp()); +//! a.par_mapv_inplace(f64::exp); //! -//! // You can also use the parallel iterator directly -//! a.par_iter_mut().for_each(|x| *x = x.exp()); -//! } +//! // You can also use the parallel iterator directly +//! a.par_iter_mut().for_each(|x| *x = x.exp()); //! ``` //! //! ## Axis iterators @@ -59,22 +61,18 @@ //! Use the parallel `.axis_iter()` to compute the sum of each row. //! //! ``` -//! extern crate ndarray; -//! //! use ndarray::Array; //! use ndarray::Axis; //! use ndarray::parallel::prelude::*; //! -//! fn main() { -//! let a = Array::linspace(0., 63., 64).into_shape((4, 16)).unwrap(); -//! let mut sums = Vec::new(); -//! a.axis_iter(Axis(0)) -//! .into_par_iter() -//! .map(|row| row.sum()) -//! .collect_into_vec(&mut sums); +//! let a = Array::linspace(0., 63., 64).into_shape_with_order((4, 16)).unwrap(); +//! let mut sums = Vec::new(); +//! a.axis_iter(Axis(0)) +//! .into_par_iter() +//! .map(|row| row.sum()) +//! .collect_into_vec(&mut sums); //! -//! assert_eq!(sums, [120., 376., 632., 888.]); -//! } +//! assert_eq!(sums, [120., 376., 632., 888.]); //! ``` //! //! ## Axis chunks iterators @@ -82,22 +80,18 @@ //! Use the parallel `.axis_chunks_iter()` to process your data in chunks. //! //! ``` -//! extern crate ndarray; -//! //! use ndarray::Array; //! use ndarray::Axis; //! use ndarray::parallel::prelude::*; //! -//! fn main() { -//! let a = Array::linspace(0., 63., 64).into_shape((4, 16)).unwrap(); -//! let mut shapes = Vec::new(); -//! a.axis_chunks_iter(Axis(0), 3) -//! .into_par_iter() -//! .map(|chunk| chunk.shape().to_owned()) -//! .collect_into_vec(&mut shapes); +//! let a = Array::linspace(0., 63., 64).into_shape_with_order((4, 16)).unwrap(); +//! let mut shapes = Vec::new(); +//! a.axis_chunks_iter(Axis(0), 3) +//! .into_par_iter() +//! .map(|chunk| chunk.shape().to_owned()) +//! .collect_into_vec(&mut shapes); //! -//! assert_eq!(shapes, [vec![3, 16], vec![1, 16]]); -//! } +//! assert_eq!(shapes, [vec![3, 16], vec![1, 16]]); //! ``` //! //! ## Zip @@ -105,34 +99,39 @@ //! Use zip for lock step function application across several arrays //! //! ``` -//! extern crate ndarray; -//! //! use ndarray::Array3; //! use ndarray::Zip; //! //! type Array3f64 = Array3; //! -//! fn main() { -//! const N: usize = 128; -//! let a = Array3f64::from_elem((N, N, N), 1.); -//! let b = Array3f64::from_elem(a.dim(), 2.); -//! let mut c = Array3f64::zeros(a.dim()); -//! -//! Zip::from(&mut c) -//! .and(&a) -//! .and(&b) -//! .par_apply(|c, &a, &b| { -//! *c += a - b; -//! }); -//! } +//! const N: usize = 128; +//! let a = Array3f64::from_elem((N, N, N), 1.); +//! let b = Array3f64::from_elem(a.dim(), 2.); +//! let mut c = Array3f64::zeros(a.dim()); +//! +//! Zip::from(&mut c) +//! .and(&a) +//! .and(&b) +//! .par_for_each(|c, &a, &b| { +//! *c += a - b; +//! }); //! ``` +#[allow(unused_imports)] // used by rustdoc links +use crate::iter::{AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut}; +#[allow(unused_imports)] // used by rustdoc links +use crate::{ArcArray, Array, ArrayBase, ArrayView, ArrayViewMut, Zip}; + /// Into- traits for creating parallelized iterators and/or using [`par_azip!`] -pub mod prelude { +pub mod prelude +{ #[doc(no_inline)] pub use rayon::prelude::{ - IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, - IntoParallelRefMutIterator, ParallelIterator, + IndexedParallelIterator, + IntoParallelIterator, + IntoParallelRefIterator, + IntoParallelRefMutIterator, + ParallelIterator, }; pub use super::par_azip; @@ -144,4 +143,5 @@ pub use crate::par_azip; mod impl_par_methods; mod into_impls; mod par; +mod send_producer; mod zipmacro; diff --git a/src/parallel/par.rs b/src/parallel/par.rs index efd761acf..b59af4c8e 100644 --- a/src/parallel/par.rs +++ b/src/parallel/par.rs @@ -13,18 +13,23 @@ use crate::iter::AxisChunksIter; use crate::iter::AxisChunksIterMut; use crate::iter::AxisIter; use crate::iter::AxisIterMut; +use crate::split_at::SplitPreference; use crate::Dimension; use crate::{ArrayView, ArrayViewMut}; /// Parallel iterator wrapper. #[derive(Copy, Clone, Debug)] -pub struct Parallel { +pub struct Parallel +{ iter: I, + min_len: usize, } +const DEFAULT_MIN_LEN: usize = 1; + /// Parallel producer wrapper. #[derive(Copy, Clone, Debug)] -struct ParallelProducer(I); +struct ParallelProducer(I, usize); macro_rules! par_iter_wrapper { // thread_bounds are either Sync or Send + Sync @@ -39,6 +44,7 @@ macro_rules! par_iter_wrapper { fn into_par_iter(self) -> Self::Iter { Parallel { iter: self, + min_len: DEFAULT_MIN_LEN, } } } @@ -66,7 +72,7 @@ macro_rules! par_iter_wrapper { fn with_producer(self, callback: Cb) -> Cb::Output where Cb: ProducerCallback { - callback.callback(ParallelProducer(self.iter)) + callback.callback(ParallelProducer(self.iter, self.min_len)) } fn len(&self) -> usize { @@ -105,11 +111,11 @@ macro_rules! par_iter_wrapper { fn split_at(self, i: usize) -> (Self, Self) { let (a, b) = self.0.split_at(i); - (ParallelProducer(a), ParallelProducer(b)) + (ParallelProducer(a, self.1), ParallelProducer(b, self.1)) } } - } + }; } par_iter_wrapper!(AxisIter, [Sync]); @@ -130,11 +136,11 @@ macro_rules! par_iter_view_wrapper { fn into_par_iter(self) -> Self::Iter { Parallel { iter: self, + min_len: DEFAULT_MIN_LEN, } } } - impl<'a, A, D> ParallelIterator for Parallel<$view_name<'a, A, D>> where D: Dimension, A: $($thread_bounds)*, @@ -143,7 +149,7 @@ macro_rules! par_iter_view_wrapper { fn drive_unindexed(self, consumer: C) -> C::Result where C: UnindexedConsumer { - bridge_unindexed(ParallelProducer(self.iter), consumer) + bridge_unindexed(ParallelProducer(self.iter, self.min_len), consumer) } fn opt_len(&self) -> Option { @@ -151,26 +157,52 @@ macro_rules! par_iter_view_wrapper { } } + impl<'a, A, D> Parallel<$view_name<'a, A, D>> + where D: Dimension, + A: $($thread_bounds)*, + { + /// Sets the minimum number of elements desired to process in each job. This will not be + /// split any smaller than this length, but of course a producer could already be smaller + /// to begin with. + /// + /// ***Panics*** if `min_len` is zero. + pub fn with_min_len(self, min_len: usize) -> Self { + assert_ne!(min_len, 0, "Minimum number of elements must at least be one to avoid splitting off empty tasks."); + + Self { + min_len, + ..self + } + } + } + impl<'a, A, D> UnindexedProducer for ParallelProducer<$view_name<'a, A, D>> where D: Dimension, A: $($thread_bounds)*, { type Item = <$view_name<'a, A, D> as IntoIterator>::Item; fn split(self) -> (Self, Option) { - if self.0.len() <= 1 { + if self.0.len() <= self.1 { return (self, None) } let array = self.0; let max_axis = array.max_stride_axis(); let mid = array.len_of(max_axis) / 2; let (a, b) = array.split_at(max_axis, mid); - (ParallelProducer(a), Some(ParallelProducer(b))) + (ParallelProducer(a, self.1), Some(ParallelProducer(b, self.1))) } fn fold_with(self, folder: F) -> F where F: Folder, { - self.into_iter().fold(folder, move |f, elt| f.consume(elt)) + Zip::from(self.0).fold_while(folder, |mut folder, elt| { + folder = folder.consume(elt); + if folder.full() { + FoldWhile::Done(folder) + } else { + FoldWhile::Continue(folder) + } + }).into_inner() } } @@ -185,7 +217,7 @@ macro_rules! par_iter_view_wrapper { } } - } + }; } par_iter_view_wrapper!(ArrayView, [Sync]); @@ -209,6 +241,7 @@ macro_rules! zip_impl { fn into_par_iter(self) -> Self::Iter { Parallel { iter: self, + min_len: DEFAULT_MIN_LEN, } } } @@ -225,7 +258,7 @@ macro_rules! zip_impl { fn drive_unindexed(self, consumer: Cons) -> Cons::Result where Cons: UnindexedConsumer { - bridge_unindexed(ParallelProducer(self.iter), consumer) + bridge_unindexed(ParallelProducer(self.iter, self.min_len), consumer) } fn opt_len(&self) -> Option { @@ -243,11 +276,11 @@ macro_rules! zip_impl { type Item = ($($p::Item ,)*); fn split(self) -> (Self, Option) { - if self.0.size() <= 1 { + if self.0.size() <= self.1 { return (self, None) } let (a, b) = self.0.split(); - (ParallelProducer(a), Some(ParallelProducer(b))) + (ParallelProducer(a, self.1), Some(ParallelProducer(b, self.1))) } fn fold_with(self, folder: Fold) -> Fold @@ -264,7 +297,7 @@ macro_rules! zip_impl { } } )+ - } + }; } zip_impl! { @@ -275,3 +308,74 @@ zip_impl! { [P1 P2 P3 P4 P5], [P1 P2 P3 P4 P5 P6], } + +impl Parallel> +where D: Dimension +{ + /// Sets the minimum number of elements desired to process in each job. This will not be + /// split any smaller than this length, but of course a producer could already be smaller + /// to begin with. + /// + /// ***Panics*** if `min_len` is zero. + pub fn with_min_len(self, min_len: usize) -> Self + { + assert_ne!(min_len, 0, "Minimum number of elements must at least be one to avoid splitting off empty tasks."); + + Self { min_len, ..self } + } +} + +/// A parallel iterator (unindexed) that produces the splits of the array +/// or producer `P`. +pub(crate) struct ParallelSplits

+{ + pub(crate) iter: P, + pub(crate) max_splits: usize, +} + +impl

ParallelIterator for ParallelSplits

+where P: SplitPreference + Send +{ + type Item = P; + + fn drive_unindexed(self, consumer: C) -> C::Result + where C: UnindexedConsumer + { + bridge_unindexed(self, consumer) + } + + fn opt_len(&self) -> Option + { + None + } +} + +impl

UnindexedProducer for ParallelSplits

+where P: SplitPreference + Send +{ + type Item = P; + + fn split(self) -> (Self, Option) + { + if self.max_splits == 0 || !self.iter.can_split() { + return (self, None); + } + let (a, b) = self.iter.split(); + ( + ParallelSplits { + iter: a, + max_splits: self.max_splits - 1, + }, + Some(ParallelSplits { + iter: b, + max_splits: self.max_splits - 1, + }), + ) + } + + fn fold_with(self, folder: Fold) -> Fold + where Fold: Folder + { + folder.consume(self.iter) + } +} diff --git a/src/parallel/send_producer.rs b/src/parallel/send_producer.rs new file mode 100644 index 000000000..ecfb77af0 --- /dev/null +++ b/src/parallel/send_producer.rs @@ -0,0 +1,103 @@ +use crate::imp_prelude::*; +use crate::{Layout, NdProducer}; +use std::ops::{Deref, DerefMut}; + +/// An NdProducer that is unconditionally `Send`. +#[repr(transparent)] +pub(crate) struct SendProducer +{ + inner: T, +} + +impl SendProducer +{ + /// Create an unconditionally `Send` ndproducer from the producer + pub(crate) unsafe fn new(producer: T) -> Self + { + Self { inner: producer } + } +} + +unsafe impl

Send for SendProducer

{} + +impl

Deref for SendProducer

+{ + type Target = P; + fn deref(&self) -> &P + { + &self.inner + } +} + +impl

DerefMut for SendProducer

+{ + fn deref_mut(&mut self) -> &mut P + { + &mut self.inner + } +} + +impl NdProducer for SendProducer

+where P: NdProducer +{ + type Item = P::Item; + type Dim = P::Dim; + type Ptr = P::Ptr; + type Stride = P::Stride; + + private_impl! {} + + #[inline(always)] + fn raw_dim(&self) -> Self::Dim + { + self.inner.raw_dim() + } + + #[inline(always)] + fn equal_dim(&self, dim: &Self::Dim) -> bool + { + self.inner.equal_dim(dim) + } + + #[inline(always)] + fn as_ptr(&self) -> Self::Ptr + { + self.inner.as_ptr() + } + + #[inline(always)] + fn layout(&self) -> Layout + { + self.inner.layout() + } + + #[inline(always)] + unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item + { + self.inner.as_ref(ptr) + } + + #[inline(always)] + unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr + { + self.inner.uget_ptr(i) + } + + #[inline(always)] + fn stride_of(&self, axis: Axis) -> Self::Stride + { + self.inner.stride_of(axis) + } + + #[inline(always)] + fn contiguous_stride(&self) -> Self::Stride + { + self.inner.contiguous_stride() + } + + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) + { + let (a, b) = self.inner.split_at(axis, index); + (Self { inner: a }, Self { inner: b }) + } +} diff --git a/src/parallel/zipmacro.rs b/src/parallel/zipmacro.rs index 99c3807f8..28188542f 100644 --- a/src/parallel/zipmacro.rs +++ b/src/parallel/zipmacro.rs @@ -24,7 +24,7 @@ /// Is equivalent to: /// /// ```rust,ignore -/// Zip::from(&mut a).and(&b).and(&c).par_apply(|a, &b, &c| { +/// Zip::from(&mut a).and(&b).and(&c).par_for_each(|a, &b, &c| { /// *a = b + c; /// }); /// ``` @@ -34,28 +34,24 @@ /// ## Examples /// /// ```rust -/// extern crate ndarray; -/// /// use ndarray::Array2; /// use ndarray::parallel::par_azip; /// /// type M = Array2; /// -/// fn main() { -/// let mut a = M::zeros((16, 16)); -/// let b = M::from_elem(a.dim(), 1.); -/// let c = M::from_elem(a.dim(), 2.); +/// let mut a = M::zeros((16, 16)); +/// let b = M::from_elem(a.dim(), 1.); +/// let c = M::from_elem(a.dim(), 2.); /// -/// // Compute a simple ternary operation: -/// // elementwise addition of b and c, stored in a +/// // Compute a simple ternary operation: +/// // elementwise addition of b and c, stored in a /// -/// par_azip!((a in &mut a, &b in &b, &c in &c) *a = b + c); +/// par_azip!((a in &mut a, &b in &b, &c in &c) *a = b + c); /// -/// assert_eq!(a, &b + &c); -/// } +/// assert_eq!(a, &b + &c); /// ``` macro_rules! par_azip { ($($t:tt)*) => { - $crate::azip!(@build par_apply $($t)*) + $crate::azip!(@build par_for_each $($t)*) }; } diff --git a/src/partial.rs b/src/partial.rs new file mode 100644 index 000000000..99aba75a8 --- /dev/null +++ b/src/partial.rs @@ -0,0 +1,97 @@ +// Copyright 2020 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::ptr; + +/// Partial is a partially written contiguous slice of data; +/// it is the owner of the elements, but not the allocation, +/// and will drop the elements on drop. +#[must_use] +pub(crate) struct Partial +{ + /// Data pointer + ptr: *mut T, + /// Current length + pub(crate) len: usize, +} + +impl Partial +{ + /// Create an empty partial for this data pointer + /// + /// ## Safety + /// + /// Unless ownership is released, the Partial acts as an owner of the slice of data (not the + /// allocation); and will free the elements on drop; the pointer must be dereferenceable and + /// the `len` elements following it valid. + /// + /// The Partial has an accessible length field which must only be modified in trusted code. + pub(crate) unsafe fn new(ptr: *mut T) -> Self + { + Self { ptr, len: 0 } + } + + #[cfg(feature = "rayon")] + pub(crate) fn stub() -> Self + { + Self { + len: 0, + ptr: ptr::null_mut(), + } + } + + #[cfg(feature = "rayon")] + pub(crate) fn is_stub(&self) -> bool + { + self.ptr.is_null() + } + + /// Release Partial's ownership of the written elements, and return the current length + pub(crate) fn release_ownership(mut self) -> usize + { + let ret = self.len; + self.len = 0; + ret + } + + #[cfg(feature = "rayon")] + /// Merge if they are in order (left to right) and contiguous. + /// Skips merge if T does not need drop. + pub(crate) fn try_merge(mut left: Self, right: Self) -> Self + { + if !std::mem::needs_drop::() { + return left; + } + // Merge the partial collect results; the final result will be a slice that + // covers the whole output. + if left.is_stub() { + right + } else if left.ptr.wrapping_add(left.len) == right.ptr { + left.len += right.release_ownership(); + left + } else { + // failure to merge; this is a bug in collect, so we will never reach this + debug_assert!(false, "Partial: failure to merge left and right parts"); + left + } + } +} + +unsafe impl Send for Partial where T: Send {} + +impl Drop for Partial +{ + fn drop(&mut self) + { + if !self.ptr.is_null() { + unsafe { + ptr::drop_in_place(alloc::slice::from_raw_parts_mut(self.ptr, self.len)); + } + } + } +} diff --git a/src/prelude.rs b/src/prelude.rs index 8662a4d34..acf39da1a 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -12,16 +12,13 @@ //! and macros that you can import easily as a group. //! //! ``` -//! //! use ndarray::prelude::*; +//! +//! # let _ = arr0(1); // use the import //! ``` #[doc(no_inline)] -#[allow(deprecated)] -pub use crate::{ - ArcArray, Array, ArrayBase, ArrayView, ArrayViewMut, CowArray, RawArrayView, RawArrayViewMut, - RcArray, -}; +pub use crate::{ArcArray, Array, ArrayBase, ArrayView, ArrayViewMut, CowArray, RawArrayView, RawArrayViewMut}; #[doc(no_inline)] pub use crate::{Axis, Dim, Dimension}; @@ -30,14 +27,18 @@ pub use crate::{Axis, Dim, Dimension}; pub use crate::{Array0, Array1, Array2, Array3, Array4, Array5, Array6, ArrayD}; #[doc(no_inline)] -pub use crate::{ - ArrayView0, ArrayView1, ArrayView2, ArrayView3, ArrayView4, ArrayView5, ArrayView6, ArrayViewD, -}; +pub use crate::{ArrayView0, ArrayView1, ArrayView2, ArrayView3, ArrayView4, ArrayView5, ArrayView6, ArrayViewD}; #[doc(no_inline)] pub use crate::{ - ArrayViewMut0, ArrayViewMut1, ArrayViewMut2, ArrayViewMut3, ArrayViewMut4, ArrayViewMut5, - ArrayViewMut6, ArrayViewMutD, + ArrayViewMut0, + ArrayViewMut1, + ArrayViewMut2, + ArrayViewMut3, + ArrayViewMut4, + ArrayViewMut5, + ArrayViewMut6, + ArrayViewMutD, }; #[doc(no_inline)] @@ -52,4 +53,11 @@ pub use crate::{array, azip, s}; pub use crate::ShapeBuilder; #[doc(no_inline)] -pub use crate::{AsArray, NdFloat}; +pub use crate::NewAxis; + +#[doc(no_inline)] +pub use crate::AsArray; + +#[doc(no_inline)] +#[cfg(feature = "std")] +pub use crate::NdFloat; diff --git a/src/private.rs b/src/private.rs index ea13164e4..9dade0c48 100644 --- a/src/private.rs +++ b/src/private.rs @@ -11,15 +11,15 @@ macro_rules! private_decl { () => { /// This trait is private to implement; this method exists to make it /// impossible to implement outside the crate. + #[doc(hidden)] fn __private__(&self) -> crate::private::PrivateMarker; - } + }; } macro_rules! private_impl { () => { - #[doc(hidden)] fn __private__(&self) -> crate::private::PrivateMarker { crate::private::PrivateMarker } - } + }; } diff --git a/src/shape_builder.rs b/src/shape_builder.rs index bb5a949ab..cd790a25f 100644 --- a/src/shape_builder.rs +++ b/src/shape_builder.rs @@ -1,29 +1,117 @@ use crate::dimension::IntoDimension; +use crate::order::Order; use crate::Dimension; -use crate::{Shape, StrideShape}; + +/// A contiguous array shape of n dimensions. +/// +/// Either c- or f- memory ordered (*c* a.k.a *row major* is the default). +#[derive(Copy, Clone, Debug)] +pub struct Shape +{ + /// Shape (axis lengths) + pub(crate) dim: D, + /// Strides can only be C or F here + pub(crate) strides: Strides, +} + +#[derive(Copy, Clone, Debug)] +pub(crate) enum Contiguous {} + +impl Shape +{ + pub(crate) fn is_c(&self) -> bool + { + matches!(self.strides, Strides::C) + } +} + +/// An array shape of n dimensions in c-order, f-order or custom strides. +#[derive(Copy, Clone, Debug)] +pub struct StrideShape +{ + pub(crate) dim: D, + pub(crate) strides: Strides, +} + +impl StrideShape +where D: Dimension +{ + /// Return a reference to the dimension + pub fn raw_dim(&self) -> &D + { + &self.dim + } + /// Return the size of the shape in number of elements + pub fn size(&self) -> usize + { + self.dim.size() + } +} + +/// Stride description +#[derive(Copy, Clone, Debug)] +pub(crate) enum Strides +{ + /// Row-major ("C"-order) + C, + /// Column-major ("F"-order) + F, + /// Custom strides + Custom(D), +} + +impl Strides +{ + /// Return strides for `dim` (computed from dimension if c/f, else return the custom stride) + pub(crate) fn strides_for_dim(self, dim: &D) -> D + where D: Dimension + { + match self { + Strides::C => dim.default_strides(), + Strides::F => dim.fortran_strides(), + Strides::Custom(c) => { + debug_assert_eq!( + c.ndim(), + dim.ndim(), + "Custom strides given with {} dimensions, expected {}", + c.ndim(), + dim.ndim() + ); + c + } + } + } + + #[inline] + pub(crate) fn is_custom(&self) -> bool + { + matches!(*self, Strides::Custom(_)) + } +} /// A trait for `Shape` and `D where D: Dimension` that allows /// customizing the memory layout (strides) of an array shape. /// /// This trait is used together with array constructor methods like /// `Array::from_shape_vec`. -pub trait ShapeBuilder { +pub trait ShapeBuilder +{ type Dim: Dimension; type Strides; - fn into_shape(self) -> Shape; + fn into_shape_with_order(self) -> Shape; fn f(self) -> Shape; fn set_f(self, is_f: bool) -> Shape; fn strides(self, strides: Self::Strides) -> StrideShape; } impl From for Shape -where - D: Dimension, +where D: Dimension { /// Create a `Shape` from `dimension`, using the default memory layout. - fn from(dimension: D) -> Shape { - dimension.into_shape() + fn from(dimension: D) -> Shape + { + dimension.into_shape_with_order() } } @@ -32,94 +120,121 @@ where D: Dimension, T: ShapeBuilder, { - fn from(value: T) -> Self { - let shape = value.into_shape(); - let d = shape.dim; - let st = if shape.is_c { - d.default_strides() - } else { - d.fortran_strides() - }; - StrideShape { - strides: st, - dim: d, - custom: false, - } - } -} - -/* -impl From> for StrideShape - where D: Dimension -{ - fn from(shape: Shape) -> Self { - let d = shape.dim; - let st = if shape.is_c { d.default_strides() } else { d.fortran_strides() }; + fn from(value: T) -> Self + { + let shape = value.into_shape_with_order(); + let st = if shape.is_c() { Strides::C } else { Strides::F }; StrideShape { strides: st, - dim: d, - custom: false, + dim: shape.dim, } } } -*/ impl ShapeBuilder for T -where - T: IntoDimension, +where T: IntoDimension { type Dim = T::Dim; type Strides = T; - fn into_shape(self) -> Shape { + fn into_shape_with_order(self) -> Shape + { Shape { dim: self.into_dimension(), - is_c: true, + strides: Strides::C, } } - fn f(self) -> Shape { + fn f(self) -> Shape + { self.set_f(true) } - fn set_f(self, is_f: bool) -> Shape { - self.into_shape().set_f(is_f) + fn set_f(self, is_f: bool) -> Shape + { + self.into_shape_with_order().set_f(is_f) } - fn strides(self, st: T) -> StrideShape { - self.into_shape().strides(st.into_dimension()) + fn strides(self, st: T) -> StrideShape + { + self.into_shape_with_order().strides(st.into_dimension()) } } impl ShapeBuilder for Shape -where - D: Dimension, +where D: Dimension { type Dim = D; type Strides = D; - fn into_shape(self) -> Shape { + + fn into_shape_with_order(self) -> Shape + { self } - fn f(self) -> Self { + + fn f(self) -> Self + { self.set_f(true) } - fn set_f(mut self, is_f: bool) -> Self { - self.is_c = !is_f; + + fn set_f(mut self, is_f: bool) -> Self + { + self.strides = if !is_f { Strides::C } else { Strides::F }; self } - fn strides(self, st: D) -> StrideShape { + + fn strides(self, st: D) -> StrideShape + { StrideShape { dim: self.dim, - strides: st, - custom: true, + strides: Strides::Custom(st), } } } impl Shape -where - D: Dimension, +where D: Dimension { - // Return a reference to the dimension - //pub fn dimension(&self) -> &D { &self.dim } + /// Return a reference to the dimension + pub fn raw_dim(&self) -> &D + { + &self.dim + } /// Return the size of the shape in number of elements - pub fn size(&self) -> usize { + pub fn size(&self) -> usize + { self.dim.size() } } + +/// Array shape argument with optional order parameter +/// +/// Shape or array dimension argument, with optional [`Order`] parameter. +/// +/// This is an argument conversion trait that is used to accept an array shape and +/// (optionally) an ordering argument. +/// +/// See for example [`.to_shape()`](crate::ArrayBase::to_shape). +pub trait ShapeArg +{ + type Dim: Dimension; + fn into_shape_and_order(self) -> (Self::Dim, Option); +} + +impl ShapeArg for T +where T: IntoDimension +{ + type Dim = T::Dim; + + fn into_shape_and_order(self) -> (Self::Dim, Option) + { + (self.into_dimension(), None) + } +} + +impl ShapeArg for (T, Order) +where T: IntoDimension +{ + type Dim = T::Dim; + + fn into_shape_and_order(self) -> (Self::Dim, Option) + { + (self.0.into_dimension(), Some(self.1)) + } +} diff --git a/src/slice.rs b/src/slice.rs index 58bff48b5..e6c237a92 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -7,18 +7,24 @@ // except according to those terms. use crate::dimension::slices_intersect; use crate::error::{ErrorKind, ShapeError}; -use crate::{ArrayViewMut, Dimension}; +#[cfg(doc)] +use crate::s; +use crate::{ArrayViewMut, DimAdd, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; + +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; +use std::convert::TryFrom; use std::fmt; use std::marker::PhantomData; use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; /// A slice (range with step size). /// -/// `end` is an exclusive index. Negative `begin` or `end` indexes are counted +/// `end` is an exclusive index. Negative `start` or `end` indexes are counted /// from the back of the axis. If `end` is `None`, the slice extends to the end /// of the axis. /// -/// See also the [`s![]`](macro.s.html) macro. +/// See also the [`s![]`](s!) macro. /// /// ## Examples /// @@ -33,13 +39,19 @@ use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, Rang /// reverse order. It can also be created with `Slice::from(a..).step_by(-1)`. /// The Python equivalent is `[a::-1]`. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -pub struct Slice { +pub struct Slice +{ + /// start index; negative are counted from the back of the axis pub start: isize, + /// end index; negative are counted from the back of the axis; when not present + /// the default is the full length of the axis. pub end: Option, + /// step size in elements; the default is 1, for every element. pub step: isize, } -impl Slice { +impl Slice +{ /// Create a new `Slice` with the given extents. /// /// See also the `From` impls, converting from ranges; for example @@ -47,7 +59,8 @@ impl Slice { /// /// `step` must be nonzero. /// (This method checks with a debug assertion that `step` is not zero.) - pub fn new(start: isize, end: Option, step: isize) -> Slice { + pub fn new(start: isize, end: Option, step: isize) -> Slice + { debug_assert_ne!(step, 0, "Slice::new: step must be nonzero"); Slice { start, end, step } } @@ -58,7 +71,8 @@ impl Slice { /// `step` must be nonzero. /// (This method checks with a debug assertion that `step` is not zero.) #[inline] - pub fn step_by(self, step: isize) -> Self { + pub fn step_by(self, step: isize) -> Self + { debug_assert_ne!(step, 0, "Slice::step_by: step must be nonzero"); Slice { step: self.step * step, @@ -67,91 +81,92 @@ impl Slice { } } -/// A slice (range with step) or an index. +/// Token to represent a new axis in a slice description. /// -/// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a -/// `&SliceInfo<[SliceOrIndex; n], D>`. +/// See also the [`s![]`](s!) macro. +#[derive(Clone, Copy, Debug)] +pub struct NewAxis; + +/// A slice (range with step), an index, or a new axis token. +/// +/// See also the [`s![]`](s!) macro for a convenient way to create a +/// `SliceInfo<[SliceInfoElem; n], Din, Dout>`. /// /// ## Examples /// -/// `SliceOrIndex::Index(a)` is the index `a`. It can also be created with -/// `SliceOrIndex::from(a)`. The Python equivalent is `[a]`. The macro +/// `SliceInfoElem::Index(a)` is the index `a`. It can also be created with +/// `SliceInfoElem::from(a)`. The Python equivalent is `[a]`. The macro /// equivalent is `s![a]`. /// -/// `SliceOrIndex::Slice { start: 0, end: None, step: 1 }` is the full range of -/// an axis. It can also be created with `SliceOrIndex::from(..)`. The Python -/// equivalent is `[:]`. The macro equivalent is `s![..]`. +/// `SliceInfoElem::Slice { start: 0, end: None, step: 1 }` is the full range +/// of an axis. It can also be created with `SliceInfoElem::from(..)`. The +/// Python equivalent is `[:]`. The macro equivalent is `s![..]`. /// -/// `SliceOrIndex::Slice { start: a, end: Some(b), step: 2 }` is every second +/// `SliceInfoElem::Slice { start: a, end: Some(b), step: 2 }` is every second /// element from `a` until `b`. It can also be created with -/// `SliceOrIndex::from(a..b).step_by(2)`. The Python equivalent is `[a:b:2]`. -/// The macro equivalent is `s![a..b;2]`. +/// `SliceInfoElem::from(Slice::from(a..b).step_by(2))`. The Python equivalent +/// is `[a:b:2]`. The macro equivalent is `s![a..b;2]`. /// -/// `SliceOrIndex::Slice { start: a, end: None, step: -1 }` is every element, +/// `SliceInfoElem::Slice { start: a, end: None, step: -1 }` is every element, /// from `a` until the end, in reverse order. It can also be created with -/// `SliceOrIndex::from(a..).step_by(-1)`. The Python equivalent is `[a::-1]`. -/// The macro equivalent is `s![a..;-1]`. +/// `SliceInfoElem::from(Slice::from(a..).step_by(-1))`. The Python equivalent +/// is `[a::-1]`. The macro equivalent is `s![a..;-1]`. +/// +/// `SliceInfoElem::NewAxis` is a new axis of length 1. It can also be created +/// with `SliceInfoElem::from(NewAxis)`. The Python equivalent is +/// `[np.newaxis]`. The macro equivalent is `s![NewAxis]`. #[derive(Debug, PartialEq, Eq, Hash)] -pub enum SliceOrIndex { - /// A range with step size. `end` is an exclusive index. Negative `begin` +pub enum SliceInfoElem +{ + /// A range with step size. `end` is an exclusive index. Negative `start` /// or `end` indexes are counted from the back of the axis. If `end` is /// `None`, the slice extends to the end of the axis. - Slice { + Slice + { + /// start index; negative are counted from the back of the axis start: isize, + /// end index; negative are counted from the back of the axis; when not present + /// the default is the full length of the axis. end: Option, + /// step size in elements; the default is 1, for every element. step: isize, }, /// A single index. Index(isize), + /// A new axis of length 1. + NewAxis, } -copy_and_clone! {SliceOrIndex} +copy_and_clone! {SliceInfoElem} -impl SliceOrIndex { +impl SliceInfoElem +{ /// Returns `true` if `self` is a `Slice` value. - pub fn is_slice(&self) -> bool { - match self { - SliceOrIndex::Slice { .. } => true, - _ => false, - } + pub fn is_slice(&self) -> bool + { + matches!(self, SliceInfoElem::Slice { .. }) } /// Returns `true` if `self` is an `Index` value. - pub fn is_index(&self) -> bool { - match self { - SliceOrIndex::Index(_) => true, - _ => false, - } + pub fn is_index(&self) -> bool + { + matches!(self, SliceInfoElem::Index(_)) } - /// Returns a new `SliceOrIndex` with the given step size (multiplied with - /// the previous step size). - /// - /// `step` must be nonzero. - /// (This method checks with a debug assertion that `step` is not zero.) - #[inline] - pub fn step_by(self, step: isize) -> Self { - debug_assert_ne!(step, 0, "SliceOrIndex::step_by: step must be nonzero"); - match self { - SliceOrIndex::Slice { - start, - end, - step: orig_step, - } => SliceOrIndex::Slice { - start, - end, - step: orig_step * step, - }, - SliceOrIndex::Index(s) => SliceOrIndex::Index(s), - } + /// Returns `true` if `self` is a `NewAxis` value. + pub fn is_new_axis(&self) -> bool + { + matches!(self, SliceInfoElem::NewAxis) } } -impl fmt::Display for SliceOrIndex { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl fmt::Display for SliceInfoElem +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { match *self { - SliceOrIndex::Index(index) => write!(f, "{}", index)?, - SliceOrIndex::Slice { start, end, step } => { + SliceInfoElem::Index(index) => write!(f, "{}", index)?, + SliceInfoElem::Slice { start, end, step } => { if start != 0 { write!(f, "{}", start)?; } @@ -163,6 +178,7 @@ impl fmt::Display for SliceOrIndex { write!(f, ";{}", step)?; } } + SliceInfoElem::NewAxis => write!(f, stringify!(NewAxis))?, } Ok(()) } @@ -231,13 +247,15 @@ macro_rules! impl_slice_variant_from_range { impl_slice_variant_from_range!(Slice, Slice, isize); impl_slice_variant_from_range!(Slice, Slice, usize); impl_slice_variant_from_range!(Slice, Slice, i32); -impl_slice_variant_from_range!(SliceOrIndex, SliceOrIndex::Slice, isize); -impl_slice_variant_from_range!(SliceOrIndex, SliceOrIndex::Slice, usize); -impl_slice_variant_from_range!(SliceOrIndex, SliceOrIndex::Slice, i32); +impl_slice_variant_from_range!(SliceInfoElem, SliceInfoElem::Slice, isize); +impl_slice_variant_from_range!(SliceInfoElem, SliceInfoElem::Slice, usize); +impl_slice_variant_from_range!(SliceInfoElem, SliceInfoElem::Slice, i32); -impl From for Slice { +impl From for Slice +{ #[inline] - fn from(_: RangeFull) -> Slice { + fn from(_: RangeFull) -> Slice + { Slice { start: 0, end: None, @@ -246,10 +264,12 @@ impl From for Slice { } } -impl From for SliceOrIndex { +impl From for SliceInfoElem +{ #[inline] - fn from(_: RangeFull) -> SliceOrIndex { - SliceOrIndex::Slice { + fn from(_: RangeFull) -> SliceInfoElem + { + SliceInfoElem::Slice { start: 0, end: None, step: 1, @@ -257,10 +277,12 @@ impl From for SliceOrIndex { } } -impl From for SliceOrIndex { +impl From for SliceInfoElem +{ #[inline] - fn from(s: Slice) -> SliceOrIndex { - SliceOrIndex::Slice { + fn from(s: Slice) -> SliceInfoElem + { + SliceInfoElem::Slice { start: s.start, end: s.end, step: s.step, @@ -268,198 +290,443 @@ impl From for SliceOrIndex { } } -macro_rules! impl_sliceorindex_from_index { +macro_rules! impl_sliceinfoelem_from_index { ($index:ty) => { - impl From<$index> for SliceOrIndex { + impl From<$index> for SliceInfoElem { #[inline] - fn from(r: $index) -> SliceOrIndex { - SliceOrIndex::Index(r as isize) + fn from(r: $index) -> SliceInfoElem { + SliceInfoElem::Index(r as isize) } } }; } -impl_sliceorindex_from_index!(isize); -impl_sliceorindex_from_index!(usize); -impl_sliceorindex_from_index!(i32); +impl_sliceinfoelem_from_index!(isize); +impl_sliceinfoelem_from_index!(usize); +impl_sliceinfoelem_from_index!(i32); + +impl From for SliceInfoElem +{ + #[inline] + fn from(_: NewAxis) -> SliceInfoElem + { + SliceInfoElem::NewAxis + } +} + +/// A type that can slice an array of dimension `D`. +/// +/// This trait is unsafe to implement because the implementation must ensure +/// that `D`, `Self::OutDim`, `self.in_dim()`, and `self.out_ndim()` are +/// consistent with the `&[SliceInfoElem]` returned by `self.as_ref()` and that +/// `self.as_ref()` always returns the same value when called multiple times. +#[allow(clippy::missing_safety_doc)] // not implementable downstream +pub unsafe trait SliceArg: AsRef<[SliceInfoElem]> +{ + /// Dimensionality of the output array. + type OutDim: Dimension; + + /// Returns the number of axes in the input array. + fn in_ndim(&self) -> usize; + + /// Returns the number of axes in the output array. + fn out_ndim(&self) -> usize; + + private_decl! {} +} + +unsafe impl SliceArg for &T +where + T: SliceArg + ?Sized, + D: Dimension, +{ + type OutDim = T::OutDim; + + fn in_ndim(&self) -> usize + { + T::in_ndim(self) + } + + fn out_ndim(&self) -> usize + { + T::out_ndim(self) + } + + private_impl! {} +} + +macro_rules! impl_slicearg_samedim { + ($in_dim:ty) => { + unsafe impl SliceArg<$in_dim> for SliceInfo + where + T: AsRef<[SliceInfoElem]>, + Dout: Dimension, + { + type OutDim = Dout; + + fn in_ndim(&self) -> usize { + self.in_ndim() + } + + fn out_ndim(&self) -> usize { + self.out_ndim() + } + + private_impl! {} + } + }; +} +impl_slicearg_samedim!(Ix0); +impl_slicearg_samedim!(Ix1); +impl_slicearg_samedim!(Ix2); +impl_slicearg_samedim!(Ix3); +impl_slicearg_samedim!(Ix4); +impl_slicearg_samedim!(Ix5); +impl_slicearg_samedim!(Ix6); + +unsafe impl SliceArg for SliceInfo +where + T: AsRef<[SliceInfoElem]>, + Din: Dimension, + Dout: Dimension, +{ + type OutDim = Dout; + + fn in_ndim(&self) -> usize + { + self.in_ndim() + } + + fn out_ndim(&self) -> usize + { + self.out_ndim() + } + + private_impl! {} +} + +unsafe impl SliceArg for [SliceInfoElem] +{ + type OutDim = IxDyn; + + fn in_ndim(&self) -> usize + { + self.iter().filter(|s| !s.is_new_axis()).count() + } + + fn out_ndim(&self) -> usize + { + self.iter().filter(|s| !s.is_index()).count() + } + + private_impl! {} +} /// Represents all of the necessary information to perform a slice. /// -/// The type `T` is typically `[SliceOrIndex; n]`, `[SliceOrIndex]`, or -/// `Vec`. The type `D` is the output dimension after calling -/// [`.slice()`]. +/// The type `T` is typically `[SliceInfoElem; n]`, `&[SliceInfoElem]`, or +/// `Vec`. The type `Din` is the dimension of the array to be +/// sliced, and `Dout` is the output dimension after calling [`.slice()`]. Note +/// that if `Din` is a fixed dimension type (`Ix0`, `Ix1`, `Ix2`, etc.), the +/// `SliceInfo` instance can still be used to slice an array with dimension +/// `IxDyn` as long as the number of axes matches. /// -/// [`.slice()`]: struct.ArrayBase.html#method.slice +/// [`.slice()`]: crate::ArrayBase::slice #[derive(Debug)] -#[repr(C)] -pub struct SliceInfo { - out_dim: PhantomData, +pub struct SliceInfo +{ + in_dim: PhantomData, + out_dim: PhantomData, indices: T, } -impl Deref for SliceInfo +impl Deref for SliceInfo where - D: Dimension, + Din: Dimension, + Dout: Dimension, { type Target = T; - fn deref(&self) -> &Self::Target { + fn deref(&self) -> &Self::Target + { &self.indices } } -impl SliceInfo +fn check_dims_for_sliceinfo(indices: &[SliceInfoElem]) -> Result<(), ShapeError> where - D: Dimension, + Din: Dimension, + Dout: Dimension, { - /// Returns a new `SliceInfo` instance. - /// - /// If you call this method, you are guaranteeing that `out_dim` is - /// consistent with `indices`. - #[doc(hidden)] - pub unsafe fn new_unchecked(indices: T, out_dim: PhantomData) -> SliceInfo { - SliceInfo { out_dim, indices } + if let Some(in_ndim) = Din::NDIM { + if in_ndim != indices.in_ndim() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + } + if let Some(out_ndim) = Dout::NDIM { + if out_ndim != indices.out_ndim() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } } + Ok(()) } -impl SliceInfo +impl SliceInfo where - T: AsRef<[SliceOrIndex]>, - D: Dimension, + T: AsRef<[SliceInfoElem]>, + Din: Dimension, + Dout: Dimension, { /// Returns a new `SliceInfo` instance. /// - /// Errors if `D` is not consistent with `indices`. - pub fn new(indices: T) -> Result, ShapeError> { - if let Some(ndim) = D::NDIM { - if ndim != indices.as_ref().iter().filter(|s| s.is_slice()).count() { - return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); - } + /// **Note:** only unchecked for non-debug builds of `ndarray`. + /// + /// # Safety + /// + /// The caller must ensure that `in_dim` and `out_dim` are consistent with + /// `indices` and that `indices.as_ref()` always returns the same value + /// when called multiple times. + #[doc(hidden)] + pub unsafe fn new_unchecked( + indices: T, in_dim: PhantomData, out_dim: PhantomData, + ) -> SliceInfo + { + if cfg!(debug_assertions) { + check_dims_for_sliceinfo::(indices.as_ref()) + .expect("`Din` and `Dout` must be consistent with `indices`."); } + SliceInfo { + in_dim, + out_dim, + indices, + } + } + + /// Returns a new `SliceInfo` instance. + /// + /// Errors if `Din` or `Dout` is not consistent with `indices`. + /// + /// For common types, a safe alternative is to use `TryFrom` instead. + /// + /// # Safety + /// + /// The caller must ensure `indices.as_ref()` always returns the same value + /// when called multiple times. + pub unsafe fn new(indices: T) -> Result, ShapeError> + { + check_dims_for_sliceinfo::(indices.as_ref())?; Ok(SliceInfo { + in_dim: PhantomData, out_dim: PhantomData, indices, }) } -} -impl SliceInfo -where - T: AsRef<[SliceOrIndex]>, - D: Dimension, -{ + /// Returns the number of dimensions of the input array for + /// [`.slice()`](crate::ArrayBase::slice). + /// + /// If `Din` is a fixed-size dimension type, then this is equivalent to + /// `Din::NDIM.unwrap()`. Otherwise, the value is calculated by iterating + /// over the `SliceInfoElem` elements. + pub fn in_ndim(&self) -> usize + { + if let Some(ndim) = Din::NDIM { + ndim + } else { + self.indices.as_ref().in_ndim() + } + } + /// Returns the number of dimensions after calling - /// [`.slice()`](struct.ArrayBase.html#method.slice) (including taking + /// [`.slice()`](crate::ArrayBase::slice) (including taking /// subviews). /// - /// If `D` is a fixed-size dimension type, then this is equivalent to - /// `D::NDIM.unwrap()`. Otherwise, the value is calculated by iterating - /// over the ranges/indices. - pub fn out_ndim(&self) -> usize { - D::NDIM.unwrap_or_else(|| { - self.indices - .as_ref() - .iter() - .filter(|s| s.is_slice()) - .count() - }) + /// If `Dout` is a fixed-size dimension type, then this is equivalent to + /// `Dout::NDIM.unwrap()`. Otherwise, the value is calculated by iterating + /// over the `SliceInfoElem` elements. + pub fn out_ndim(&self) -> usize + { + if let Some(ndim) = Dout::NDIM { + ndim + } else { + self.indices.as_ref().out_ndim() + } } } -impl AsRef<[SliceOrIndex]> for SliceInfo +impl<'a, Din, Dout> TryFrom<&'a [SliceInfoElem]> for SliceInfo<&'a [SliceInfoElem], Din, Dout> where - T: AsRef<[SliceOrIndex]>, - D: Dimension, + Din: Dimension, + Dout: Dimension, { - fn as_ref(&self) -> &[SliceOrIndex] { - self.indices.as_ref() + type Error = ShapeError; + + fn try_from(indices: &'a [SliceInfoElem]) -> Result, ShapeError> + { + unsafe { + // This is okay because `&[SliceInfoElem]` always returns the same + // value for `.as_ref()`. + Self::new(indices) + } } } -impl AsRef> for SliceInfo +impl TryFrom> for SliceInfo, Din, Dout> where - T: AsRef<[SliceOrIndex]>, - D: Dimension, + Din: Dimension, + Dout: Dimension, { - fn as_ref(&self) -> &SliceInfo<[SliceOrIndex], D> { + type Error = ShapeError; + + fn try_from(indices: Vec) -> Result, Din, Dout>, ShapeError> + { unsafe { - // This is okay because the only non-zero-sized member of - // `SliceInfo` is `indices`, so `&SliceInfo<[SliceOrIndex], D>` - // should have the same bitwise representation as - // `&[SliceOrIndex]`. - &*(self.indices.as_ref() as *const [SliceOrIndex] - as *const SliceInfo<[SliceOrIndex], D>) + // This is okay because `Vec` always returns the same value for + // `.as_ref()`. + Self::new(indices) + } + } +} + +macro_rules! impl_tryfrom_array_for_sliceinfo { + ($len:expr) => { + impl TryFrom<[SliceInfoElem; $len]> + for SliceInfo<[SliceInfoElem; $len], Din, Dout> + where + Din: Dimension, + Dout: Dimension, + { + type Error = ShapeError; + + fn try_from( + indices: [SliceInfoElem; $len], + ) -> Result, ShapeError> { + unsafe { + // This is okay because `[SliceInfoElem; N]` always returns + // the same value for `.as_ref()`. + Self::new(indices) + } + } + } + }; +} +impl_tryfrom_array_for_sliceinfo!(0); +impl_tryfrom_array_for_sliceinfo!(1); +impl_tryfrom_array_for_sliceinfo!(2); +impl_tryfrom_array_for_sliceinfo!(3); +impl_tryfrom_array_for_sliceinfo!(4); +impl_tryfrom_array_for_sliceinfo!(5); +impl_tryfrom_array_for_sliceinfo!(6); +impl_tryfrom_array_for_sliceinfo!(7); +impl_tryfrom_array_for_sliceinfo!(8); + +impl AsRef<[SliceInfoElem]> for SliceInfo +where + T: AsRef<[SliceInfoElem]>, + Din: Dimension, + Dout: Dimension, +{ + fn as_ref(&self) -> &[SliceInfoElem] + { + self.indices.as_ref() + } +} + +impl<'a, T, Din, Dout> From<&'a SliceInfo> for SliceInfo<&'a [SliceInfoElem], Din, Dout> +where + T: AsRef<[SliceInfoElem]>, + Din: Dimension, + Dout: Dimension, +{ + fn from(info: &'a SliceInfo) -> SliceInfo<&'a [SliceInfoElem], Din, Dout> + { + SliceInfo { + in_dim: info.in_dim, + out_dim: info.out_dim, + indices: info.indices.as_ref(), } } } -impl Copy for SliceInfo +impl Copy for SliceInfo where T: Copy, - D: Dimension, + Din: Dimension, + Dout: Dimension, { } -impl Clone for SliceInfo +impl Clone for SliceInfo where T: Clone, - D: Dimension, + Din: Dimension, + Dout: Dimension, { - fn clone(&self) -> Self { + fn clone(&self) -> Self + { SliceInfo { + in_dim: PhantomData, out_dim: PhantomData, indices: self.indices.clone(), } } } +/// Trait for determining dimensionality of input and output for [`s!`] macro. #[doc(hidden)] -pub trait SliceNextDim { - fn next_dim(&self, _: PhantomData) -> PhantomData; +pub trait SliceNextDim +{ + /// Number of dimensions that this slicing argument consumes in the input array. + type InDim: Dimension; + /// Number of dimensions that this slicing argument produces in the output array. + type OutDim: Dimension; + + fn next_in_dim(&self, _: PhantomData) -> PhantomData<>::Output> + where D: Dimension + DimAdd + { + PhantomData + } + + fn next_out_dim(&self, _: PhantomData) -> PhantomData<>::Output> + where D: Dimension + DimAdd + { + PhantomData + } } -macro_rules! impl_slicenextdim_equal { - ($self:ty) => { - impl SliceNextDim for $self { - fn next_dim(&self, _: PhantomData) -> PhantomData { - PhantomData - } +macro_rules! impl_slicenextdim { + (($($generics:tt)*), $self:ty, $in:ty, $out:ty) => { + impl<$($generics)*> SliceNextDim for $self { + type InDim = $in; + type OutDim = $out; } }; } -impl_slicenextdim_equal!(isize); -impl_slicenextdim_equal!(usize); -impl_slicenextdim_equal!(i32); -macro_rules! impl_slicenextdim_larger { - (($($generics:tt)*), $self:ty) => { - impl SliceNextDim for $self { - fn next_dim(&self, _: PhantomData) -> PhantomData { - PhantomData - } - } - } -} -impl_slicenextdim_larger!((T), Range); -impl_slicenextdim_larger!((T), RangeInclusive); -impl_slicenextdim_larger!((T), RangeFrom); -impl_slicenextdim_larger!((T), RangeTo); -impl_slicenextdim_larger!((T), RangeToInclusive); -impl_slicenextdim_larger!((), RangeFull); -impl_slicenextdim_larger!((), Slice); +impl_slicenextdim!((), isize, Ix1, Ix0); +impl_slicenextdim!((), usize, Ix1, Ix0); +impl_slicenextdim!((), i32, Ix1, Ix0); + +impl_slicenextdim!((T), Range, Ix1, Ix1); +impl_slicenextdim!((T), RangeInclusive, Ix1, Ix1); +impl_slicenextdim!((T), RangeFrom, Ix1, Ix1); +impl_slicenextdim!((T), RangeTo, Ix1, Ix1); +impl_slicenextdim!((T), RangeToInclusive, Ix1, Ix1); +impl_slicenextdim!((), RangeFull, Ix1, Ix1); +impl_slicenextdim!((), Slice, Ix1, Ix1); + +impl_slicenextdim!((), NewAxis, Ix0, Ix1); /// Slice argument constructor. /// -/// `s![]` takes a list of ranges/slices/indices, separated by comma, with -/// optional step sizes that are separated from the range by a semicolon. It is -/// converted into a [`&SliceInfo`] instance. -/// -/// [`&SliceInfo`]: struct.SliceInfo.html +/// `s![]` takes a list of ranges/slices/indices/new-axes, separated by comma, +/// with optional step sizes that are separated from the range by a semicolon. +/// It is converted into a [`SliceInfo`] instance. /// /// Each range/slice/index uses signed indices, where a negative value is /// counted from the end of the axis. Step sizes are also signed and may be /// negative, but must not be zero. /// -/// The syntax is `s![` *[ axis-slice-or-index [, axis-slice-or-index [ , ... ] -/// ] ]* `]`, where *axis-slice-or-index* is any of the following: +/// The syntax is `s![` *[ elem [, elem [ , ... ] ] ]* `]`, where *elem* is any +/// of the following: /// /// * *index*: an index to use for taking a subview with respect to that axis. /// (The index is selected. The axis is removed except with @@ -467,37 +734,37 @@ impl_slicenextdim_larger!((), Slice); /// * *range*: a range with step size 1 to use for slicing that axis. /// * *range* `;` *step*: a range with step size *step* to use for slicing that axis. /// * *slice*: a [`Slice`] instance to use for slicing that axis. -/// * *slice* `;` *step*: a range constructed from the start and end of a [`Slice`] -/// instance, with new step size *step*, to use for slicing that axis. -/// -/// [`Slice`]: struct.Slice.html -/// -/// The number of *axis-slice-or-index* must match the number of axes in the -/// array. *index*, *range*, *slice*, and *step* can be expressions. *index* -/// must be of type `isize`, `usize`, or `i32`. *range* must be of type -/// `Range`, `RangeTo`, `RangeFrom`, or `RangeFull` where `I` is -/// `isize`, `usize`, or `i32`. *step* must be a type that can be converted to -/// `isize` with the `as` keyword. -/// -/// For example `s![0..4;2, 6, 1..5]` is a slice of the first axis for 0..4 -/// with step size 2, a subview of the second axis at index 6, and a slice of -/// the third axis for 1..5 with default step size 1. The input array must have -/// 3 dimensions. The resulting slice would have shape `[2, 4]` for -/// [`.slice()`], [`.slice_mut()`], and [`.slice_move()`], and shape -/// `[2, 1, 4]` for [`.slice_collapse()`]. -/// -/// [`.slice()`]: struct.ArrayBase.html#method.slice -/// [`.slice_mut()`]: struct.ArrayBase.html#method.slice_mut -/// [`.slice_move()`]: struct.ArrayBase.html#method.slice_move -/// [`.slice_collapse()`]: struct.ArrayBase.html#method.slice_collapse -/// -/// See also [*Slicing*](struct.ArrayBase.html#slicing). +/// * *slice* `;` *step*: a range constructed from a [`Slice`] instance, +/// multiplying the step size by *step*, to use for slicing that axis. +/// * *new-axis*: a [`NewAxis`] instance that represents the creation of a new axis. +/// (Except for [`.slice_collapse()`], which panics on [`NewAxis`] elements.) +/// +/// The number of *elem*, not including *new-axis*, must match the +/// number of axes in the array. *index*, *range*, *slice*, *step*, and +/// *new-axis* can be expressions. *index* must be of type `isize`, `usize`, or +/// `i32`. *range* must be of type `Range`, `RangeTo`, `RangeFrom`, or +/// `RangeFull` where `I` is `isize`, `usize`, or `i32`. *step* must be a type +/// that can be converted to `isize` with the `as` keyword. +/// +/// For example, `s![0..4;2, 6, 1..5, NewAxis]` is a slice of the first axis +/// for 0..4 with step size 2, a subview of the second axis at index 6, a slice +/// of the third axis for 1..5 with default step size 1, and a new axis of +/// length 1 at the end of the shape. The input array must have 3 dimensions. +/// The resulting slice would have shape `[2, 4, 1]` for [`.slice()`], +/// [`.slice_mut()`], and [`.slice_move()`], while [`.slice_collapse()`] would +/// panic. Without the `NewAxis`, i.e. `s![0..4;2, 6, 1..5]`, +/// [`.slice_collapse()`] would result in an array of shape `[2, 1, 4]`. +/// +/// [`.slice()`]: crate::ArrayBase::slice +/// [`.slice_mut()`]: crate::ArrayBase::slice_mut +/// [`.slice_move()`]: crate::ArrayBase::slice_move +/// [`.slice_collapse()`]: crate::ArrayBase::slice_collapse +/// +/// See also [*Slicing*](crate::ArrayBase#slicing). /// /// # Example /// /// ``` -/// extern crate ndarray; -/// /// use ndarray::{s, Array2, ArrayView2}; /// /// fn laplacian(v: &ArrayView2) -> Array2 { @@ -507,7 +774,7 @@ impl_slicenextdim_larger!((), Slice); /// + v.slice(s![1..-1, 2.. ]) /// + v.slice(s![2.. , 1..-1]) /// } -/// # fn main() { } +/// # fn main() { let _ = laplacian; } /// ``` /// /// # Negative *step* @@ -528,8 +795,6 @@ impl_slicenextdim_larger!((), Slice); /// For example, /// /// ``` -/// # extern crate ndarray; -/// # /// # use ndarray::prelude::*; /// # /// # fn main() { @@ -544,49 +809,48 @@ impl_slicenextdim_larger!((), Slice); #[macro_export] macro_rules! s( // convert a..b;c into @convert(a..b, c), final item - (@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr) => { + (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr;$s:expr) => { match $r { r => { - let out_dim = $crate::SliceNextDim::next_dim(&r, $dim); - #[allow(unsafe_code)] - unsafe { - $crate::SliceInfo::new_unchecked( - [$($stack)* $crate::s!(@convert r, $s)], - out_dim, - ) - } + let in_dim = $crate::SliceNextDim::next_in_dim(&r, $in_dim); + let out_dim = $crate::SliceNextDim::next_out_dim(&r, $out_dim); + ( + [$($stack)* $crate::s!(@convert r, $s)], + in_dim, + out_dim, + ) } } }; // convert a..b into @convert(a..b), final item - (@parse $dim:expr, [$($stack:tt)*] $r:expr) => { + (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr) => { match $r { r => { - let out_dim = $crate::SliceNextDim::next_dim(&r, $dim); - #[allow(unsafe_code)] - unsafe { - $crate::SliceInfo::new_unchecked( - [$($stack)* $crate::s!(@convert r)], - out_dim, - ) - } + let in_dim = $crate::SliceNextDim::next_in_dim(&r, $in_dim); + let out_dim = $crate::SliceNextDim::next_out_dim(&r, $out_dim); + ( + [$($stack)* $crate::s!(@convert r)], + in_dim, + out_dim, + ) } } }; // convert a..b;c into @convert(a..b, c), final item, trailing comma - (@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr ,) => { - $crate::s![@parse $dim, [$($stack)*] $r;$s] + (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr;$s:expr ,) => { + $crate::s![@parse $in_dim, $out_dim, [$($stack)*] $r;$s] }; // convert a..b into @convert(a..b), final item, trailing comma - (@parse $dim:expr, [$($stack:tt)*] $r:expr ,) => { - $crate::s![@parse $dim, [$($stack)*] $r] + (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr ,) => { + $crate::s![@parse $in_dim, $out_dim, [$($stack)*] $r] }; // convert a..b;c into @convert(a..b, c) - (@parse $dim:expr, [$($stack:tt)*] $r:expr;$s:expr, $($t:tt)*) => { + (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr;$s:expr, $($t:tt)*) => { match $r { r => { $crate::s![@parse - $crate::SliceNextDim::next_dim(&r, $dim), + $crate::SliceNextDim::next_in_dim(&r, $in_dim), + $crate::SliceNextDim::next_out_dim(&r, $out_dim), [$($stack)* $crate::s!(@convert r, $s),] $($t)* ] @@ -594,11 +858,12 @@ macro_rules! s( } }; // convert a..b into @convert(a..b) - (@parse $dim:expr, [$($stack:tt)*] $r:expr, $($t:tt)*) => { + (@parse $in_dim:expr, $out_dim:expr, [$($stack:tt)*] $r:expr, $($t:tt)*) => { match $r { r => { $crate::s![@parse - $crate::SliceNextDim::next_dim(&r, $dim), + $crate::SliceNextDim::next_in_dim(&r, $in_dim), + $crate::SliceNextDim::next_out_dim(&r, $out_dim), [$($stack)* $crate::s!(@convert r),] $($t)* ] @@ -606,28 +871,40 @@ macro_rules! s( } }; // empty call, i.e. `s![]` - (@parse ::std::marker::PhantomData::<$crate::Ix0>, []) => { - { - #[allow(unsafe_code)] - unsafe { - $crate::SliceInfo::new_unchecked([], ::std::marker::PhantomData::<$crate::Ix0>) - } - } + (@parse ::core::marker::PhantomData::<$crate::Ix0>, ::core::marker::PhantomData::<$crate::Ix0>, []) => { + ( + [], + ::core::marker::PhantomData::<$crate::Ix0>, + ::core::marker::PhantomData::<$crate::Ix0>, + ) }; // Catch-all clause for syntax errors (@parse $($t:tt)*) => { compile_error!("Invalid syntax in s![] call.") }; - // convert range/index into SliceOrIndex + // convert range/index/new-axis into SliceInfoElem (@convert $r:expr) => { - <$crate::SliceOrIndex as ::std::convert::From<_>>::from($r) + <$crate::SliceInfoElem as ::core::convert::From<_>>::from($r) }; - // convert range/index and step into SliceOrIndex + // convert range/index/new-axis and step into SliceInfoElem (@convert $r:expr, $s:expr) => { - <$crate::SliceOrIndex as ::std::convert::From<_>>::from($r).step_by($s as isize) + <$crate::SliceInfoElem as ::core::convert::From<_>>::from( + <$crate::Slice as ::core::convert::From<_>>::from($r).step_by($s as isize) + ) }; ($($t:tt)*) => { - // The extra `*&` is a workaround for this compiler bug: - // https://github.com/rust-lang/rust/issues/23014 - &*&$crate::s![@parse ::std::marker::PhantomData::<$crate::Ix0>, [] $($t)*] + { + let (indices, in_dim, out_dim) = $crate::s![@parse + ::core::marker::PhantomData::<$crate::Ix0>, + ::core::marker::PhantomData::<$crate::Ix0>, + [] + $($t)* + ]; + // Safety: The `s![@parse ...]` above always constructs the correct + // values to meet the constraints of `SliceInfo::new_unchecked`. + #[allow(unsafe_code)] + unsafe { + $crate::SliceInfo::new_unchecked(indices, in_dim, out_dim) + } + } }; ); @@ -635,7 +912,7 @@ macro_rules! s( /// /// It's unfortunate that we need `'a` and `A` to be parameters of the trait, /// but they're necessary until Rust supports generic associated types. -pub trait MultiSlice<'a, A, D> +pub trait MultiSliceArg<'a, A, D> where A: 'a, D: Dimension, @@ -648,9 +925,11 @@ where /// **Panics** if performing any individual slice panics or if the slices /// are not disjoint (i.e. if they intersect). fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output; + + private_decl! {} } -impl<'a, A, D> MultiSlice<'a, A, D> for () +impl<'a, A, D> MultiSliceArg<'a, A, D> for () where A: 'a, D: Dimension, @@ -658,19 +937,24 @@ where type Output = (); fn multi_slice_move(&self, _view: ArrayViewMut<'a, A, D>) -> Self::Output {} + + private_impl! {} } -impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (&SliceInfo,) +impl<'a, A, D, I0> MultiSliceArg<'a, A, D> for (I0,) where A: 'a, D: Dimension, - Do0: Dimension, + I0: SliceArg, { - type Output = (ArrayViewMut<'a, A, Do0>,); + type Output = (ArrayViewMut<'a, A, I0::OutDim>,); - fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { - (view.slice_move(self.0),) + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output + { + (view.slice_move(&self.0),) } + + private_impl! {} } macro_rules! impl_multislice_tuple { @@ -678,13 +962,13 @@ macro_rules! impl_multislice_tuple { impl_multislice_tuple!(@def_impl ($($but_last,)* $last,), [$($but_last)*] $last); }; (@def_impl ($($all:ident,)*), [$($but_last:ident)*] $last:ident) => { - impl<'a, A, D, $($all,)*> MultiSlice<'a, A, D> for ($(&SliceInfo,)*) + impl<'a, A, D, $($all,)*> MultiSliceArg<'a, A, D> for ($($all,)*) where A: 'a, D: Dimension, - $($all: Dimension,)* + $($all: SliceArg,)* { - type Output = ($(ArrayViewMut<'a, A, $all>,)*); + type Output = ($(ArrayViewMut<'a, A, $all::OutDim>,)*); fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { #[allow(non_snake_case)] @@ -701,6 +985,8 @@ macro_rules! impl_multislice_tuple { ) } } + + private_impl! {} } }; (@intersects_self $shape:expr, ($head:expr,)) => { @@ -712,21 +998,24 @@ macro_rules! impl_multislice_tuple { }; } -impl_multislice_tuple!([Do0] Do1); -impl_multislice_tuple!([Do0 Do1] Do2); -impl_multislice_tuple!([Do0 Do1 Do2] Do3); -impl_multislice_tuple!([Do0 Do1 Do2 Do3] Do4); -impl_multislice_tuple!([Do0 Do1 Do2 Do3 Do4] Do5); +impl_multislice_tuple!([I0] I1); +impl_multislice_tuple!([I0 I1] I2); +impl_multislice_tuple!([I0 I1 I2] I3); +impl_multislice_tuple!([I0 I1 I2 I3] I4); +impl_multislice_tuple!([I0 I1 I2 I3 I4] I5); -impl<'a, A, D, T> MultiSlice<'a, A, D> for &T +impl<'a, A, D, T> MultiSliceArg<'a, A, D> for &T where A: 'a, D: Dimension, - T: MultiSlice<'a, A, D>, + T: MultiSliceArg<'a, A, D>, { type Output = T::Output; - fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output { + fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output + { T::multi_slice_move(self, view) } + + private_impl! {} } diff --git a/src/split_at.rs b/src/split_at.rs new file mode 100644 index 000000000..5dee44b63 --- /dev/null +++ b/src/split_at.rs @@ -0,0 +1,54 @@ +use crate::imp_prelude::*; + +/// Arrays and similar that can be split along an axis +pub(crate) trait SplitAt +{ + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) + where Self: Sized; +} + +pub(crate) trait SplitPreference: SplitAt +{ + #[allow(dead_code)] // used only when Rayon support is enabled + fn can_split(&self) -> bool; + fn split_preference(&self) -> (Axis, usize); + fn split(self) -> (Self, Self) + where Self: Sized + { + let (axis, index) = self.split_preference(); + self.split_at(axis, index) + } +} + +impl SplitAt for D +where D: Dimension +{ + fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) + { + let mut d1 = self; + let mut d2 = d1.clone(); + let i = axis.index(); + let len = d1[i]; + d1[i] = index; + d2[i] = len - index; + (d1, d2) + } +} + +impl SplitAt for ArrayViewMut<'_, A, D> +where D: Dimension +{ + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) + { + self.split_at(axis, index) + } +} + +impl SplitAt for RawArrayViewMut +where D: Dimension +{ + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) + { + self.split_at(axis, index) + } +} diff --git a/src/stacking.rs b/src/stacking.rs index e998b6d15..8737d6f60 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -1,4 +1,4 @@ -// Copyright 2014-2016 bluss and ndarray developers. +// Copyright 2014-2020 bluss and ndarray developers. // // Licensed under the Apache License, Version 2.0 or the MIT license @@ -6,10 +6,14 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; + +use crate::dimension; use crate::error::{from_kind, ErrorKind, ShapeError}; use crate::imp_prelude::*; -/// Stack arrays along the given axis. +/// Concatenate arrays along the given axis. /// /// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`. /// (may be made more flexible in the future).
@@ -17,24 +21,21 @@ use crate::imp_prelude::*; /// if the result is larger than is possible to represent. /// /// ``` -/// use ndarray::{arr2, Axis, stack}; +/// use ndarray::{arr2, Axis, concatenate}; /// /// let a = arr2(&[[2., 2.], /// [3., 3.]]); /// assert!( -/// stack(Axis(0), &[a.view(), a.view()]) +/// concatenate(Axis(0), &[a.view(), a.view()]) /// == Ok(arr2(&[[2., 2.], /// [3., 3.], /// [2., 2.], /// [3., 3.]])) /// ); /// ``` -pub fn stack<'a, A, D>( - axis: Axis, - arrays: &[ArrayView<'a, A, D>], -) -> Result, ShapeError> +pub fn concatenate(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> where - A: Copy, + A: Clone, D: RemoveAxis, { if arrays.is_empty() { @@ -54,58 +55,172 @@ where let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis)); res_dim.set_axis(axis, stacked_dim); + let new_len = dimension::size_of_shape_checked(&res_dim)?; - // we can safely use uninitialized values here because they are Copy - // and we will only ever write to them - let size = res_dim.size(); - let mut v = Vec::with_capacity(size); - unsafe { - v.set_len(size); + // start with empty array with precomputed capacity + // append's handling of empty arrays makes sure `axis` is ok for appending + res_dim.set_axis(axis, 0); + let mut res = unsafe { + // Safety: dimension is size 0 and vec is empty + Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len)) + }; + + for array in arrays { + res.append(axis, array.clone())?; } - let mut res = Array::from_shape_vec(res_dim, v)?; + debug_assert_eq!(res.len_of(axis), stacked_dim); + Ok(res) +} - { - let mut assign_view = res.view_mut(); - for array in arrays { - let len = array.len_of(axis); - let (mut front, rest) = assign_view.split_at(axis, len); - front.assign(array); - assign_view = rest; - } +/// Stack arrays along the new axis. +/// +/// ***Errors*** if the arrays have mismatching shapes. +/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds, +/// if the result is larger than is possible to represent. +/// +/// ``` +/// extern crate ndarray; +/// +/// use ndarray::{arr2, arr3, stack, Axis}; +/// +/// # fn main() { +/// +/// let a = arr2(&[[2., 2.], +/// [3., 3.]]); +/// assert!( +/// stack(Axis(0), &[a.view(), a.view()]) +/// == Ok(arr3(&[[[2., 2.], +/// [3., 3.]], +/// [[2., 2.], +/// [3., 3.]]])) +/// ); +/// # } +/// ``` +pub fn stack(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> +where + A: Clone, + D: Dimension, + D::Larger: RemoveAxis, +{ + if arrays.is_empty() { + return Err(from_kind(ErrorKind::Unsupported)); + } + let common_dim = arrays[0].raw_dim(); + // Avoid panic on `insert_axis` call, return an Err instead of it. + if axis.index() > common_dim.ndim() { + return Err(from_kind(ErrorKind::OutOfBounds)); + } + let mut res_dim = common_dim.insert_axis(axis); + + if arrays.iter().any(|a| a.raw_dim() != common_dim) { + return Err(from_kind(ErrorKind::IncompatibleShape)); } + + res_dim.set_axis(axis, arrays.len()); + + let new_len = dimension::size_of_shape_checked(&res_dim)?; + + // start with empty array with precomputed capacity + // append's handling of empty arrays makes sure `axis` is ok for appending + res_dim.set_axis(axis, 0); + let mut res = unsafe { + // Safety: dimension is size 0 and vec is empty + Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len)) + }; + + for array in arrays { + res.append(axis, array.clone().insert_axis(axis))?; + } + + debug_assert_eq!(res.len_of(axis), arrays.len()); Ok(res) } -/// Stack arrays along the given axis. +/// Stack arrays along the new axis. /// -/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each +/// Uses the [`stack()`] function, calling `ArrayView::from(&a)` on each /// argument `a`. /// -/// [1]: fn.stack.html -/// /// ***Panics*** if the `stack` function would return an error. /// /// ``` /// extern crate ndarray; /// -/// use ndarray::{arr2, stack, Axis}; +/// use ndarray::{arr2, arr3, stack, Axis}; /// /// # fn main() { /// -/// let a = arr2(&[[2., 2.], -/// [3., 3.]]); -/// assert!( -/// stack![Axis(0), a, a] -/// == arr2(&[[2., 2.], -/// [3., 3.], -/// [2., 2.], -/// [3., 3.]]) +/// let a = arr2(&[[1., 2.], +/// [3., 4.]]); +/// assert_eq!( +/// stack![Axis(0), a, a], +/// arr3(&[[[1., 2.], +/// [3., 4.]], +/// [[1., 2.], +/// [3., 4.]]]), +/// ); +/// assert_eq!( +/// stack![Axis(1), a, a,], +/// arr3(&[[[1., 2.], +/// [1., 2.]], +/// [[3., 4.], +/// [3., 4.]]]), +/// ); +/// assert_eq!( +/// stack![Axis(2), a, a], +/// arr3(&[[[1., 1.], +/// [2., 2.]], +/// [[3., 3.], +/// [4., 4.]]]), /// ); /// # } /// ``` #[macro_export] macro_rules! stack { + ($axis:expr, $( $array:expr ),+ ,) => { + $crate::stack!($axis, $($array),+) + }; ($axis:expr, $( $array:expr ),+ ) => { $crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() - } + }; +} + +/// Concatenate arrays along the given axis. +/// +/// Uses the [`concatenate()`] function, calling `ArrayView::from(&a)` on each +/// argument `a`. +/// +/// ***Panics*** if the `concatenate` function would return an error. +/// +/// ``` +/// extern crate ndarray; +/// +/// use ndarray::{arr2, concatenate, Axis}; +/// +/// # fn main() { +/// +/// let a = arr2(&[[1., 2.], +/// [3., 4.]]); +/// assert_eq!( +/// concatenate![Axis(0), a, a], +/// arr2(&[[1., 2.], +/// [3., 4.], +/// [1., 2.], +/// [3., 4.]]), +/// ); +/// assert_eq!( +/// concatenate![Axis(1), a, a,], +/// arr2(&[[1., 2., 1., 2.], +/// [3., 4., 3., 4.]]), +/// ); +/// # } +/// ``` +#[macro_export] +macro_rules! concatenate { + ($axis:expr, $( $array:expr ),+ ,) => { + $crate::concatenate!($axis, $($array),+) + }; + ($axis:expr, $( $array:expr ),+ ) => { + $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() + }; } diff --git a/src/tri.rs b/src/tri.rs new file mode 100644 index 000000000..b7d297fcc --- /dev/null +++ b/src/tri.rs @@ -0,0 +1,367 @@ +// Copyright 2014-2024 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use core::cmp::min; + +use num_traits::Zero; + +use crate::{ + dimension::{is_layout_c, is_layout_f}, + Array, + ArrayBase, + Axis, + Data, + Dimension, + Zip, +}; + +impl ArrayBase +where + S: Data, + D: Dimension, + A: Clone + Zero, +{ + /// Upper triangular of an array. + /// + /// Return a copy of the array with elements below the *k*-th diagonal zeroed. + /// For arrays with `ndim` exceeding 2, `triu` will apply to the final two axes. + /// For 0D and 1D arrays, `triu` will return an unchanged clone. + /// + /// See also [`ArrayBase::tril`] + /// + /// ``` + /// use ndarray::array; + /// + /// let arr = array![ + /// [1, 2, 3], + /// [4, 5, 6], + /// [7, 8, 9] + /// ]; + /// assert_eq!( + /// arr.triu(0), + /// array![ + /// [1, 2, 3], + /// [0, 5, 6], + /// [0, 0, 9] + /// ] + /// ); + /// ``` + pub fn triu(&self, k: isize) -> Array + { + if self.ndim() <= 1 { + return self.to_owned(); + } + + // Performance optimization for F-order arrays. + // C-order array check prevents infinite recursion in edge cases like [[1]]. + // k-size check prevents underflow when k == isize::MIN + let n = self.ndim(); + if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN { + let mut x = self.view(); + x.swap_axes(n - 2, n - 1); + let mut tril = x.tril(-k); + tril.swap_axes(n - 2, n - 1); + + return tril; + } + + let mut res = Array::zeros(self.raw_dim()); + let ncols = self.len_of(Axis(n - 1)); + let nrows = self.len_of(Axis(n - 2)); + let indices = Array::from_iter(0..nrows); + Zip::from(self.rows()) + .and(res.rows_mut()) + .and_broadcast(&indices) + .for_each(|src, mut dst, row_num| { + let mut lower = match k >= 0 { + true => row_num.saturating_add(k as usize), // Avoid overflow + false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0 + }; + lower = min(lower, ncols); + dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..])); + }); + + res + } + + /// Lower triangular of an array. + /// + /// Return a copy of the array with elements above the *k*-th diagonal zeroed. + /// For arrays with `ndim` exceeding 2, `tril` will apply to the final two axes. + /// For 0D and 1D arrays, `tril` will return an unchanged clone. + /// + /// See also [`ArrayBase::triu`] + /// + /// ``` + /// use ndarray::array; + /// + /// let arr = array![ + /// [1, 2, 3], + /// [4, 5, 6], + /// [7, 8, 9] + /// ]; + /// assert_eq!( + /// arr.tril(0), + /// array![ + /// [1, 0, 0], + /// [4, 5, 0], + /// [7, 8, 9] + /// ] + /// ); + /// ``` + pub fn tril(&self, k: isize) -> Array + { + if self.ndim() <= 1 { + return self.to_owned(); + } + + // Performance optimization for F-order arrays. + // C-order array check prevents infinite recursion in edge cases like [[1]]. + // k-size check prevents underflow when k == isize::MIN + let n = self.ndim(); + if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN { + let mut x = self.view(); + x.swap_axes(n - 2, n - 1); + let mut tril = x.triu(-k); + tril.swap_axes(n - 2, n - 1); + + return tril; + } + + let mut res = Array::zeros(self.raw_dim()); + let ncols = self.len_of(Axis(n - 1)); + let nrows = self.len_of(Axis(n - 2)); + let indices = Array::from_iter(0..nrows); + Zip::from(self.rows()) + .and(res.rows_mut()) + .and_broadcast(&indices) + .for_each(|src, mut dst, row_num| { + // let row_num = i.into_dimension().last_elem(); + let mut upper = match k >= 0 { + true => row_num.saturating_add(k as usize).saturating_add(1), // Avoid overflow + false => row_num.saturating_sub((k + 1).unsigned_abs()), // Avoid underflow + }; + upper = min(upper, ncols); + dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper])); + }); + + res + } +} + +#[cfg(test)] +mod tests +{ + use core::isize; + + use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder}; + use alloc::vec; + + #[test] + fn test_keep_order() + { + let x = Array2::::ones((3, 3).f()); + let res = x.triu(0); + assert!(dimension::is_layout_f(&res.dim, &res.strides)); + + let res = x.tril(0); + assert!(dimension::is_layout_f(&res.dim, &res.strides)); + } + + #[test] + fn test_0d() + { + let x = Array0::::ones(()); + let res = x.triu(0); + assert_eq!(res, x); + + let res = x.tril(0); + assert_eq!(res, x); + + let x = Array0::::ones(().f()); + let res = x.triu(0); + assert_eq!(res, x); + + let res = x.tril(0); + assert_eq!(res, x); + } + + #[test] + fn test_1d() + { + let x = array![1, 2, 3]; + let res = x.triu(0); + assert_eq!(res, x); + + let res = x.triu(0); + assert_eq!(res, x); + + let x = Array1::::ones(3.f()); + let res = x.triu(0); + assert_eq!(res, x); + + let res = x.triu(0); + assert_eq!(res, x); + } + + #[test] + fn test_2d() + { + let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; + + // Upper + let res = x.triu(0); + assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]); + + // Lower + let res = x.tril(0); + assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]); + + let x = Array2::from_shape_vec((3, 3).f(), vec![1, 4, 7, 2, 5, 8, 3, 6, 9]).unwrap(); + + // Upper + let res = x.triu(0); + assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]); + + // Lower + let res = x.tril(0); + assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]); + } + + #[test] + fn test_2d_single() + { + let x = array![[1]]; + + assert_eq!(x.triu(0), array![[1]]); + assert_eq!(x.tril(0), array![[1]]); + assert_eq!(x.triu(1), array![[0]]); + assert_eq!(x.tril(1), array![[1]]); + assert_eq!(x.triu(-1), array![[1]]); + assert_eq!(x.tril(-1), array![[0]]); + } + + #[test] + fn test_3d() + { + let x = array![ + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]] + ]; + + // Upper + let res = x.triu(0); + assert_eq!( + res, + array![ + [[1, 2, 3], [0, 5, 6], [0, 0, 9]], + [[10, 11, 12], [0, 14, 15], [0, 0, 18]], + [[19, 20, 21], [0, 23, 24], [0, 0, 27]] + ] + ); + + // Lower + let res = x.tril(0); + assert_eq!( + res, + array![ + [[1, 0, 0], [4, 5, 0], [7, 8, 9]], + [[10, 0, 0], [13, 14, 0], [16, 17, 18]], + [[19, 0, 0], [22, 23, 0], [25, 26, 27]] + ] + ); + + let x = Array3::from_shape_vec( + (3, 3, 3).f(), + vec![1, 10, 19, 4, 13, 22, 7, 16, 25, 2, 11, 20, 5, 14, 23, 8, 17, 26, 3, 12, 21, 6, 15, 24, 9, 18, 27], + ) + .unwrap(); + + // Upper + let res = x.triu(0); + assert_eq!( + res, + array![ + [[1, 2, 3], [0, 5, 6], [0, 0, 9]], + [[10, 11, 12], [0, 14, 15], [0, 0, 18]], + [[19, 20, 21], [0, 23, 24], [0, 0, 27]] + ] + ); + + // Lower + let res = x.tril(0); + assert_eq!( + res, + array![ + [[1, 0, 0], [4, 5, 0], [7, 8, 9]], + [[10, 0, 0], [13, 14, 0], [16, 17, 18]], + [[19, 0, 0], [22, 23, 0], [25, 26, 27]] + ] + ); + } + + #[test] + fn test_off_axis() + { + let x = array![ + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]] + ]; + + let res = x.triu(1); + assert_eq!( + res, + array![ + [[0, 2, 3], [0, 0, 6], [0, 0, 0]], + [[0, 11, 12], [0, 0, 15], [0, 0, 0]], + [[0, 20, 21], [0, 0, 24], [0, 0, 0]] + ] + ); + + let res = x.triu(-1); + assert_eq!( + res, + array![ + [[1, 2, 3], [4, 5, 6], [0, 8, 9]], + [[10, 11, 12], [13, 14, 15], [0, 17, 18]], + [[19, 20, 21], [22, 23, 24], [0, 26, 27]] + ] + ); + } + + #[test] + fn test_odd_shape() + { + let x = array![[1, 2, 3], [4, 5, 6]]; + let res = x.triu(0); + assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]); + + let res = x.tril(0); + assert_eq!(res, array![[1, 0, 0], [4, 5, 0]]); + + let x = array![[1, 2], [3, 4], [5, 6]]; + let res = x.triu(0); + assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]); + + let res = x.tril(0); + assert_eq!(res, array![[1, 0], [3, 4], [5, 6]]); + } + + #[test] + fn test_odd_k() + { + let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; + let z = Array2::zeros([3, 3]); + assert_eq!(x.triu(isize::MIN), x); + assert_eq!(x.tril(isize::MIN), z); + assert_eq!(x.triu(isize::MAX), z); + assert_eq!(x.tril(isize::MAX), x); + } +} diff --git a/src/zip/mod.rs b/src/zip/mod.rs index 33317253e..b58752f66 100644 --- a/src/zip/mod.rs +++ b/src/zip/mod.rs @@ -8,15 +8,22 @@ #[macro_use] mod zipmacro; +mod ndproducer; + +#[cfg(feature = "rayon")] +use std::mem::MaybeUninit; use crate::imp_prelude::*; +use crate::partial::Partial; +use crate::AssignElem; use crate::IntoDimension; use crate::Layout; -use crate::NdIndex; +use crate::dimension; use crate::indexes::{indices, Indices}; -use crate::layout::LayoutPriv; -use crate::layout::{CORDER, FORDER}; +use crate::split_at::{SplitAt, SplitPreference}; + +pub use self::ndproducer::{IntoNdProducer, NdProducer, Offset}; /// Return if the expression is a break value. macro_rules! fold_while { @@ -30,38 +37,53 @@ macro_rules! fold_while { /// Broadcast an array so that it acts like a larger size and/or shape array. /// -/// See [broadcasting][1] for more information. -/// -/// [1]: struct.ArrayBase.html#broadcasting +/// See [broadcasting](ArrayBase#broadcasting) for more information. trait Broadcast -where - E: IntoDimension, +where E: IntoDimension { type Output: NdProducer; /// Broadcast the array to the new dimensions `shape`. /// /// ***Panics*** if broadcasting isn’t possible. + #[track_caller] fn broadcast_unwrap(self, shape: E) -> Self::Output; private_decl! {} } +/// Compute `Layout` hints for array shape dim, strides +fn array_layout(dim: &D, strides: &D) -> Layout +{ + let n = dim.ndim(); + if dimension::is_layout_c(dim, strides) { + // effectively one-dimensional => C and F layout compatible + if n <= 1 || dim.slice().iter().filter(|&&len| len > 1).count() <= 1 { + Layout::one_dimensional() + } else { + Layout::c() + } + } else if n > 1 && dimension::is_layout_f(dim, strides) { + Layout::f() + } else if n > 1 { + if dim[0] > 1 && strides[0] == 1 { + Layout::fpref() + } else if dim[n - 1] > 1 && strides[n - 1] == 1 { + Layout::cpref() + } else { + Layout::none() + } + } else { + Layout::none() + } +} + impl ArrayBase where S: RawData, D: Dimension, { - pub(crate) fn layout_impl(&self) -> Layout { - Layout::new(if self.is_standard_layout() { - if self.ndim() <= 1 { - FORDER | CORDER - } else { - CORDER - } - } else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() { - FORDER - } else { - 0 - }) + pub(crate) fn layout_impl(&self) -> Layout + { + array_layout(&self.dim, &self.strides) } } @@ -71,145 +93,17 @@ where D: Dimension, { type Output = ArrayView<'a, A, E::Dim>; - fn broadcast_unwrap(self, shape: E) -> Self::Output { + fn broadcast_unwrap(self, shape: E) -> Self::Output + { + #[allow(clippy::needless_borrow)] let res: ArrayView<'_, A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension()); unsafe { ArrayView::new(res.ptr, res.dim, res.strides) } } private_impl! {} } -pub trait Splittable: Sized { - fn split_at(self, axis: Axis, index: Ix) -> (Self, Self); -} - -impl Splittable for D -where - D: Dimension, -{ - fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { - let mut d1 = self; - let mut d2 = d1.clone(); - let i = axis.index(); - let len = d1[i]; - d1[i] = index; - d2[i] = len - index; - (d1, d2) - } -} - -/// Argument conversion into a producer. -/// -/// Slices and vectors can be used (equivalent to 1-dimensional array views). -/// -/// This trait is like `IntoIterator` for `NdProducers` instead of iterators. -pub trait IntoNdProducer { - /// The element produced per iteration. - type Item; - /// Dimension type of the producer - type Dim: Dimension; - type Output: NdProducer; - /// Convert the value into an `NdProducer`. - fn into_producer(self) -> Self::Output; -} - -impl

(&self, part: &P) - where - P: NdProducer, - { - ndassert!( - part.equal_dim(&self.dimension), - "Zip: Producer dimension mismatch, expected: {:?}, got: {:?}", - self.dimension, - part.raw_dim() - ); - } + ndassert!( + part.equal_dim(dimension), + "Zip: Producer dimension mismatch, expected: {:?}, got: {:?}", + dimension, + part.raw_dim() + ); +} +impl Zip +where D: Dimension +{ /// Return a the number of element tuples in the Zip - pub fn size(&self) -> usize { + pub fn size(&self) -> usize + { self.dimension.size() } /// Return the length of `axis` /// /// ***Panics*** if `axis` is out of bounds. - fn len_of(&self, axis: Axis) -> usize { + #[track_caller] + fn len_of(&self, axis: Axis) -> usize + { self.dimension[axis.index()] } + fn prefer_f(&self) -> bool + { + !self.layout.is(Layout::CORDER) && (self.layout.is(Layout::FORDER) || self.layout_tendency < 0) + } + /// Return an *approximation* to the max stride axis; if /// component arrays disagree, there may be no choice better than the /// others. - fn max_stride_axis(&self) -> Axis { - let i = match self.layout.flag() { - FORDER => self - .dimension + fn max_stride_axis(&self) -> Axis + { + let i = if self.prefer_f() { + self.dimension .slice() .iter() .rposition(|&len| len > 1) - .unwrap_or(self.dimension.ndim() - 1), + .unwrap_or(self.dimension.ndim() - 1) + } else { /* corder or default */ - _ => self - .dimension + self.dimension .slice() .iter() .position(|&len| len > 1) - .unwrap_or(0), + .unwrap_or(0) }; Axis(i) } } impl Zip -where - D: Dimension, +where D: Dimension { - fn apply_core(&mut self, acc: Acc, function: F) -> FoldWhile + fn for_each_core(&mut self, acc: Acc, mut function: F) -> FoldWhile where F: FnMut(Acc, P::Item) -> FoldWhile, P: ZippableTuple, { - if self.layout.is(CORDER | FORDER) { - self.apply_core_contiguous(acc, function) + if self.dimension.ndim() == 0 { + function(acc, unsafe { self.parts.as_ref(self.parts.as_ptr()) }) + } else if self.layout.is(Layout::CORDER | Layout::FORDER) { + self.for_each_core_contiguous(acc, function) } else { - self.apply_core_strided(acc, function) + self.for_each_core_strided(acc, function) } } - fn apply_core_contiguous(&mut self, mut acc: Acc, mut function: F) -> FoldWhile + + fn for_each_core_contiguous(&mut self, acc: Acc, mut function: F) -> FoldWhile where F: FnMut(Acc, P::Item) -> FoldWhile, P: ZippableTuple, { - debug_assert!(self.layout.is(CORDER | FORDER)); + debug_assert!(self.layout.is(Layout::CORDER | Layout::FORDER)); let size = self.dimension.size(); let ptrs = self.parts.as_ptr(); let inner_strides = self.parts.contiguous_stride(); - for i in 0..size { - unsafe { - let ptr_i = ptrs.stride_offset(inner_strides, i); - acc = fold_while![function(acc, self.parts.as_ref(ptr_i))]; - } + unsafe { self.inner(acc, ptrs, inner_strides, size, &mut function) } + } + + /// The innermost loop of the Zip for_each methods + /// + /// Run the fold while operation on a stretch of elements with constant strides + /// + /// `ptr`: base pointer for the first element in this stretch + /// `strides`: strides for the elements in this stretch + /// `len`: number of elements + /// `function`: closure + unsafe fn inner( + &self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride, len: usize, function: &mut F, + ) -> FoldWhile + where + F: FnMut(Acc, P::Item) -> FoldWhile, + P: ZippableTuple, + { + let mut i = 0; + while i < len { + let p = ptr.stride_offset(strides, i); + acc = fold_while!(function(acc, self.parts.as_ref(p))); + i += 1; } FoldWhile::Continue(acc) } - fn apply_core_strided(&mut self, mut acc: Acc, mut function: F) -> FoldWhile + fn for_each_core_strided(&mut self, acc: Acc, function: F) -> FoldWhile where F: FnMut(Acc, P::Item) -> FoldWhile, P: ZippableTuple, @@ -715,26 +367,90 @@ where if n == 0 { panic!("Unreachable: ndim == 0 is contiguous") } + if n == 1 || self.layout_tendency >= 0 { + self.for_each_core_strided_c(acc, function) + } else { + self.for_each_core_strided_f(acc, function) + } + } + + // Non-contiguous but preference for C - unroll over Axis(ndim - 1) + fn for_each_core_strided_c(&mut self, mut acc: Acc, mut function: F) -> FoldWhile + where + F: FnMut(Acc, P::Item) -> FoldWhile, + P: ZippableTuple, + { + let n = self.dimension.ndim(); let unroll_axis = n - 1; let inner_len = self.dimension[unroll_axis]; self.dimension[unroll_axis] = 1; let mut index_ = self.dimension.first_index(); let inner_strides = self.parts.stride_of(unroll_axis); + // Loop unrolled over closest axis while let Some(index) = index_ { - // Let's “unroll” the loop over the innermost axis unsafe { let ptr = self.parts.uget_ptr(&index); - for i in 0..inner_len { - let p = ptr.stride_offset(inner_strides, i); - acc = fold_while!(function(acc, self.parts.as_ref(p))); - } + acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)]; } index_ = self.dimension.next_for(index); } - self.dimension[unroll_axis] = inner_len; FoldWhile::Continue(acc) } + + // Non-contiguous but preference for F - unroll over Axis(0) + fn for_each_core_strided_f(&mut self, mut acc: Acc, mut function: F) -> FoldWhile + where + F: FnMut(Acc, P::Item) -> FoldWhile, + P: ZippableTuple, + { + let unroll_axis = 0; + let inner_len = self.dimension[unroll_axis]; + self.dimension[unroll_axis] = 1; + let index_ = self.dimension.first_index(); + let inner_strides = self.parts.stride_of(unroll_axis); + // Loop unrolled over closest axis + if let Some(mut index) = index_ { + loop { + unsafe { + let ptr = self.parts.uget_ptr(&index); + acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)]; + } + + if !self.dimension.next_for_f(&mut index) { + break; + } + } + } + FoldWhile::Continue(acc) + } + + #[cfg(feature = "rayon")] + pub(crate) fn uninitialized_for_current_layout(&self) -> Array, D> + { + let is_f = self.prefer_f(); + Array::uninit(self.dimension.clone().set_f(is_f)) + } +} + +impl Zip<(P1, P2), D> +where + D: Dimension, + P1: NdProducer, + P1: NdProducer, +{ + /// Debug assert traversal order is like c (including 1D case) + // Method placement: only used for binary Zip at the moment. + #[inline] + pub(crate) fn debug_assert_c_order(self) -> Self + { + debug_assert!(self.layout.is(Layout::CORDER) || self.layout_tendency >= 0 || + self.dimension.slice().iter().filter(|&&d| d > 1).count() <= 1, + "Assertion failed: traversal is not c-order or 1D for \ + layout {:?}, tendency {}, dimension {:?}", + self.layout, self.layout_tendency, self.dimension); + self + } } /* @@ -752,14 +468,17 @@ impl Offset for *mut T { } */ -trait OffsetTuple { +trait OffsetTuple +{ type Args; unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self; } -impl OffsetTuple for *mut T { +impl OffsetTuple for *mut T +{ type Args = isize; - unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self { + unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self + { self.offset(index as isize * stride) } } @@ -777,7 +496,7 @@ macro_rules! offset_impl { } } )+ - } + }; } offset_impl! { @@ -836,7 +555,7 @@ macro_rules! zipt_impl { } } )+ - } + }; } zipt_impl! { @@ -858,10 +577,10 @@ macro_rules! map_impl { { /// Apply a function to all elements of the input arrays, /// visiting elements in lock step. - pub fn apply(mut self, mut function: F) + pub fn for_each(mut self, mut function: F) where F: FnMut($($p::Item),*) { - self.apply_core((), move |(), args| { + self.for_each_core((), move |(), args| { let ($($p,)*) = args; FoldWhile::Continue(function($($p),*)) }); @@ -899,7 +618,7 @@ macro_rules! map_impl { where F: FnMut(Acc, $($p::Item),*) -> Acc, { - self.apply_core(acc, move |acc, args| { + self.for_each_core(acc, move |acc, args| { let ($($p,)*) = args; FoldWhile::Continue(function(acc, $($p),*)) }).into_inner() @@ -912,7 +631,7 @@ macro_rules! map_impl { -> FoldWhile where F: FnMut(Acc, $($p::Item),*) -> FoldWhile { - self.apply_core(acc, move |acc, args| { + self.for_each_core(acc, move |acc, args| { let ($($p,)*) = args; function(acc, $($p),*) }) @@ -934,7 +653,7 @@ macro_rules! map_impl { pub fn all(mut self, mut predicate: F) -> bool where F: FnMut($($p::Item),*) -> bool { - !self.apply_core((), move |_, args| { + !self.for_each_core((), move |_, args| { let ($($p,)*) = args; if predicate($($p),*) { FoldWhile::Continue(()) @@ -944,22 +663,64 @@ macro_rules! map_impl { }).is_done() } + /// Tests if at least one element of the iterator matches a predicate. + /// + /// Returns `true` if `predicate` evaluates to `true` for at least one element. + /// Returns `false` if the input arrays are empty. + /// + /// Example: + /// + /// ``` + /// use ndarray::{array, Zip}; + /// let a = array![1, 2, 3]; + /// let b = array![1, 4, 9]; + /// assert!(Zip::from(&a).and(&b).any(|&a, &b| a == b)); + /// assert!(!Zip::from(&a).and(&b).any(|&a, &b| a - 1 == b)); + /// ``` + pub fn any(mut self, mut predicate: F) -> bool + where F: FnMut($($p::Item),*) -> bool + { + self.for_each_core((), move |_, args| { + let ($($p,)*) = args; + if predicate($($p),*) { + FoldWhile::Done(()) + } else { + FoldWhile::Continue(()) + } + }).is_done() + } + expand_if!(@bool [$notlast] /// Include the producer `p` in the Zip. /// /// ***Panics*** if `p`’s shape doesn’t match the Zip’s exactly. + #[track_caller] pub fn and

(self, p: P) -> Zip<($($p,)* P::Output, ), D> where P: IntoNdProducer, { - let array = p.into_producer(); - self.check(&array); - let part_layout = array.layout(); - let ($($p,)*) = self.parts; - Zip { - parts: ($($p,)* array, ), - layout: self.layout.and(part_layout), - dimension: self.dimension, + let part = p.into_producer(); + zip_dimension_check(&self.dimension, &part); + self.build_and(part) + } + + /// Include the producer `p` in the Zip. + /// + /// ## Safety + /// + /// The caller must ensure that the producer's shape is equal to the Zip's shape. + /// Uses assertions when debug assertions are enabled. + #[allow(unused)] + pub(crate) unsafe fn and_unchecked

(self, p: P) -> Zip<($($p,)* P::Output, ), D> + where P: IntoNdProducer, + { + #[cfg(debug_assertions)] + { + self.and(p) + } + #[cfg(not(debug_assertions))] + { + self.build_and(p.into_producer()) } } @@ -968,20 +729,74 @@ macro_rules! map_impl { /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. /// /// ***Panics*** if broadcasting isn’t possible. + #[track_caller] pub fn and_broadcast<'a, P, D2, Elem>(self, p: P) -> Zip<($($p,)* ArrayView<'a, Elem, D>, ), D> where P: IntoNdProducer, Item=&'a Elem>, D2: Dimension, { - let array = p.into_producer().broadcast_unwrap(self.dimension.clone()); - let part_layout = array.layout(); + let part = p.into_producer().broadcast_unwrap(self.dimension.clone()); + self.build_and(part) + } + + fn build_and

(self, part: P) -> Zip<($($p,)* P, ), D> + where P: NdProducer, + { + let part_layout = part.layout(); let ($($p,)*) = self.parts; Zip { - parts: ($($p,)* array, ), - layout: self.layout.and(part_layout), + parts: ($($p,)* part, ), + layout: self.layout.intersect(part_layout), dimension: self.dimension, + layout_tendency: self.layout_tendency + part_layout.tendency(), } } + + /// Map and collect the results into a new array, which has the same size as the + /// inputs. + /// + /// If all inputs are c- or f-order respectively, that is preserved in the output. + pub fn map_collect(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array { + self.map_collect_owned(f) + } + + pub(crate) fn map_collect_owned(self, f: impl FnMut($($p::Item,)* ) -> R) + -> ArrayBase + where S: DataOwned + { + // safe because: all elements are written before the array is completed + + let shape = self.dimension.clone().set_f(self.prefer_f()); + let output = >::build_uninit(shape, |output| { + // Use partial to count the number of filled elements, and can drop the right + // number of elements on unwinding (if it happens during apply/collect). + unsafe { + let output_view = output.into_raw_view_mut().cast::(); + self.and(output_view) + .collect_with_partial(f) + .release_ownership(); + } + }); + unsafe { + output.assume_init() + } + } + + /// Map and assign the results into the producer `into`, which should have the same + /// size as the other inputs. + /// + /// The producer should have assignable items as dictated by the `AssignElem` trait, + /// for example `&mut R`. + pub fn map_assign_into(self, into: Q, mut f: impl FnMut($($p::Item,)* ) -> R) + where Q: IntoNdProducer, + Q::Item: AssignElem + { + self.and(into) + .for_each(move |$($p, )* output_| { + output_.assign_elem(f($($p ),*)); + }); + } + ); /// Split the `Zip` evenly in two. @@ -990,25 +805,108 @@ macro_rules! map_impl { pub fn split(self) -> (Self, Self) { debug_assert_ne!(self.size(), 0, "Attempt to split empty zip"); debug_assert_ne!(self.size(), 1, "Attempt to split zip with 1 elem"); + SplitPreference::split(self) + } + } + + expand_if!(@bool [$notlast] + // For collect; Last producer is a RawViewMut + #[allow(non_snake_case)] + impl Zip<($($p,)* PLast), D> + where D: Dimension, + $($p: NdProducer ,)* + PLast: NdProducer, + { + /// The inner workings of map_collect and par_map_collect + /// + /// Apply the function and collect the results into the output (last producer) + /// which should be a raw array view; a Partial that owns the written + /// elements is returned. + /// + /// Elements will be overwritten in place (in the sense of std::ptr::write). + /// + /// ## Safety + /// + /// The last producer is a RawArrayViewMut and must be safe to write into. + /// The producer must be c- or f-contig and have the same layout tendency + /// as the whole Zip. + /// + /// The returned Partial's proxy ownership of the elements must be handled, + /// before the array the raw view points to realizes its ownership. + pub(crate) unsafe fn collect_with_partial(self, mut f: F) -> Partial + where F: FnMut($($p::Item,)* ) -> R + { + // Get the last producer; and make a Partial that aliases its data pointer + let (.., ref output) = &self.parts; + + // debug assert that the output is contiguous in the memory layout we need + if cfg!(debug_assertions) { + let out_layout = output.layout(); + assert!(out_layout.is(Layout::CORDER | Layout::FORDER)); + assert!( + (self.layout_tendency <= 0 && out_layout.tendency() <= 0) || + (self.layout_tendency >= 0 && out_layout.tendency() >= 0), + "layout tendency violation for self layout {:?}, output layout {:?},\ + output shape {:?}", + self.layout, out_layout, output.raw_dim()); + } + + let mut partial = Partial::new(output.as_ptr()); + + // Apply the mapping function on this zip + // if we panic with unwinding; Partial will drop the written elements. + let partial_len = &mut partial.len; + self.for_each(move |$($p,)* output_elem: *mut R| { + output_elem.write(f($($p),*)); + if std::mem::needs_drop::() { + *partial_len += 1; + } + }); + + partial + } + } + ); + + impl SplitPreference for Zip<($($p,)*), D> + where D: Dimension, + $($p: NdProducer ,)* + { + fn can_split(&self) -> bool { self.size() > 1 } + + fn split_preference(&self) -> (Axis, usize) { // Always split in a way that preserves layout (if any) let axis = self.max_stride_axis(); let index = self.len_of(axis) / 2; + (axis, index) + } + } + + impl SplitAt for Zip<($($p,)*), D> + where D: Dimension, + $($p: NdProducer ,)* + { + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { let (p1, p2) = self.parts.split_at(axis, index); let (d1, d2) = self.dimension.split_at(axis, index); (Zip { dimension: d1, layout: self.layout, parts: p1, + layout_tendency: self.layout_tendency, }, Zip { dimension: d2, layout: self.layout, parts: p2, + layout_tendency: self.layout_tendency, }) } + } + )+ - } + }; } map_impl! { @@ -1022,23 +920,27 @@ map_impl! { /// Value controlling the execution of `.fold_while` on `Zip`. #[derive(Debug, Copy, Clone)] -pub enum FoldWhile { +pub enum FoldWhile +{ /// Continue folding with this value Continue(T), /// Fold is complete and will return this value Done(T), } -impl FoldWhile { +impl FoldWhile +{ /// Return the inner value - pub fn into_inner(self) -> T { + pub fn into_inner(self) -> T + { match self { FoldWhile::Continue(x) | FoldWhile::Done(x) => x, } } /// Return true if it is `Done`, false if `Continue` - pub fn is_done(&self) -> bool { + pub fn is_done(&self) -> bool + { match *self { FoldWhile::Continue(_) => false, FoldWhile::Done(_) => true, diff --git a/src/zip/ndproducer.rs b/src/zip/ndproducer.rs new file mode 100644 index 000000000..1d1b3391b --- /dev/null +++ b/src/zip/ndproducer.rs @@ -0,0 +1,453 @@ +use crate::imp_prelude::*; +use crate::Layout; +use crate::NdIndex; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; + +/// Argument conversion into a producer. +/// +/// Slices and vectors can be used (equivalent to 1-dimensional array views). +/// +/// This trait is like `IntoIterator` for `NdProducers` instead of iterators. +pub trait IntoNdProducer +{ + /// The element produced per iteration. + type Item; + /// Dimension type of the producer + type Dim: Dimension; + type Output: NdProducer; + /// Convert the value into an `NdProducer`. + fn into_producer(self) -> Self::Output; +} + +impl

IntoNdProducer for P +where P: NdProducer +{ + type Item = P::Item; + type Dim = P::Dim; + type Output = Self; + fn into_producer(self) -> Self::Output + { + self + } +} + +/// A producer of an n-dimensional set of elements; +/// for example an array view, mutable array view or an iterator +/// that yields chunks. +/// +/// Producers are used as a arguments to [`Zip`](crate::Zip) and +/// [`azip!()`]. +/// +/// # Comparison to `IntoIterator` +/// +/// Most `NdProducers` are *iterable* (implement `IntoIterator`) but not directly +/// iterators. This separation is needed because the producer represents +/// a multidimensional set of items, it can be split along a particular axis for +/// parallelization, and it has no fixed correspondence to a sequence. +/// +/// The natural exception is one dimensional producers, like `AxisIter`, which +/// implement `Iterator` directly +/// (`AxisIter` traverses a one dimensional sequence, along an axis, while +/// *producing* multidimensional items). +/// +/// See also [`IntoNdProducer`] +pub trait NdProducer +{ + /// The element produced per iteration. + type Item; + // Internal use / Pointee type + /// Dimension type + type Dim: Dimension; + + // The pointer Ptr is used by an array view to simply point to the + // current element. It doesn't have to be a pointer (see Indices). + // Its main function is that it can be incremented with a particular + // stride (= along a particular axis) + #[doc(hidden)] + /// Pointer or stand-in for pointer + type Ptr: Offset; + #[doc(hidden)] + /// Pointer stride + type Stride: Copy; + + #[doc(hidden)] + fn layout(&self) -> Layout; + /// Return the shape of the producer. + fn raw_dim(&self) -> Self::Dim; + #[doc(hidden)] + fn equal_dim(&self, dim: &Self::Dim) -> bool + { + self.raw_dim() == *dim + } + #[doc(hidden)] + fn as_ptr(&self) -> Self::Ptr; + #[doc(hidden)] + unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item; + #[doc(hidden)] + unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr; + #[doc(hidden)] + fn stride_of(&self, axis: Axis) -> ::Stride; + #[doc(hidden)] + fn contiguous_stride(&self) -> Self::Stride; + #[doc(hidden)] + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) + where Self: Sized; + + private_decl! {} +} + +pub trait Offset: Copy +{ + type Stride: Copy; + unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self; + private_decl! {} +} + +impl Offset for *const T +{ + type Stride = isize; + unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self + { + self.offset(s * (index as isize)) + } + private_impl! {} +} + +impl Offset for *mut T +{ + type Stride = isize; + unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self + { + self.offset(s * (index as isize)) + } + private_impl! {} +} + +/// An array reference is an n-dimensional producer of element references +/// (like ArrayView). +impl<'a, A: 'a, S, D> IntoNdProducer for &'a ArrayBase +where + D: Dimension, + S: Data, +{ + type Item = &'a A; + type Dim = D; + type Output = ArrayView<'a, A, D>; + fn into_producer(self) -> Self::Output + { + self.view() + } +} + +/// A mutable array reference is an n-dimensional producer of mutable element +/// references (like ArrayViewMut). +impl<'a, A: 'a, S, D> IntoNdProducer for &'a mut ArrayBase +where + D: Dimension, + S: DataMut, +{ + type Item = &'a mut A; + type Dim = D; + type Output = ArrayViewMut<'a, A, D>; + fn into_producer(self) -> Self::Output + { + self.view_mut() + } +} + +/// A slice is a one-dimensional producer +impl<'a, A: 'a> IntoNdProducer for &'a [A] +{ + type Item = ::Item; + type Dim = Ix1; + type Output = ArrayView1<'a, A>; + fn into_producer(self) -> Self::Output + { + <_>::from(self) + } +} + +/// A mutable slice is a mutable one-dimensional producer +impl<'a, A: 'a> IntoNdProducer for &'a mut [A] +{ + type Item = ::Item; + type Dim = Ix1; + type Output = ArrayViewMut1<'a, A>; + fn into_producer(self) -> Self::Output + { + <_>::from(self) + } +} + +/// A one-dimensional array is a one-dimensional producer +impl<'a, A: 'a, const N: usize> IntoNdProducer for &'a [A; N] +{ + type Item = ::Item; + type Dim = Ix1; + type Output = ArrayView1<'a, A>; + fn into_producer(self) -> Self::Output + { + <_>::from(self) + } +} + +/// A mutable one-dimensional array is a mutable one-dimensional producer +impl<'a, A: 'a, const N: usize> IntoNdProducer for &'a mut [A; N] +{ + type Item = ::Item; + type Dim = Ix1; + type Output = ArrayViewMut1<'a, A>; + fn into_producer(self) -> Self::Output + { + <_>::from(self) + } +} + +/// A Vec is a one-dimensional producer +impl<'a, A: 'a> IntoNdProducer for &'a Vec +{ + type Item = ::Item; + type Dim = Ix1; + type Output = ArrayView1<'a, A>; + fn into_producer(self) -> Self::Output + { + <_>::from(self) + } +} + +/// A mutable Vec is a mutable one-dimensional producer +impl<'a, A: 'a> IntoNdProducer for &'a mut Vec +{ + type Item = ::Item; + type Dim = Ix1; + type Output = ArrayViewMut1<'a, A>; + fn into_producer(self) -> Self::Output + { + <_>::from(self) + } +} + +impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> +{ + type Item = &'a A; + type Dim = D; + type Ptr = *mut A; + type Stride = isize; + + private_impl! {} + + fn raw_dim(&self) -> Self::Dim + { + self.raw_dim() + } + + fn equal_dim(&self, dim: &Self::Dim) -> bool + { + self.dim.equal(dim) + } + + fn as_ptr(&self) -> *mut A + { + self.as_ptr() as _ + } + + fn layout(&self) -> Layout + { + self.layout_impl() + } + + unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item + { + &*ptr + } + + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A + { + self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) + } + + fn stride_of(&self, axis: Axis) -> isize + { + self.stride_of(axis) + } + + #[inline(always)] + fn contiguous_stride(&self) -> Self::Stride + { + 1 + } + + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) + { + self.split_at(axis, index) + } +} + +impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> +{ + type Item = &'a mut A; + type Dim = D; + type Ptr = *mut A; + type Stride = isize; + + private_impl! {} + + fn raw_dim(&self) -> Self::Dim + { + self.raw_dim() + } + + fn equal_dim(&self, dim: &Self::Dim) -> bool + { + self.dim.equal(dim) + } + + fn as_ptr(&self) -> *mut A + { + self.as_ptr() as _ + } + + fn layout(&self) -> Layout + { + self.layout_impl() + } + + unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item + { + &mut *ptr + } + + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A + { + self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) + } + + fn stride_of(&self, axis: Axis) -> isize + { + self.stride_of(axis) + } + + #[inline(always)] + fn contiguous_stride(&self) -> Self::Stride + { + 1 + } + + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) + { + self.split_at(axis, index) + } +} + +impl NdProducer for RawArrayView +{ + type Item = *const A; + type Dim = D; + type Ptr = *const A; + type Stride = isize; + + private_impl! {} + + fn raw_dim(&self) -> Self::Dim + { + self.raw_dim() + } + + fn equal_dim(&self, dim: &Self::Dim) -> bool + { + self.dim.equal(dim) + } + + fn as_ptr(&self) -> *const A + { + self.as_ptr() + } + + fn layout(&self) -> Layout + { + self.layout_impl() + } + + unsafe fn as_ref(&self, ptr: *const A) -> *const A + { + ptr + } + + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *const A + { + self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) + } + + fn stride_of(&self, axis: Axis) -> isize + { + self.stride_of(axis) + } + + #[inline(always)] + fn contiguous_stride(&self) -> Self::Stride + { + 1 + } + + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) + { + self.split_at(axis, index) + } +} + +impl NdProducer for RawArrayViewMut +{ + type Item = *mut A; + type Dim = D; + type Ptr = *mut A; + type Stride = isize; + + private_impl! {} + + fn raw_dim(&self) -> Self::Dim + { + self.raw_dim() + } + + fn equal_dim(&self, dim: &Self::Dim) -> bool + { + self.dim.equal(dim) + } + + fn as_ptr(&self) -> *mut A + { + self.as_ptr() as _ + } + + fn layout(&self) -> Layout + { + self.layout_impl() + } + + unsafe fn as_ref(&self, ptr: *mut A) -> *mut A + { + ptr + } + + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A + { + self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) + } + + fn stride_of(&self, axis: Axis) -> isize + { + self.stride_of(axis) + } + + #[inline(always)] + fn contiguous_stride(&self) -> Self::Stride + { + 1 + } + + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) + { + self.split_at(axis, index) + } +} diff --git a/src/zip/zipmacro.rs b/src/zip/zipmacro.rs index ea616a05e..0bbe956b3 100644 --- a/src/zip/zipmacro.rs +++ b/src/zip/zipmacro.rs @@ -1,7 +1,7 @@ /// Array zip macro: lock step function application across several arrays and /// producers. /// -/// This is a shorthand for [`Zip`](struct.Zip.html). +/// This is a shorthand for [`Zip`](crate::Zip). /// /// This example: /// @@ -12,7 +12,7 @@ /// Is equivalent to: /// /// ```rust,ignore -/// Zip::from(&mut a).and(&b).and(&c).apply(|a, &b, &c| { +/// Zip::from(&mut a).and(&b).and(&c).for_each(|a, &b, &c| { /// *a = b + c /// }); /// ``` @@ -27,8 +27,8 @@ /// /// The *expr* are expressions whose types must implement `IntoNdProducer`, the /// *pat* are the patterns of the parameters to the closure called by -/// `Zip::apply`, and *body_expr* is the body of the closure called by -/// `Zip::apply`. You can think of each *pat* `in` *expr* as being analogous to +/// `Zip::for_each`, and *body_expr* is the body of the closure called by +/// `Zip::for_each`. You can think of each *pat* `in` *expr* as being analogous to /// the `pat in expr` of a normal loop `for pat in expr { statements }`: a /// pattern, followed by `in`, followed by an expression that implements /// `IntoNdProducer` (analogous to `IntoIterator` for a `for` loop). @@ -38,66 +38,61 @@ /// ## Examples /// /// ```rust -/// extern crate ndarray; -/// /// use ndarray::{azip, Array1, Array2, Axis}; /// /// type M = Array2; /// -/// fn main() { -/// // Setup example arrays -/// let mut a = M::zeros((16, 16)); -/// let mut b = M::zeros(a.dim()); -/// let mut c = M::zeros(a.dim()); -/// -/// // assign values -/// b.fill(1.); -/// for ((i, j), elt) in c.indexed_iter_mut() { -/// *elt = (i + 10 * j) as f32; -/// } +/// // Setup example arrays +/// let mut a = M::zeros((16, 16)); +/// let mut b = M::zeros(a.dim()); +/// let mut c = M::zeros(a.dim()); /// -/// // Example 1: Compute a simple ternary operation: -/// // elementwise addition of b and c, stored in a -/// azip!((a in &mut a, &b in &b, &c in &c) *a = b + c); +/// // assign values +/// b.fill(1.); +/// for ((i, j), elt) in c.indexed_iter_mut() { +/// *elt = (i + 10 * j) as f32; +/// } /// -/// assert_eq!(a, &b + &c); +/// // Example 1: Compute a simple ternary operation: +/// // elementwise addition of b and c, stored in a +/// azip!((a in &mut a, &b in &b, &c in &c) *a = b + c); /// -/// // Example 2: azip!() with index -/// azip!((index (i, j), &b in &b, &c in &c) { -/// a[[i, j]] = b - c; -/// }); +/// assert_eq!(a, &b + &c); /// -/// assert_eq!(a, &b - &c); +/// // Example 2: azip!() with index +/// azip!((index (i, j), &b in &b, &c in &c) { +/// a[[i, j]] = b - c; +/// }); /// +/// assert_eq!(a, &b - &c); /// -/// // Example 3: azip!() on references -/// // See the definition of the function below -/// borrow_multiply(&mut a, &b, &c); /// -/// assert_eq!(a, &b * &c); +/// // Example 3: azip!() on references +/// // See the definition of the function below +/// borrow_multiply(&mut a, &b, &c); /// +/// assert_eq!(a, &b * &c); /// -/// // Since this function borrows its inputs, the `IntoNdProducer` -/// // expressions don't need to explicitly include `&mut` or `&`. -/// fn borrow_multiply(a: &mut M, b: &M, c: &M) { -/// azip!((a in a, &b in b, &c in c) *a = b * c); -/// } /// +/// // Since this function borrows its inputs, the `IntoNdProducer` +/// // expressions don't need to explicitly include `&mut` or `&`. +/// fn borrow_multiply(a: &mut M, b: &M, c: &M) { +/// azip!((a in a, &b in b, &c in c) *a = b * c); +/// } /// -/// // Example 4: using azip!() without dereference in pattern. -/// // -/// // Create a new array `totals` with one entry per row of `a`. -/// // Use azip to traverse the rows of `a` and assign to the corresponding -/// // entry in `totals` with the sum across each row. -/// // -/// // The row is an array view; it doesn't need to be dereferenced. -/// let mut totals = Array1::zeros(a.rows()); -/// azip!((totals in &mut totals, row in a.genrows()) *totals = row.sum()); /// -/// // Check the result against the built in `.sum_axis()` along axis 1. -/// assert_eq!(totals, a.sum_axis(Axis(1))); -/// } +/// // Example 4: using azip!() without dereference in pattern. +/// // +/// // Create a new array `totals` with one entry per row of `a`. +/// // Use azip to traverse the rows of `a` and assign to the corresponding +/// // entry in `totals` with the sum across each row. +/// // +/// // The row is an array view; it doesn't need to be dereferenced. +/// let mut totals = Array1::zeros(a.nrows()); +/// azip!((totals in &mut totals, row in a.rows()) *totals = row.sum()); /// +/// // Check the result against the built in `.sum_axis()` along axis 1. +/// assert_eq!(totals, a.sum_axis(Axis(1))); /// ``` #[macro_export] macro_rules! azip { @@ -122,9 +117,15 @@ macro_rules! azip { $(.and($prod))* .$apply(|$first_pat, $($pat),*| $body) }; + + // Unindexed with one or more producer, no loop body + (@build $apply:ident $first_prod:expr $(, $prod:expr)* $(,)?) => { + $crate::Zip::from($first_prod) + $(.and($prod))* + }; // catch-all rule (@build $($t:tt)*) => { compile_error!("Invalid syntax in azip!()") }; ($($t:tt)*) => { - $crate::azip!(@build apply $($t)*) + $crate::azip!(@build for_each $($t)*) }; } diff --git a/tests/append.rs b/tests/append.rs new file mode 100644 index 000000000..cf5397de1 --- /dev/null +++ b/tests/append.rs @@ -0,0 +1,453 @@ +use ndarray::prelude::*; +use ndarray::{ErrorKind, ShapeError}; + +#[test] +fn push_row() +{ + let mut a = Array::zeros((0, 4)); + a.push_row(aview1(&[0., 1., 2., 3.])).unwrap(); + a.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); + assert_eq!(a.shape(), &[2, 4]); + + assert_eq!(a, + array![[0., 1., 2., 3.], + [4., 5., 6., 7.]]); + + assert_eq!(a.push_row(aview1(&[1.])), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_column(aview1(&[1.])), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_column(aview1(&[1., 2.])), + Ok(())); + assert_eq!(a, + array![[0., 1., 2., 3., 1.], + [4., 5., 6., 7., 2.]]); +} + +#[test] +fn push_row_wrong_layout() +{ + let mut a = Array::zeros((0, 4)); + a.push_row(aview1(&[0., 1., 2., 3.])).unwrap(); + a.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); + assert_eq!(a.shape(), &[2, 4]); + + assert_eq!(a, + array![[0., 1., 2., 3.], + [4., 5., 6., 7.]]); + assert_eq!(a.strides(), &[4, 1]); + + // Changing the memory layout to fit the next append + let mut a2 = a.clone(); + a2.push_column(aview1(&[1., 2.])).unwrap(); + assert_eq!(a2, + array![[0., 1., 2., 3., 1.], + [4., 5., 6., 7., 2.]]); + assert_eq!(a2.strides(), &[1, 2]); + + // Clone the array + + let mut dim = a.raw_dim(); + dim[1] = 0; + let mut b = Array::zeros(dim); + b.append(Axis(1), a.view()).unwrap(); + assert_eq!(b.push_column(aview1(&[1., 2.])), Ok(())); + assert_eq!(b, + array![[0., 1., 2., 3., 1.], + [4., 5., 6., 7., 2.]]); +} + +#[test] +fn push_row_neg_stride_1() +{ + let mut a = Array::zeros((0, 4)); + a.push_row(aview1(&[0., 1., 2., 3.])).unwrap(); + a.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); + assert_eq!(a.shape(), &[2, 4]); + + assert_eq!(a, + array![[0., 1., 2., 3.], + [4., 5., 6., 7.]]); + assert_eq!(a.strides(), &[4, 1]); + + a.invert_axis(Axis(0)); + + // Changing the memory layout to fit the next append + let mut a2 = a.clone(); + println!("a = {:?}", a); + println!("a2 = {:?}", a2); + a2.push_column(aview1(&[1., 2.])).unwrap(); + assert_eq!(a2, + array![[4., 5., 6., 7., 1.], + [0., 1., 2., 3., 2.]]); + assert_eq!(a2.strides(), &[1, 2]); + + a.invert_axis(Axis(1)); + let mut a3 = a.clone(); + a3.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); + assert_eq!(a3, + array![[7., 6., 5., 4.], + [3., 2., 1., 0.], + [4., 5., 6., 7.]]); + assert_eq!(a3.strides(), &[4, 1]); + + a.invert_axis(Axis(0)); + let mut a4 = a.clone(); + a4.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); + assert_eq!(a4, + array![[3., 2., 1., 0.], + [7., 6., 5., 4.], + [4., 5., 6., 7.]]); + assert_eq!(a4.strides(), &[4, -1]); +} + +#[test] +fn push_row_neg_stride_2() +{ + let mut a = Array::zeros((0, 4)); + a.push_row(aview1(&[0., 1., 2., 3.])).unwrap(); + a.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); + assert_eq!(a.shape(), &[2, 4]); + + assert_eq!(a, + array![[0., 1., 2., 3.], + [4., 5., 6., 7.]]); + assert_eq!(a.strides(), &[4, 1]); + + a.invert_axis(Axis(1)); + + // Changing the memory layout to fit the next append + let mut a2 = a.clone(); + println!("a = {:?}", a); + println!("a2 = {:?}", a2); + a2.push_column(aview1(&[1., 2.])).unwrap(); + assert_eq!(a2, + array![[3., 2., 1., 0., 1.], + [7., 6., 5., 4., 2.]]); + assert_eq!(a2.strides(), &[1, 2]); + + a.invert_axis(Axis(0)); + let mut a3 = a.clone(); + a3.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); + assert_eq!(a3, + array![[7., 6., 5., 4.], + [3., 2., 1., 0.], + [4., 5., 6., 7.]]); + assert_eq!(a3.strides(), &[4, 1]); + + a.invert_axis(Axis(1)); + let mut a4 = a.clone(); + a4.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); + assert_eq!(a4, + array![[4., 5., 6., 7.], + [0., 1., 2., 3.], + [4., 5., 6., 7.]]); + assert_eq!(a4.strides(), &[4, 1]); +} + +#[test] +fn push_row_error() +{ + let mut a = Array::zeros((3, 4)); + + assert_eq!(a.push_row(aview1(&[1.])), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_column(aview1(&[1.])), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_column(aview1(&[1., 2., 3.])), + Ok(())); + assert_eq!(a.t(), + array![[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [1., 2., 3.]]); +} + +#[test] +fn push_row_existing() +{ + let mut a = Array::zeros((1, 4)); + a.push_row(aview1(&[0., 1., 2., 3.])).unwrap(); + a.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); + assert_eq!(a.shape(), &[3, 4]); + + assert_eq!(a, + array![[0., 0., 0., 0.], + [0., 1., 2., 3.], + [4., 5., 6., 7.]]); + + assert_eq!(a.push_row(aview1(&[1.])), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_column(aview1(&[1.])), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + assert_eq!(a.push_column(aview1(&[1., 2., 3.])), + Ok(())); + assert_eq!(a, + array![[0., 0., 0., 0., 1.], + [0., 1., 2., 3., 2.], + [4., 5., 6., 7., 3.]]); +} + +#[test] +fn push_row_col_len_1() +{ + // Test appending 1 row and then cols from shape 1 x 1 + let mut a = Array::zeros((1, 1)); + a.push_row(aview1(&[1.])).unwrap(); // shape 2 x 1 + a.push_column(aview1(&[2., 3.])).unwrap(); // shape 2 x 2 + assert_eq!(a.push_row(aview1(&[1.])), + Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); + //assert_eq!(a.push_row(aview1(&[1., 2.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout))); + a.push_column(aview1(&[4., 5.])).unwrap(); // shape 2 x 3 + assert_eq!(a.shape(), &[2, 3]); + + assert_eq!(a, + array![[0., 2., 4.], + [1., 3., 5.]]); +} + +#[test] +fn push_column() +{ + let mut a = Array::zeros((4, 0)); + a.push_column(aview1(&[0., 1., 2., 3.])).unwrap(); + a.push_column(aview1(&[4., 5., 6., 7.])).unwrap(); + assert_eq!(a.shape(), &[4, 2]); + + assert_eq!(a.t(), + array![[0., 1., 2., 3.], + [4., 5., 6., 7.]]); +} + +#[test] +fn append_array1() +{ + let mut a = Array::zeros((0, 4)); + a.append(Axis(0), aview2(&[[0., 1., 2., 3.]])).unwrap(); + println!("{:?}", a); + a.append(Axis(0), aview2(&[[4., 5., 6., 7.]])).unwrap(); + println!("{:?}", a); + //a.push_column(aview1(&[4., 5., 6., 7.])).unwrap(); + //assert_eq!(a.shape(), &[4, 2]); + + assert_eq!(a, + array![[0., 1., 2., 3.], + [4., 5., 6., 7.]]); + + a.append(Axis(0), aview2(&[[5., 5., 4., 4.], [3., 3., 2., 2.]])) + .unwrap(); + println!("{:?}", a); + assert_eq!(a, + array![[0., 1., 2., 3.], + [4., 5., 6., 7.], + [5., 5., 4., 4.], + [3., 3., 2., 2.]]); +} + +#[test] +fn append_array_3d() +{ + let mut a = Array::zeros((0, 2, 2)); + a.append(Axis(0), array![[[0, 1], [2, 3]]].view()).unwrap(); + println!("{:?}", a); + + let aa = array![[[51, 52], [53, 54]], [[55, 56], [57, 58]]]; + let av = aa.view(); + println!("Send {:?} to append", av); + a.append(Axis(0), av.clone()).unwrap(); + + a.swap_axes(0, 1); + let aa = array![[[71, 72], [73, 74]], [[75, 76], [77, 78]]]; + let mut av = aa.view(); + av.swap_axes(0, 1); + println!("Send {:?} to append", av); + a.append(Axis(1), av.clone()).unwrap(); + println!("{:?}", a); + let aa = array![[[81, 82], [83, 84]], [[85, 86], [87, 88]]]; + let mut av = aa.view(); + av.swap_axes(0, 1); + println!("Send {:?} to append", av); + a.append(Axis(1), av).unwrap(); + println!("{:?}", a); + assert_eq!(a, + array![[[0, 1], + [51, 52], + [55, 56], + [71, 72], + [75, 76], + [81, 82], + [85, 86]], + [[2, 3], + [53, 54], + [57, 58], + [73, 74], + [77, 78], + [83, 84], + [87, 88]]]); +} + +#[test] +fn test_append_2d() +{ + // create an empty array and append + let mut a = Array::zeros((0, 4)); + let ones = ArrayView::from(&[1.; 12]) + .into_shape_with_order((3, 4)) + .unwrap(); + let zeros = ArrayView::from(&[0.; 8]) + .into_shape_with_order((2, 4)) + .unwrap(); + a.append(Axis(0), ones).unwrap(); + a.append(Axis(0), zeros).unwrap(); + a.append(Axis(0), ones).unwrap(); + println!("{:?}", a); + assert_eq!(a.shape(), &[8, 4]); + for (i, row) in a.rows().into_iter().enumerate() { + let ones = i < 3 || i >= 5; + assert!(row.iter().all(|&x| x == ones as i32 as f64), "failed on lane {}", i); + } + + let mut a = Array::zeros((0, 4)); + a = a.reversed_axes(); + let ones = ones.reversed_axes(); + let zeros = zeros.reversed_axes(); + a.append(Axis(1), ones).unwrap(); + a.append(Axis(1), zeros).unwrap(); + a.append(Axis(1), ones).unwrap(); + println!("{:?}", a); + assert_eq!(a.shape(), &[4, 8]); + + for (i, row) in a.columns().into_iter().enumerate() { + let ones = i < 3 || i >= 5; + assert!(row.iter().all(|&x| x == ones as i32 as f64), "failed on lane {}", i); + } +} + +#[test] +fn test_append_middle_axis() +{ + // ensure we can append to Axis(1) by letting it become outermost + let mut a = Array::::zeros((3, 0, 2)); + a.append( + Axis(1), + Array::from_iter(0..12) + .into_shape_with_order((3, 2, 2)) + .unwrap() + .view(), + ) + .unwrap(); + println!("{:?}", a); + a.append( + Axis(1), + Array::from_iter(12..24) + .into_shape_with_order((3, 2, 2)) + .unwrap() + .view(), + ) + .unwrap(); + println!("{:?}", a); + + // ensure we can append to Axis(1) by letting it become outermost + let mut a = Array::::zeros((3, 1, 2)); + a.append( + Axis(1), + Array::from_iter(0..12) + .into_shape_with_order((3, 2, 2)) + .unwrap() + .view(), + ) + .unwrap(); + println!("{:?}", a); + a.append( + Axis(1), + Array::from_iter(12..24) + .into_shape_with_order((3, 2, 2)) + .unwrap() + .view(), + ) + .unwrap(); + println!("{:?}", a); +} + +#[test] +fn test_append_zero_size() +{ + { + let mut a = Array::::zeros((0, 0)); + a.append(Axis(0), aview2(&[[]])).unwrap(); + a.append(Axis(0), aview2(&[[]])).unwrap(); + assert_eq!(a.len(), 0); + assert_eq!(a.shape(), &[2, 0]); + } + + { + let mut a = Array::::zeros((0, 0)); + a.append(Axis(1), ArrayView1::from(&[]).into_shape_with_order((0, 1)).unwrap()) + .unwrap(); + a.append(Axis(1), ArrayView1::from(&[]).into_shape_with_order((0, 1)).unwrap()) + .unwrap(); + assert_eq!(a.len(), 0); + assert_eq!(a.shape(), &[0, 2]); + } +} + +#[test] +fn push_row_neg_stride_3() +{ + let mut a = Array::zeros((0, 4)); + a.push_row(aview1(&[0., 1., 2., 3.])).unwrap(); + a.invert_axis(Axis(1)); + a.push_row(aview1(&[4., 5., 6., 7.])).unwrap(); + assert_eq!(a.shape(), &[2, 4]); + assert_eq!(a, array![[3., 2., 1., 0.], [4., 5., 6., 7.]]); + assert_eq!(a.strides(), &[4, -1]); +} + +#[test] +fn push_row_ignore_strides_length_one_axes() +{ + let strides = &[0, 1, 10, 20]; + for invert in &[vec![], vec![0], vec![1], vec![0, 1]] { + for &stride0 in strides { + for &stride1 in strides { + let mut a = Array::from_shape_vec([1, 1].strides([stride0, stride1]), vec![0.]).unwrap(); + for &ax in invert { + a.invert_axis(Axis(ax)); + } + a.push_row(aview1(&[1.])).unwrap(); + assert_eq!(a.shape(), &[2, 1]); + assert_eq!(a, array![[0.], [1.]]); + assert_eq!(a.stride_of(Axis(0)), 1); + } + } + } +} + +#[test] +#[should_panic(expected = "IncompatibleShape")] +fn zero_dimensional_error1() +{ + let mut a = Array::zeros(()).into_dyn(); + a.append(Axis(0), arr0(0).into_dyn().view()).unwrap(); +} + +#[test] +#[should_panic(expected = "IncompatibleShape")] +fn zero_dimensional_error2() +{ + let mut a = Array::zeros(()).into_dyn(); + a.push(Axis(0), arr0(0).into_dyn().view()).unwrap(); +} + +#[test] +fn zero_dimensional_ok() +{ + let mut a = Array::zeros(0); + let one = aview0(&1); + let two = aview0(&2); + a.push(Axis(0), two).unwrap(); + a.push(Axis(0), one).unwrap(); + a.push(Axis(0), one).unwrap(); + assert_eq!(a, array![2, 1, 1]); +} diff --git a/tests/array-construct.rs b/tests/array-construct.rs index 6c1caf1ef..9f8418467 100644 --- a/tests/array-construct.rs +++ b/tests/array-construct.rs @@ -1,33 +1,33 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names )] use defmac::defmac; +use ndarray::arr3; use ndarray::prelude::*; +use ndarray::Zip; #[test] -fn test_from_shape_fn() { +fn test_from_shape_fn() +{ let step = 3.1; - let h = Array::from_shape_fn((5, 5), |(i, j)| { - f64::sin(i as f64 / step) * f64::cos(j as f64 / step) - }); + let h = Array::from_shape_fn((5, 5), |(i, j)| f64::sin(i as f64 / step) * f64::cos(j as f64 / step)); assert_eq!(h.shape(), &[5, 5]); } #[test] -fn test_dimension_zero() { +fn test_dimension_zero() +{ let a: Array2 = Array2::from(vec![[], [], []]); - assert_eq!(vec![0.; 0], a.into_raw_vec()); + assert_eq!((vec![0.; 0], None), a.into_raw_vec_and_offset()); let a: Array3 = Array3::from(vec![[[]], [[]], [[]]]); - assert_eq!(vec![0.; 0], a.into_raw_vec()); + assert_eq!((vec![0.; 0], None), a.into_raw_vec_and_offset()); } #[test] #[cfg(feature = "approx")] -fn test_arc_into_owned() { +fn test_arc_into_owned() +{ use approx::assert_abs_diff_ne; let a = Array2::from_elem((5, 5), 1.).into_shared(); @@ -40,7 +40,8 @@ fn test_arc_into_owned() { } #[test] -fn test_arcarray_thread_safe() { +fn test_arcarray_thread_safe() +{ fn is_send(_t: &T) {} fn is_sync(_t: &T) {} let a = Array2::from_elem((5, 5), 1.).into_shared(); @@ -50,22 +51,8 @@ fn test_arcarray_thread_safe() { } #[test] -fn test_uninit() { - unsafe { - let mut a = Array::::uninitialized((3, 4).f()); - assert_eq!(a.dim(), (3, 4)); - assert_eq!(a.strides(), &[1, 3]); - let b = Array::::linspace(0., 25., a.len()) - .into_shape(a.dim()) - .unwrap(); - a.assign(&b); - assert_eq!(&a, &b); - assert_eq!(a.t(), b.t()); - } -} - -#[test] -fn test_from_fn_c0() { +fn test_from_fn_c0() +{ let a = Array::from_shape_fn((), |i| i); assert_eq!(a[()], ()); assert_eq!(a.len(), 1); @@ -73,7 +60,8 @@ fn test_from_fn_c0() { } #[test] -fn test_from_fn_c1() { +fn test_from_fn_c1() +{ let a = Array::from_shape_fn(28, |i| i); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt); @@ -81,7 +69,8 @@ fn test_from_fn_c1() { } #[test] -fn test_from_fn_c() { +fn test_from_fn_c() +{ let a = Array::from_shape_fn((4, 7), |i| i); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt); @@ -89,7 +78,8 @@ fn test_from_fn_c() { } #[test] -fn test_from_fn_c3() { +fn test_from_fn_c3() +{ let a = Array::from_shape_fn((4, 3, 7), |i| i); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt); @@ -97,7 +87,8 @@ fn test_from_fn_c3() { } #[test] -fn test_from_fn_f0() { +fn test_from_fn_f0() +{ let a = Array::from_shape_fn(().f(), |i| i); assert_eq!(a[()], ()); assert_eq!(a.len(), 1); @@ -105,7 +96,8 @@ fn test_from_fn_f0() { } #[test] -fn test_from_fn_f1() { +fn test_from_fn_f1() +{ let a = Array::from_shape_fn(28.f(), |i| i); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt); @@ -113,7 +105,8 @@ fn test_from_fn_f1() { } #[test] -fn test_from_fn_f() { +fn test_from_fn_f() +{ let a = Array::from_shape_fn((4, 7).f(), |i| i); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt); @@ -121,7 +114,8 @@ fn test_from_fn_f() { } #[test] -fn test_from_fn_f_with_zero() { +fn test_from_fn_f_with_zero() +{ defmac!(test_from_fn_f_with_zero shape => { let a = Array::from_shape_fn(shape.f(), |i| i); assert_eq!(a.len(), 0); @@ -136,7 +130,8 @@ fn test_from_fn_f_with_zero() { } #[test] -fn test_from_fn_f3() { +fn test_from_fn_f3() +{ let a = Array::from_shape_fn((4, 2, 7).f(), |i| i); for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt); @@ -144,53 +139,169 @@ fn test_from_fn_f3() { } #[test] -fn deny_wraparound_from_vec() { +fn deny_wraparound_from_vec() +{ let five = vec![0; 5]; let five_large = Array::from_shape_vec((3, 7, 29, 36760123, 823996703), five.clone()); + println!("{:?}", five_large); assert!(five_large.is_err()); let six = Array::from_shape_vec(6, five.clone()); assert!(six.is_err()); } #[test] -fn test_ones() { +fn test_ones() +{ let mut a = Array::::zeros((2, 3, 4)); a.fill(1.0); let b = Array::::ones((2, 3, 4)); assert_eq!(a, b); } +#[test] +fn test_from_shape_empty_with_neg_stride() +{ + // Issue #998, negative strides for an axis where it doesn't matter. + let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]; + let v = s[..12].to_vec(); + let v_ptr = v.as_ptr(); + let a = Array::from_shape_vec((2, 0, 2).strides((1, -4isize as usize, 2)), v).unwrap(); + assert_eq!(a, arr3(&[[[0; 2]; 0]; 2])); + assert_eq!(a.as_ptr(), v_ptr); +} + +#[test] +fn test_from_shape_with_neg_stride() +{ + // Issue #998, negative strides for an axis where it doesn't matter. + let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]; + let v = s[..12].to_vec(); + let v_ptr = v.as_ptr(); + let a = Array::from_shape_vec((2, 1, 2).strides((1, -4isize as usize, 2)), v).unwrap(); + assert_eq!(a, arr3(&[[[0, 2]], + [[1, 3]]])); + assert_eq!(a.as_ptr(), v_ptr); +} + +#[test] +fn test_from_shape_2_2_2_with_neg_stride() +{ + // Issue #998, negative strides for an axis where it doesn't matter. + let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]; + let v = s[..12].to_vec(); + let v_ptr = v.as_ptr(); + let a = Array::from_shape_vec((2, 2, 2).strides((1, -4isize as usize, 2)), v).unwrap(); + assert_eq!(a, arr3(&[[[4, 6], + [0, 2]], + [[5, 7], + [1, 3]]])); + assert_eq!(a.as_ptr(), v_ptr.wrapping_add(4)); +} + #[should_panic] #[test] -fn deny_wraparound_zeros() { +fn deny_wraparound_zeros() +{ //2^64 + 5 = 18446744073709551621 = 3×7×29×36760123×823996703 (5 distinct prime factors) let _five_large = Array::::zeros((3, 7, 29, 36760123, 823996703)); } #[should_panic] #[test] -fn deny_wraparound_reshape() { +fn deny_wraparound_reshape() +{ //2^64 + 5 = 18446744073709551621 = 3×7×29×36760123×823996703 (5 distinct prime factors) let five = Array::::zeros(5); - let _five_large = five.into_shape((3, 7, 29, 36760123, 823996703)).unwrap(); + let _five_large = five + .into_shape_with_order((3, 7, 29, 36760123, 823996703)) + .unwrap(); } #[should_panic] #[test] -fn deny_wraparound_default() { +fn deny_wraparound_default() +{ let _five_large = Array::::default((3, 7, 29, 36760123, 823996703)); } #[should_panic] #[test] -fn deny_wraparound_from_shape_fn() { +fn deny_wraparound_from_shape_fn() +{ let _five_large = Array::::from_shape_fn((3, 7, 29, 36760123, 823996703), |_| 0.); } #[should_panic] #[test] -fn deny_wraparound_uninit() { +fn deny_wraparound_uninit() +{ + let _five_large = Array::::uninit((3, 7, 29, 36760123, 823996703)); +} + +#[should_panic] +#[test] +fn deny_slice_with_too_many_rows_to_arrayview2() +{ + let _view = ArrayView2::from(&[[0u8; 0]; usize::MAX][..]); +} + +#[should_panic] +#[test] +fn deny_slice_with_too_many_zero_sized_elems_to_arrayview2() +{ + let _view = ArrayView2::from(&[[(); isize::MAX as usize]; isize::MAX as usize][..]); +} + +#[should_panic] +#[test] +fn deny_slice_with_too_many_rows_to_arrayviewmut2() +{ + let _view = ArrayViewMut2::from(&mut [[0u8; 0]; usize::MAX][..]); +} + +#[should_panic] +#[test] +fn deny_slice_with_too_many_zero_sized_elems_to_arrayviewmut2() +{ + let _view = ArrayViewMut2::from(&mut [[(); isize::MAX as usize]; isize::MAX as usize][..]); +} + +#[test] +fn maybe_uninit_1() +{ + use std::mem::MaybeUninit; + unsafe { - let _five_large = Array::::uninitialized((3, 7, 29, 36760123, 823996703)); + // Array + type Mat = Array; + + let mut a = Mat::uninit((10, 10)); + a.mapv_inplace(|_| MaybeUninit::new(1.)); + + let a_init = a.assume_init(); + assert_eq!(a_init, Array2::from_elem(a_init.dim(), 1.)); + + // ArcArray + type ArcMat = ArcArray; + + let mut a = ArcMat::uninit((10, 10)); + a.mapv_inplace(|_| MaybeUninit::new(1.)); + let a2 = a.clone(); + + let a_init = a.assume_init(); + assert_eq!(a_init, Array2::from_elem(a_init.dim(), 1.)); + + // ArrayView + let av_init = a2.view().assume_init(); + assert_eq!(av_init, Array2::from_elem(a_init.dim(), 1.)); + + // RawArrayViewMut + let mut a = Mat::uninit((10, 10)); + let v = a.raw_view_mut(); + Zip::from(v).for_each(|ptr| *(*ptr).as_mut_ptr() = 1.); + + let u = a.raw_view_mut().assume_init(); + + Zip::from(u).for_each(|ptr| assert_eq!(*ptr, 1.)); } } diff --git a/tests/array.rs b/tests/array.rs index db54b7e5f..ac38fdd03 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1,19 +1,20 @@ #![allow(non_snake_case)] #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names, + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names, clippy::float_cmp )] +use approx::assert_relative_eq; use defmac::defmac; -use itertools::{enumerate, zip, Itertools}; +#[allow(deprecated)] +use itertools::{zip, Itertools}; use ndarray::indices; use ndarray::prelude::*; +use ndarray::ErrorKind; use ndarray::{arr3, rcarr2}; -use ndarray::{Slice, SliceInfo, SliceOrIndex}; -use std::iter::FromIterator; +use ndarray::{Slice, SliceInfo, SliceInfoElem}; +use num_complex::Complex; +use std::convert::TryFrom; macro_rules! assert_panics { ($body:expr) => { @@ -30,7 +31,8 @@ macro_rules! assert_panics { } #[test] -fn test_matmul_arcarray() { +fn test_matmul_arcarray() +{ let mut A = ArcArray::::zeros((2, 3)); for (i, elt) in A.iter_mut().enumerate() { *elt = i; @@ -46,8 +48,7 @@ fn test_matmul_arcarray() { println!("B = \n{:?}", B); println!("A x B = \n{:?}", c); unsafe { - let result = - ArcArray::from_shape_vec_unchecked((2, 4), vec![20, 23, 26, 29, 56, 68, 80, 92]); + let result = ArcArray::from_shape_vec_unchecked((2, 4), vec![20, 23, 26, 29, 56, 68, 80, 92]); assert_eq!(c.shape(), result.shape()); assert!(c.iter().zip(result.iter()).all(|(a, b)| a == b)); assert!(c == result); @@ -55,22 +56,26 @@ fn test_matmul_arcarray() { } #[allow(unused)] -fn arrayview_shrink_lifetime<'a, 'b: 'a>(view: ArrayView1<'b, f64>) -> ArrayView1<'a, f64> { +fn arrayview_shrink_lifetime<'a, 'b: 'a>(view: ArrayView1<'b, f64>) -> ArrayView1<'a, f64> +{ view.reborrow() } #[allow(unused)] -fn arrayviewmut_shrink_lifetime<'a, 'b: 'a>( - view: ArrayViewMut1<'b, f64>, -) -> ArrayViewMut1<'a, f64> { +fn arrayviewmut_shrink_lifetime<'a, 'b: 'a>(view: ArrayViewMut1<'b, f64>) -> ArrayViewMut1<'a, f64> +{ view.reborrow() } #[test] -fn test_mat_mul() { +#[cfg(feature = "std")] +fn test_mat_mul() +{ // smoke test, a big matrix multiplication of uneven size let (n, m) = (45, 33); - let a = ArcArray::linspace(0., ((n * m) - 1) as f32, n as usize * m as usize).reshape((n, m)); + let a = ArcArray::linspace(0., ((n * m) - 1) as f32, n as usize * m as usize) + .into_shape_with_order((n, m)) + .unwrap(); let b = ArcArray::eye(m); assert_eq!(a.dot(&b), a); let c = ArcArray::eye(n); @@ -79,14 +84,15 @@ fn test_mat_mul() { #[deny(unsafe_code)] #[test] -fn test_slice() { +fn test_slice() +{ let mut A = ArcArray::::zeros((3, 4, 5)); for (i, elt) in A.iter_mut().enumerate() { *elt = i; } - let vi = A.slice(s![1.., ..;2, Slice::new(0, None, 2)]); - assert_eq!(vi.shape(), &[2, 2, 3]); + let vi = A.slice(s![1.., ..;2, NewAxis, Slice::new(0, None, 2)]); + assert_eq!(vi.shape(), &[2, 2, 1, 3]); let vi = A.slice(s![.., .., ..]); assert_eq!(vi.shape(), A.shape()); assert!(vi.iter().zip(A.iter()).all(|(a, b)| a == b)); @@ -94,13 +100,15 @@ fn test_slice() { #[deny(unsafe_code)] #[test] -fn test_slice_ix0() { +fn test_slice_ix0() +{ let arr = arr0(5); assert_eq!(arr.slice(s![]), aview0(&5)); } #[test] -fn test_slice_edge_cases() { +fn test_slice_edge_cases() +{ let mut arr = Array3::::zeros((3, 4, 5)); arr.slice_collapse(s![0..0;-1, .., ..]); assert_eq!(arr.shape(), &[0, 4, 5]); @@ -110,7 +118,8 @@ fn test_slice_edge_cases() { } #[test] -fn test_slice_inclusive_range() { +fn test_slice_inclusive_range() +{ let arr = array![[1, 2, 3], [4, 5, 6]]; assert_eq!(arr.slice(s![1..=1, 1..=2]), array![[5, 6]]); assert_eq!(arr.slice(s![1..=-1, -2..=2;-1]), array![[6, 5]]); @@ -124,7 +133,8 @@ fn test_slice_inclusive_range() { /// `ArrayView1` and `ArrayView2`, so the compiler needs to determine which /// type is the correct result for the `.slice()` call. #[test] -fn test_slice_infer() { +fn test_slice_infer() +{ let a = array![1., 2.]; let b = array![[3., 4.], [5., 6.]]; b.slice(s![..-1, ..]).dot(&a); @@ -132,20 +142,21 @@ fn test_slice_infer() { } #[test] -fn test_slice_with_many_dim() { +fn test_slice_with_many_dim() +{ let mut A = ArcArray::::zeros(&[3, 1, 4, 1, 3, 2, 1][..]); for (i, elt) in A.iter_mut().enumerate() { *elt = i; } - let vi = A.slice(s![..2, .., ..;2, ..1, ..1, 1.., ..]); - let new_shape = &[2, 1, 2, 1, 1, 1, 1][..]; + let vi = A.slice(s![..2, NewAxis, .., ..;2, NewAxis, ..1, ..1, 1.., ..]); + let new_shape = &[2, 1, 1, 2, 1, 1, 1, 1, 1][..]; assert_eq!(vi.shape(), new_shape); let correct = array![ [A[&[0, 0, 0, 0, 0, 1, 0][..]], A[&[0, 0, 2, 0, 0, 1, 0][..]]], [A[&[1, 0, 0, 0, 0, 1, 0][..]], A[&[1, 0, 2, 0, 0, 1, 0][..]]] ] - .into_shape(new_shape) + .into_shape_with_order(new_shape) .unwrap(); assert_eq!(vi, correct); @@ -159,14 +170,16 @@ fn test_slice_with_many_dim() { } #[test] -fn test_slice_range_variable() { +fn test_slice_range_variable() +{ let range = 1..4; let arr = array![0, 1, 2, 3, 4]; assert_eq!(arr.slice(s![range]), array![1, 2, 3]); } #[test] -fn test_slice_args_eval_range_once() { +fn test_slice_args_eval_range_once() +{ let mut eval_count = 0; { let mut range = || { @@ -180,7 +193,8 @@ fn test_slice_args_eval_range_once() { } #[test] -fn test_slice_args_eval_step_once() { +fn test_slice_args_eval_step_once() +{ let mut eval_count = 0; { let mut step = || { @@ -194,107 +208,150 @@ fn test_slice_args_eval_step_once() { } #[test] -fn test_slice_array_fixed() { +fn test_slice_array_fixed() +{ let mut arr = Array3::::zeros((5, 2, 5)); - let info = s![1.., 1, ..;2]; + let info = s![1.., 1, NewAxis, ..;2]; arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + let info2 = s![1.., 1, ..;2]; + arr.view().slice_collapse(info2); } #[test] -fn test_slice_dyninput_array_fixed() { +fn test_slice_dyninput_array_fixed() +{ let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = s![1.., 1, ..;2]; + let info = s![1.., 1, NewAxis, ..;2]; arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info.as_ref()); + let info2 = s![1.., 1, ..;2]; + arr.view().slice_collapse(info2); } #[test] -fn test_slice_array_dyn() { +fn test_slice_array_dyn() +{ let mut arr = Array3::::zeros((5, 2, 5)); - let info = &SliceInfo::<_, IxDyn>::new([ - SliceOrIndex::from(1..), - SliceOrIndex::from(1), - SliceOrIndex::from(..).step_by(2), + let info = SliceInfo::<_, Ix3, IxDyn>::try_from([ + SliceInfoElem::from(1..), + SliceInfoElem::from(1), + SliceInfoElem::from(NewAxis), + SliceInfoElem::from(Slice::from(..).step_by(2)), ]) .unwrap(); arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info); + let info2 = SliceInfo::<_, Ix3, IxDyn>::try_from([ + SliceInfoElem::from(1..), + SliceInfoElem::from(1), + SliceInfoElem::from(Slice::from(..).step_by(2)), + ]) + .unwrap(); + arr.view().slice_collapse(info2); } #[test] -fn test_slice_dyninput_array_dyn() { +fn test_slice_dyninput_array_dyn() +{ let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, IxDyn>::new([ - SliceOrIndex::from(1..), - SliceOrIndex::from(1), - SliceOrIndex::from(..).step_by(2), + let info = SliceInfo::<_, Ix3, IxDyn>::try_from([ + SliceInfoElem::from(1..), + SliceInfoElem::from(1), + SliceInfoElem::from(NewAxis), + SliceInfoElem::from(Slice::from(..).step_by(2)), ]) .unwrap(); arr.slice(info); arr.slice_mut(info); arr.view().slice_move(info); - arr.view().slice_collapse(info.as_ref()); + let info2 = SliceInfo::<_, Ix3, IxDyn>::try_from([ + SliceInfoElem::from(1..), + SliceInfoElem::from(1), + SliceInfoElem::from(Slice::from(..).step_by(2)), + ]) + .unwrap(); + arr.view().slice_collapse(info2); } #[test] -fn test_slice_dyninput_vec_fixed() { +fn test_slice_dyninput_vec_fixed() +{ let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, Ix2>::new(vec![ - SliceOrIndex::from(1..), - SliceOrIndex::from(1), - SliceOrIndex::from(..).step_by(2), + let info = &SliceInfo::<_, Ix3, Ix3>::try_from(vec![ + SliceInfoElem::from(1..), + SliceInfoElem::from(1), + SliceInfoElem::from(NewAxis), + SliceInfoElem::from(Slice::from(..).step_by(2)), + ]) + .unwrap(); + arr.slice(info); + arr.slice_mut(info); + arr.view().slice_move(info); + let info2 = SliceInfo::<_, Ix3, Ix2>::try_from(vec![ + SliceInfoElem::from(1..), + SliceInfoElem::from(1), + SliceInfoElem::from(Slice::from(..).step_by(2)), ]) .unwrap(); - arr.slice(info.as_ref()); - arr.slice_mut(info.as_ref()); - arr.view().slice_move(info.as_ref()); - arr.view().slice_collapse(info.as_ref()); + arr.view().slice_collapse(info2); } #[test] -fn test_slice_dyninput_vec_dyn() { +fn test_slice_dyninput_vec_dyn() +{ let mut arr = Array3::::zeros((5, 2, 5)).into_dyn(); - let info = &SliceInfo::<_, IxDyn>::new(vec![ - SliceOrIndex::from(1..), - SliceOrIndex::from(1), - SliceOrIndex::from(..).step_by(2), + let info = &SliceInfo::<_, Ix3, IxDyn>::try_from(vec![ + SliceInfoElem::from(1..), + SliceInfoElem::from(1), + SliceInfoElem::from(NewAxis), + SliceInfoElem::from(Slice::from(..).step_by(2)), + ]) + .unwrap(); + arr.slice(info); + arr.slice_mut(info); + arr.view().slice_move(info); + let info2 = SliceInfo::<_, Ix3, IxDyn>::try_from(vec![ + SliceInfoElem::from(1..), + SliceInfoElem::from(1), + SliceInfoElem::from(Slice::from(..).step_by(2)), ]) .unwrap(); - arr.slice(info.as_ref()); - arr.slice_mut(info.as_ref()); - arr.view().slice_move(info.as_ref()); - arr.view().slice_collapse(info.as_ref()); + arr.view().slice_collapse(info2); } #[test] -fn test_slice_with_subview() { +fn test_slice_with_subview_and_new_axis() +{ let mut arr = ArcArray::::zeros((3, 5, 4)); for (i, elt) in arr.iter_mut().enumerate() { *elt = i; } - let vi = arr.slice(s![1.., 2, ..;2]); - assert_eq!(vi.shape(), &[2, 2]); + let vi = arr.slice(s![NewAxis, 1.., 2, ..;2]); + assert_eq!(vi.shape(), &[1, 2, 2]); assert!(vi .iter() - .zip(arr.index_axis(Axis(1), 2).slice(s![1.., ..;2]).iter()) + .zip( + arr.index_axis(Axis(1), 2) + .slice(s![1.., ..;2]) + .insert_axis(Axis(0)) + .iter() + ) .all(|(a, b)| a == b)); - let vi = arr.slice(s![1, 2, ..;2]); - assert_eq!(vi.shape(), &[2]); + let vi = arr.slice(s![1, NewAxis, 2, ..;2]); + assert_eq!(vi.shape(), &[1, 2]); assert!(vi .iter() .zip( arr.index_axis(Axis(0), 1) .index_axis(Axis(0), 2) .slice(s![..;2]) + .insert_axis(Axis(0)) .iter() ) .all(|(a, b)| a == b)); @@ -305,7 +362,8 @@ fn test_slice_with_subview() { } #[test] -fn test_slice_collapse_with_indices() { +fn test_slice_collapse_with_indices() +{ let mut arr = ArcArray::::zeros((3, 5, 4)); for (i, elt) in arr.iter_mut().enumerate() { *elt = i; @@ -343,7 +401,16 @@ fn test_slice_collapse_with_indices() { } #[test] -fn test_multislice() { +#[should_panic] +fn test_slice_collapse_with_newaxis() +{ + let mut arr = Array2::::zeros((2, 3)); + arr.slice_collapse(s![0, 0, NewAxis]); +} + +#[test] +fn test_multislice() +{ macro_rules! do_test { ($arr:expr, $($s:expr),*) => { { @@ -357,7 +424,9 @@ fn test_multislice() { }; } - let mut arr = Array1::from_iter(0..48).into_shape((8, 6)).unwrap(); + let mut arr = Array1::from_iter(0..48) + .into_shape_with_order((8, 6)) + .unwrap(); assert_eq!( (arr.clone().view_mut(),), @@ -379,10 +448,11 @@ fn test_multislice() { } #[test] -fn test_multislice_intersecting() { +fn test_multislice_intersecting() +{ assert_panics!({ let mut arr = Array2::::zeros((8, 6)); - arr.multi_slice_mut((s![3, ..], s![3, ..])); + arr.multi_slice_mut((s![3, .., NewAxis], s![3, ..])); }); assert_panics!({ let mut arr = Array2::::zeros((8, 6)); @@ -390,7 +460,7 @@ fn test_multislice_intersecting() { }); assert_panics!({ let mut arr = Array2::::zeros((8, 6)); - arr.multi_slice_mut((s![3, ..], s![..;3, ..])); + arr.multi_slice_mut((s![3, ..], s![..;3, NewAxis, ..])); }); assert_panics!({ let mut arr = Array2::::zeros((8, 6)); @@ -420,45 +490,52 @@ fn test_multislice_intersecting() { #[should_panic] #[test] -fn index_out_of_bounds() { +fn index_out_of_bounds() +{ let mut a = Array::::zeros((3, 4)); a[[3, 2]] = 1; } #[should_panic] #[test] -fn slice_oob() { +fn slice_oob() +{ let a = ArcArray::::zeros((3, 4)); let _vi = a.slice(s![..10, ..]); } #[should_panic] #[test] -fn slice_axis_oob() { +fn slice_axis_oob() +{ let a = ArcArray::::zeros((3, 4)); let _vi = a.slice_axis(Axis(0), Slice::new(0, Some(10), 1)); } #[should_panic] #[test] -fn slice_wrong_dim() { +fn slice_wrong_dim() +{ let a = ArcArray::::zeros(vec![3, 4, 5]); let _vi = a.slice(s![.., ..]); } #[test] -fn test_index() { +fn test_index() +{ let mut A = ArcArray::::zeros((2, 3)); for (i, elt) in A.iter_mut().enumerate() { *elt = i; } + #[allow(deprecated)] for ((i, j), a) in zip(indices((2, 3)), &A) { assert_eq!(*a, A[[i, j]]); } let vi = A.slice(s![1.., ..;2]); let mut it = vi.iter(); + #[allow(deprecated)] for ((i, j), x) in zip(indices((1, 2)), &mut it) { assert_eq!(*x, vi[[i, j]]); } @@ -466,18 +543,20 @@ fn test_index() { } #[test] -fn test_index_arrays() { +fn test_index_arrays() +{ let a = Array1::from_iter(0..12); assert_eq!(a[1], a[[1]]); - let v = a.view().into_shape((3, 4)).unwrap(); + let v = a.view().into_shape_with_order((3, 4)).unwrap(); assert_eq!(a[1], v[[0, 1]]); - let w = v.into_shape((2, 2, 3)).unwrap(); + let w = v.into_shape_with_order((2, 2, 3)).unwrap(); assert_eq!(a[1], w[[0, 0, 1]]); } #[test] #[allow(clippy::assign_op_pattern)] -fn test_add() { +fn test_add() +{ let mut A = ArcArray::::zeros((2, 2)); for (i, elt) in A.iter_mut().enumerate() { *elt = i; @@ -492,8 +571,11 @@ fn test_add() { } #[test] -fn test_multidim() { - let mut mat = ArcArray::zeros(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6)); +fn test_multidim() +{ + let mut mat = ArcArray::zeros(2 * 3 * 4 * 5 * 6) + .into_shape_with_order((2, 3, 4, 5, 6)) + .unwrap(); mat[(0, 0, 0, 0, 0)] = 22u8; { for (i, elt) in mat.iter_mut().enumerate() { @@ -515,7 +597,8 @@ array([[[ 7, 6], [ 9, 8]]]) */ #[test] -fn test_negative_stride_arcarray() { +fn test_negative_stride_arcarray() +{ let mut mat = ArcArray::zeros((2, 4, 2)); mat[[0, 0, 0]] = 1.0f32; for (i, elt) in mat.iter_mut().enumerate() { @@ -526,9 +609,7 @@ fn test_negative_stride_arcarray() { let vi = mat.slice(s![.., ..;-1, ..;-1]); assert_eq!(vi.shape(), &[2, 4, 2]); // Test against sequential iterator - let seq = [ - 7f32, 6., 5., 4., 3., 2., 1., 0., 15., 14., 13., 12., 11., 10., 9., 8., - ]; + let seq = [7f32, 6., 5., 4., 3., 2., 1., 0., 15., 14., 13., 12., 11., 10., 9., 8.]; for (a, b) in vi.iter().zip(seq.iter()) { assert_eq!(*a, *b); } @@ -543,7 +624,8 @@ fn test_negative_stride_arcarray() { } #[test] -fn test_cow() { +fn test_cow() +{ let mut mat = ArcArray::zeros((2, 2)); mat[[0, 0]] = 1; let n = mat.clone(); @@ -555,7 +637,7 @@ fn test_cow() { assert_eq!(n[[0, 0]], 1); assert_eq!(n[[0, 1]], 0); assert_eq!(n.get((0, 1)), Some(&0)); - let mut rev = mat.reshape(4); + let mut rev = mat.into_shape_with_order(4).unwrap(); rev.slice_collapse(s![..;-1]); assert_eq!(rev[0], 4); assert_eq!(rev[1], 3); @@ -575,7 +657,8 @@ fn test_cow() { } #[test] -fn test_cow_shrink() { +fn test_cow_shrink() +{ // A test for clone-on-write in the case that // mutation shrinks the array and gives it different strides // @@ -594,7 +677,7 @@ fn test_cow_shrink() { assert_eq!(n[[0, 1]], 0); assert_eq!(n.get((0, 1)), Some(&0)); // small has non-C strides this way - let mut small = mat.reshape(6); + let mut small = mat.into_shape_with_order(6).unwrap(); small.slice_collapse(s![4..;-1]); assert_eq!(small[0], 6); assert_eq!(small[1], 5); @@ -609,40 +692,45 @@ fn test_cow_shrink() { } #[test] -fn test_sub() { - let mat = ArcArray::linspace(0., 15., 16).reshape((2, 4, 2)); +#[cfg(feature = "std")] +fn test_sub() +{ + let mat = ArcArray::linspace(0., 15., 16) + .into_shape_with_order((2, 4, 2)) + .unwrap(); let s1 = mat.index_axis(Axis(0), 0); let s2 = mat.index_axis(Axis(0), 1); assert_eq!(s1.shape(), &[4, 2]); assert_eq!(s2.shape(), &[4, 2]); - let n = ArcArray::linspace(8., 15., 8).reshape((4, 2)); + let n = ArcArray::linspace(8., 15., 8) + .into_shape_with_order((4, 2)) + .unwrap(); assert_eq!(n, s2); - let m = ArcArray::from(vec![2., 3., 10., 11.]).reshape((2, 2)); + let m = ArcArray::from(vec![2., 3., 10., 11.]) + .into_shape_with_order((2, 2)) + .unwrap(); assert_eq!(m, mat.index_axis(Axis(1), 1)); } #[should_panic] #[test] -fn test_sub_oob_1() { - let mat = ArcArray::linspace(0., 15., 16).reshape((2, 4, 2)); +#[cfg(feature = "std")] +fn test_sub_oob_1() +{ + let mat = ArcArray::linspace(0., 15., 16) + .into_shape_with_order((2, 4, 2)) + .unwrap(); mat.index_axis(Axis(0), 2); } #[test] #[cfg(feature = "approx")] -fn test_select() { +fn test_select() +{ use approx::assert_abs_diff_eq; // test for 2-d array - let x = arr2(&[ - [0., 1.], - [1., 0.], - [1., 0.], - [1., 0.], - [1., 0.], - [0., 1.], - [0., 1.], - ]); + let x = arr2(&[[0., 1.], [1., 0.], [1., 0.], [1., 0.], [1., 0.], [0., 1.], [0., 1.]]); let r = x.select(Axis(0), &[1, 3, 5]); let c = x.select(Axis(1), &[1]); let r_target = arr2(&[[1., 0.], [1., 0.], [0., 1.]]); @@ -651,10 +739,7 @@ fn test_select() { assert_abs_diff_eq!(c, c_target.t()); // test for 3-d array - let y = arr3(&[ - [[1., 2., 3.], [1.5, 1.5, 3.]], - [[1., 2., 8.], [1., 2.5, 3.]], - ]); + let y = arr3(&[[[1., 2., 3.], [1.5, 1.5, 3.]], [[1., 2., 8.], [1., 2.5, 3.]]]); let r = y.select(Axis(1), &[1]); let c = y.select(Axis(2), &[1]); let r_target = arr3(&[[[1.5, 1.5, 3.]], [[1., 2.5, 3.]]]); @@ -664,13 +749,28 @@ fn test_select() { } #[test] -fn diag() { +fn test_select_1d() +{ + let x = arr1(&[0, 1, 2, 3, 4, 5, 6]); + let r1 = x.select(Axis(0), &[1, 3, 4, 2, 2, 5]); + assert_eq!(r1, arr1(&[1, 3, 4, 2, 2, 5])); + // select nothing + let r2 = x.select(Axis(0), &[]); + assert_eq!(r2, arr1(&[])); + // select nothing from empty + let r3 = r2.select(Axis(0), &[]); + assert_eq!(r3, arr1(&[])); +} + +#[test] +fn diag() +{ let d = arr2(&[[1., 2., 3.0f32]]).into_diag(); assert_eq!(d.dim(), 1); let a = arr2(&[[1., 2., 3.0f32], [0., 0., 0.]]); let d = a.view().into_diag(); assert_eq!(d.dim(), 2); - let d = arr2::(&[[]]).into_diag(); + let d = arr2::(&[[]]).into_diag(); assert_eq!(d.dim(), 0); let d = ArcArray::::zeros(()).into_diag(); assert_eq!(d.dim(), 1); @@ -681,7 +781,8 @@ fn diag() { /// Note that this does not check the strides in the "merged" case! #[test] #[allow(clippy::cognitive_complexity)] -fn merge_axes() { +fn merge_axes() +{ macro_rules! assert_merged { ($arr:expr, $slice:expr, $take:expr, $into:expr) => { let mut v = $arr.slice($slice); @@ -769,7 +870,8 @@ fn merge_axes() { } #[test] -fn swapaxes() { +fn swapaxes() +{ let mut a = arr2(&[[1., 2.], [3., 4.0f32]]); let b = arr2(&[[1., 3.], [2., 4.0f32]]); assert!(a != b); @@ -782,7 +884,8 @@ fn swapaxes() { } #[test] -fn permuted_axes() { +fn permuted_axes() +{ let a = array![1].index_axis_move(Axis(0), 0); let permuted = a.view().permuted_axes([]); assert_eq!(a, permuted); @@ -791,7 +894,9 @@ fn permuted_axes() { let permuted = a.view().permuted_axes([0]); assert_eq!(a, permuted); - let a = Array::from_iter(0..24).into_shape((2, 3, 4)).unwrap(); + let a = Array::from_iter(0..24) + .into_shape_with_order((2, 3, 4)) + .unwrap(); let permuted = a.view().permuted_axes([2, 1, 0]); for ((i0, i1, i2), elem) in a.indexed_iter() { assert_eq!(*elem, permuted[(i2, i1, i0)]); @@ -801,7 +906,9 @@ fn permuted_axes() { assert_eq!(*elem, permuted[&[i0, i2, i1][..]]); } - let a = Array::from_iter(0..120).into_shape((2, 3, 4, 5)).unwrap(); + let a = Array::from_iter(0..120) + .into_shape_with_order((2, 3, 4, 5)) + .unwrap(); let permuted = a.view().permuted_axes([1, 0, 3, 2]); for ((i0, i1, i2, i3), elem) in a.indexed_iter() { assert_eq!(*elem, permuted[(i1, i0, i3, i2)]); @@ -814,16 +921,20 @@ fn permuted_axes() { #[should_panic] #[test] -fn permuted_axes_repeated_axis() { - let a = Array::from_iter(0..24).into_shape((2, 3, 4)).unwrap(); +fn permuted_axes_repeated_axis() +{ + let a = Array::from_iter(0..24) + .into_shape_with_order((2, 3, 4)) + .unwrap(); a.view().permuted_axes([1, 0, 1]); } #[should_panic] #[test] -fn permuted_axes_missing_axis() { +fn permuted_axes_missing_axis() +{ let a = Array::from_iter(0..24) - .into_shape((2, 3, 4)) + .into_shape_with_order((2, 3, 4)) .unwrap() .into_dyn(); a.view().permuted_axes(&[2, 0][..]); @@ -831,13 +942,17 @@ fn permuted_axes_missing_axis() { #[should_panic] #[test] -fn permuted_axes_oob() { - let a = Array::from_iter(0..24).into_shape((2, 3, 4)).unwrap(); +fn permuted_axes_oob() +{ + let a = Array::from_iter(0..24) + .into_shape_with_order((2, 3, 4)) + .unwrap(); a.view().permuted_axes([1, 0, 3]); } #[test] -fn standard_layout() { +fn standard_layout() +{ let mut a = arr2(&[[1., 2.], [3., 4.0]]); assert!(a.is_standard_layout()); a.swap_axes(0, 1); @@ -855,32 +970,8 @@ fn standard_layout() { } #[test] -fn assign() { - let mut a = arr2(&[[1., 2.], [3., 4.]]); - let b = arr2(&[[1., 3.], [2., 4.]]); - a.assign(&b); - assert_eq!(a, b); - - /* Test broadcasting */ - a.assign(&ArcArray::zeros(1)); - assert_eq!(a, ArcArray::zeros((2, 2))); - - /* Test other type */ - a.assign(&Array::from_elem((2, 2), 3.)); - assert_eq!(a, ArcArray::from_elem((2, 2), 3.)); - - /* Test mut view */ - let mut a = arr2(&[[1, 2], [3, 4]]); - { - let mut v = a.view_mut(); - v.slice_collapse(s![..1, ..]); - v.fill(0); - } - assert_eq!(a, arr2(&[[0, 0], [3, 4]])); -} - -#[test] -fn iter_size_hint() { +fn iter_size_hint() +{ let mut a = arr2(&[[1., 2.], [3., 4.]]); { let mut it = a.iter(); @@ -915,16 +1006,17 @@ fn iter_size_hint() { } #[test] -fn zero_axes() { +fn zero_axes() +{ let mut a = arr1::(&[]); for _ in a.iter() { panic!(); } a.map(|_| panic!()); a.map_inplace(|_| panic!()); - a.visit(|_| panic!()); + a.for_each(|_| panic!()); println!("{:?}", a); - let b = arr2::(&[[], [], [], []]); + let b = arr2::(&[[], [], [], []]); println!("{:?}\n{:?}", b.shape(), b); // we can even get a subarray of b @@ -933,7 +1025,8 @@ fn zero_axes() { } #[test] -fn equality() { +fn equality() +{ let a = arr2(&[[1., 2.], [3., 4.]]); let mut b = arr2(&[[1., 2.], [2., 4.]]); assert!(a != b); @@ -946,7 +1039,8 @@ fn equality() { } #[test] -fn map1() { +fn map1() +{ let a = arr2(&[[1., 2.], [3., 4.]]); let b = a.map(|&x| (x / 3.) as isize); assert_eq!(b, arr2(&[[0, 0], [1, 1]])); @@ -956,8 +1050,25 @@ fn map1() { } #[test] -fn as_slice_memory_order() { - // test that mutation breaks sharing +fn mapv_into_any_same_type() +{ + let a: Array = array![[1., 2., 3.], [4., 5., 6.]]; + let a_plus_one: Array = array![[2., 3., 4.], [5., 6., 7.]]; + assert_eq!(a.mapv_into_any(|a| a + 1.), a_plus_one); +} + +#[test] +fn mapv_into_any_diff_types() +{ + let a: Array = array![[1., 2., 3.], [4., 5., 6.]]; + let a_even: Array = array![[false, true, false], [true, false, true]]; + assert_eq!(a.mapv_into_any(|a| a.round() as i32 % 2 == 0), a_even); +} + +#[test] +fn as_slice_memory_order_mut_arcarray() +{ + // Test that mutation breaks sharing for `ArcArray`. let a = rcarr2(&[[1., 2.], [3., 4.0f32]]); let mut b = a.clone(); for elt in b.as_slice_memory_order_mut().unwrap() { @@ -967,11 +1078,90 @@ fn as_slice_memory_order() { } #[test] -fn array0_into_scalar() { +fn as_slice_memory_order_mut_cowarray() +{ + // Test that mutation breaks sharing for `CowArray`. + let a = arr2(&[[1., 2.], [3., 4.0f32]]); + let mut b = CowArray::from(a.view()); + for elt in b.as_slice_memory_order_mut().unwrap() { + *elt = 0.; + } + assert!(a != b, "{:?} != {:?}", a, b); +} + +#[test] +fn as_slice_memory_order_mut_contiguous_arcarray() +{ + // Test that unsharing preserves the strides in the contiguous case for `ArcArray`. + let a = rcarr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes(); + let mut b = a.clone().slice_move(s![.., ..2]); + assert_eq!(b.strides(), &[1, 2]); + b.as_slice_memory_order_mut().unwrap(); + assert_eq!(b.strides(), &[1, 2]); +} + +#[test] +fn as_slice_memory_order_mut_contiguous_cowarray() +{ + // Test that unsharing preserves the strides in the contiguous case for `CowArray`. + let a = arr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes(); + let mut b = CowArray::from(a.slice(s![.., ..2])); + assert!(b.is_view()); + assert_eq!(b.strides(), &[1, 2]); + b.as_slice_memory_order_mut().unwrap(); + assert_eq!(b.strides(), &[1, 2]); +} + +#[test] +fn to_slice_memory_order() +{ + for shape in vec![[2, 0, 3, 5], [2, 1, 3, 5], [2, 4, 3, 5]] { + let data: Vec = (0..shape.iter().product()).collect(); + let mut orig = Array1::from(data.clone()) + .into_shape_with_order(shape) + .unwrap(); + for perm in vec![[0, 1, 2, 3], [0, 2, 1, 3], [2, 0, 1, 3]] { + let mut a = orig.view_mut().permuted_axes(perm); + assert_eq!(a.as_slice_memory_order().unwrap(), &data); + assert_eq!(a.as_slice_memory_order_mut().unwrap(), &data); + assert_eq!(a.view().to_slice_memory_order().unwrap(), &data); + assert_eq!(a.view_mut().into_slice_memory_order().unwrap(), &data); + } + } +} + +#[test] +fn to_slice_memory_order_discontiguous() +{ + let mut orig = Array3::::zeros([3, 2, 4]); + assert!(orig + .slice(s![.., 1.., ..]) + .as_slice_memory_order() + .is_none()); + assert!(orig + .slice_mut(s![.., 1.., ..]) + .as_slice_memory_order_mut() + .is_none()); + assert!(orig + .slice(s![.., 1.., ..]) + .to_slice_memory_order() + .is_none()); + assert!(orig + .slice_mut(s![.., 1.., ..]) + .into_slice_memory_order() + .is_none()); +} + +#[test] +fn array0_into_scalar() +{ // With this kind of setup, the `Array`'s pointer is not the same as the // underlying `Vec`'s pointer. let a: Array0 = array![4, 5, 6, 7].index_axis_move(Axis(0), 2); - assert_ne!(a.as_ptr(), a.into_raw_vec().as_ptr()); + let a_ptr = a.as_ptr(); + let (raw_vec, offset) = a.into_raw_vec_and_offset(); + assert_ne!(a_ptr, raw_vec.as_ptr()); + assert_eq!(offset, Some(2)); // `.into_scalar()` should still work correctly. let a: Array0 = array![4, 5, 6, 7].index_axis_move(Axis(0), 2); assert_eq!(a.into_scalar(), 6); @@ -982,11 +1172,15 @@ fn array0_into_scalar() { } #[test] -fn array_view0_into_scalar() { +fn array_view0_into_scalar() +{ // With this kind of setup, the `Array`'s pointer is not the same as the // underlying `Vec`'s pointer. let a: Array0 = array![4, 5, 6, 7].index_axis_move(Axis(0), 2); - assert_ne!(a.as_ptr(), a.into_raw_vec().as_ptr()); + let a_ptr = a.as_ptr(); + let (raw_vec, offset) = a.into_raw_vec_and_offset(); + assert_ne!(a_ptr, raw_vec.as_ptr()); + assert_eq!(offset, Some(2)); // `.into_scalar()` should still work correctly. let a: Array0 = array![4, 5, 6, 7].index_axis_move(Axis(0), 2); assert_eq!(a.view().into_scalar(), &6); @@ -997,11 +1191,12 @@ fn array_view0_into_scalar() { } #[test] -fn array_view_mut0_into_scalar() { +fn array_view_mut0_into_scalar() +{ // With this kind of setup, the `Array`'s pointer is not the same as the // underlying `Vec`'s pointer. let a: Array0 = array![4, 5, 6, 7].index_axis_move(Axis(0), 2); - assert_ne!(a.as_ptr(), a.into_raw_vec().as_ptr()); + assert_ne!(a.as_ptr(), a.into_raw_vec_and_offset().0.as_ptr()); // `.into_scalar()` should still work correctly. let mut a: Array0 = array![4, 5, 6, 7].index_axis_move(Axis(0), 2); assert_eq!(a.view_mut().into_scalar(), &6); @@ -1012,7 +1207,18 @@ fn array_view_mut0_into_scalar() { } #[test] -fn owned_array1() { +fn array1_into_raw_vec() +{ + let data = vec![4, 5, 6, 7]; + let array = Array::from(data.clone()); + let (raw_vec, offset) = array.into_raw_vec_and_offset(); + assert_eq!(data, raw_vec); + assert_eq!(offset, Some(0)); +} + +#[test] +fn owned_array1() +{ let mut a = Array::from(vec![1, 2, 3, 4]); for elt in a.iter_mut() { *elt = 2; @@ -1037,7 +1243,8 @@ fn owned_array1() { } #[test] -fn owned_array_with_stride() { +fn owned_array_with_stride() +{ let v: Vec<_> = (0..12).collect(); let dim = (2, 3, 2); let strides = (1, 4, 2); @@ -1047,7 +1254,8 @@ fn owned_array_with_stride() { } #[test] -fn owned_array_discontiguous() { +fn owned_array_discontiguous() +{ use std::iter::repeat; let v: Vec<_> = (0..12).flat_map(|x| repeat(x).take(2)).collect(); let dim = (3, 2, 2); @@ -1060,14 +1268,17 @@ fn owned_array_discontiguous() { } #[test] -fn owned_array_discontiguous_drop() { +fn owned_array_discontiguous_drop() +{ use std::cell::RefCell; use std::collections::BTreeSet; use std::rc::Rc; struct InsertOnDrop(Rc>>, Option); - impl Drop for InsertOnDrop { - fn drop(&mut self) { + impl Drop for InsertOnDrop + { + fn drop(&mut self) + { let InsertOnDrop(ref set, ref mut value) = *self; set.borrow_mut().insert(value.take().expect("double drop!")); } @@ -1101,13 +1312,15 @@ macro_rules! assert_matches { } #[test] -fn from_vec_dim_stride_empty_1d() { +fn from_vec_dim_stride_empty_1d() +{ let empty: [f32; 0] = []; assert_matches!(Array::from_shape_vec(0.strides(1), empty.to_vec()), Ok(_)); } #[test] -fn from_vec_dim_stride_0d() { +fn from_vec_dim_stride_0d() +{ let empty: [f32; 0] = []; let one = [1.]; let two = [1., 2.]; @@ -1123,7 +1336,8 @@ fn from_vec_dim_stride_0d() { } #[test] -fn from_vec_dim_stride_2d_1() { +fn from_vec_dim_stride_2d_1() +{ let two = [1., 2.]; let d = Ix2(2, 1); let s = d.default_strides(); @@ -1131,7 +1345,8 @@ fn from_vec_dim_stride_2d_1() { } #[test] -fn from_vec_dim_stride_2d_2() { +fn from_vec_dim_stride_2d_2() +{ let two = [1., 2.]; let d = Ix2(1, 2); let s = d.default_strides(); @@ -1139,7 +1354,8 @@ fn from_vec_dim_stride_2d_2() { } #[test] -fn from_vec_dim_stride_2d_3() { +fn from_vec_dim_stride_2d_3() +{ let a = arr3(&[[[1]], [[2]], [[3]]]); let d = a.raw_dim(); let s = d.default_strides(); @@ -1150,7 +1366,8 @@ fn from_vec_dim_stride_2d_3() { } #[test] -fn from_vec_dim_stride_2d_4() { +fn from_vec_dim_stride_2d_4() +{ let a = arr3(&[[[1]], [[2]], [[3]]]); let d = a.raw_dim(); let s = d.fortran_strides(); @@ -1161,7 +1378,8 @@ fn from_vec_dim_stride_2d_4() { } #[test] -fn from_vec_dim_stride_2d_5() { +fn from_vec_dim_stride_2d_5() +{ let a = arr3(&[[[1, 2, 3]]]); let d = a.raw_dim(); let s = d.fortran_strides(); @@ -1172,7 +1390,8 @@ fn from_vec_dim_stride_2d_5() { } #[test] -fn from_vec_dim_stride_2d_6() { +fn from_vec_dim_stride_2d_6() +{ let a = [1., 2., 3., 4., 5., 6.]; let d = (2, 1, 1); let s = (2, 2, 1); @@ -1184,7 +1403,8 @@ fn from_vec_dim_stride_2d_6() { } #[test] -fn from_vec_dim_stride_2d_7() { +fn from_vec_dim_stride_2d_7() +{ // empty arrays can have 0 strides let a: [f32; 0] = []; // [[]] shape=[4, 0], strides=[0, 1] @@ -1194,7 +1414,8 @@ fn from_vec_dim_stride_2d_7() { } #[test] -fn from_vec_dim_stride_2d_8() { +fn from_vec_dim_stride_2d_8() +{ // strides of length 1 axes can be zero let a = [1.]; let d = (1, 1); @@ -1203,7 +1424,8 @@ fn from_vec_dim_stride_2d_8() { } #[test] -fn from_vec_dim_stride_2d_rejects() { +fn from_vec_dim_stride_2d_rejects() +{ let two = [1., 2.]; let d = (2, 2); let s = (1, 0); @@ -1215,8 +1437,11 @@ fn from_vec_dim_stride_2d_rejects() { } #[test] -fn views() { - let a = ArcArray::from(vec![1, 2, 3, 4]).reshape((2, 2)); +fn views() +{ + let a = ArcArray::from(vec![1, 2, 3, 4]) + .into_shape_with_order((2, 2)) + .unwrap(); let b = a.view(); assert_eq!(a, b); assert_eq!(a.shape(), b.shape()); @@ -1232,8 +1457,11 @@ fn views() { } #[test] -fn view_mut() { - let mut a = ArcArray::from(vec![1, 2, 3, 4]).reshape((2, 2)); +fn view_mut() +{ + let mut a = ArcArray::from(vec![1, 2, 3, 4]) + .into_shape_with_order((2, 2)) + .unwrap(); for elt in &mut a.view_mut() { *elt = 0; } @@ -1251,8 +1479,11 @@ fn view_mut() { } #[test] -fn slice_mut() { - let mut a = ArcArray::from(vec![1, 2, 3, 4]).reshape((2, 2)); +fn slice_mut() +{ + let mut a = ArcArray::from(vec![1, 2, 3, 4]) + .into_shape_with_order((2, 2)) + .unwrap(); for elt in a.slice_mut(s![.., ..]) { *elt = 0; } @@ -1273,7 +1504,8 @@ fn slice_mut() { } #[test] -fn assign_ops() { +fn assign_ops() +{ let mut a = arr2(&[[1., 2.], [3., 4.]]); let b = arr2(&[[1., 3.], [2., 4.]]); (*&mut a.view_mut()) += &b; @@ -1291,7 +1523,8 @@ fn assign_ops() { } #[test] -fn aview() { +fn aview() +{ let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let data = [[1., 2., 3.], [4., 5., 6.]]; let b = aview2(&data); @@ -1300,10 +1533,11 @@ fn aview() { } #[test] -fn aview_mut() { +fn aview_mut() +{ let mut data = [0; 16]; { - let mut a = aview_mut1(&mut data).into_shape((4, 4)).unwrap(); + let mut a = aview_mut1(&mut data).into_shape_with_order((4, 4)).unwrap(); { let mut slc = a.slice_mut(s![..2, ..;2]); slc += 1; @@ -1313,7 +1547,8 @@ fn aview_mut() { } #[test] -fn transpose_view() { +fn transpose_view() +{ let a = arr2(&[[1, 2], [3, 4]]); let at = a.view().reversed_axes(); assert_eq!(at, arr2(&[[1, 3], [2, 4]])); @@ -1324,7 +1559,8 @@ fn transpose_view() { } #[test] -fn transpose_view_mut() { +fn transpose_view_mut() +{ let mut a = arr2(&[[1, 2], [3, 4]]); let mut at = a.view_mut().reversed_axes(); at[[0, 1]] = 5; @@ -1336,67 +1572,10 @@ fn transpose_view_mut() { assert_eq!(at, arr2(&[[1, 4], [2, 5], [3, 7]])); } -#[test] -fn reshape() { - let data = [1, 2, 3, 4, 5, 6, 7, 8]; - let v = aview1(&data); - let u = v.into_shape((3, 3)); - assert!(u.is_err()); - let u = v.into_shape((2, 2, 2)); - assert!(u.is_ok()); - let u = u.unwrap(); - assert_eq!(u.shape(), &[2, 2, 2]); - let s = u.into_shape((4, 2)).unwrap(); - assert_eq!(s.shape(), &[4, 2]); - assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]])); -} - -#[test] -#[should_panic(expected = "IncompatibleShape")] -fn reshape_error1() { - let data = [1, 2, 3, 4, 5, 6, 7, 8]; - let v = aview1(&data); - let _u = v.into_shape((2, 5)).unwrap(); -} - -#[test] -#[should_panic(expected = "IncompatibleLayout")] -fn reshape_error2() { - let data = [1, 2, 3, 4, 5, 6, 7, 8]; - let v = aview1(&data); - let mut u = v.into_shape((2, 2, 2)).unwrap(); - u.swap_axes(0, 1); - let _s = u.into_shape((2, 4)).unwrap(); -} - -#[test] -fn reshape_f() { - let mut u = Array::zeros((3, 4).f()); - for (i, elt) in enumerate(u.as_slice_memory_order_mut().unwrap()) { - *elt = i as i32; - } - let v = u.view(); - println!("{:?}", v); - - // noop ok - let v2 = v.into_shape((3, 4)); - assert!(v2.is_ok()); - assert_eq!(v, v2.unwrap()); - - let u = v.into_shape((3, 2, 2)); - assert!(u.is_ok()); - let u = u.unwrap(); - println!("{:?}", u); - assert_eq!(u.shape(), &[3, 2, 2]); - let s = u.into_shape((4, 3)).unwrap(); - println!("{:?}", s); - assert_eq!(s.shape(), &[4, 3]); - assert_eq!(s, aview2(&[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]])); -} - #[test] #[allow(clippy::cognitive_complexity)] -fn insert_axis() { +fn insert_axis() +{ defmac!(test_insert orig, index, new => { let res = orig.insert_axis(Axis(index)); assert_eq!(res, new); @@ -1491,7 +1670,8 @@ fn insert_axis() { } #[test] -fn insert_axis_f() { +fn insert_axis_f() +{ defmac!(test_insert_f orig, index, new => { let res = orig.insert_axis(Axis(index)); assert_eq!(res, new); @@ -1538,7 +1718,8 @@ fn insert_axis_f() { } #[test] -fn insert_axis_view() { +fn insert_axis_view() +{ let a = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]; assert_eq!( @@ -1556,24 +1737,78 @@ fn insert_axis_view() { } #[test] -fn arithmetic_broadcast() { +fn arithmetic_broadcast() +{ let mut a = arr2(&[[1., 2.], [3., 4.]]); let b = a.clone() * aview0(&1.); assert_eq!(a, b); a.swap_axes(0, 1); let b = a.clone() / aview0(&1.); assert_eq!(a, b); + + // reference + let a = arr2(&[[2], [3], [4]]); + let b = arr1(&[5, 6, 7]); + assert_eq!(&a + &b, arr2(&[[7, 8, 9], [8, 9, 10], [9, 10, 11]])); + assert_eq!( + a.clone() - &b, + arr2(&[[-3, -4, -5], [-2, -3, -4], [-1, -2, -3]]) + ); + assert_eq!( + a.clone() * b.clone(), + arr2(&[[10, 12, 14], [15, 18, 21], [20, 24, 28]]) + ); + assert_eq!(&b / a, arr2(&[[2, 3, 3], [1, 2, 2], [1, 1, 1]])); + + // Negative strides and non-contiguous memory + let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap(); + let a = s.slice(s![..;-1,..;2,..]); + let b = s.slice(s![..2, -1, ..]); + let mut c = s.clone(); + c.collapse_axis(Axis(2), 1); + let c = c.slice(s![1,..;2,..]); + assert_eq!( + &a.to_owned() + &b, + arr3(&[[[11, 15], [20, 24]], [[10, 14], [19, 23]]]) + ); + assert_eq!( + &a + b.into_owned() + c, + arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]]) + ); + + // shared array + let sa = a.to_shared(); + let sa2 = sa.to_shared(); + let sb = b.to_shared(); + let sb2 = sb.to_shared(); + let sc = c.to_shared(); + let sc2 = sc.into_shared(); + assert_eq!( + sa2 + &sb2 + sc2.into_owned(), + arr3(&[[[15, 19], [32, 36]], [[14, 18], [31, 35]]]) + ); + + // Same shape + let a = s.slice(s![..;-1, ..;2, ..]); + let b = s.slice(s![.., ..;2, ..]); + assert_eq!(a.shape(), b.shape()); + assert_eq!(&a + &b, arr3(&[[[3, 7], [19, 23]], [[3, 7], [19, 23]]])); } #[test] -fn char_array() { +fn char_array() +{ // test compilation & basics of non-numerical array - let cc = ArcArray::from_iter("alphabet".chars()).reshape((4, 2)); + let cc = ArcArray::from_iter("alphabet".chars()) + .into_shape_with_order((4, 2)) + .unwrap(); assert!(cc.index_axis(Axis(1), 0) == ArcArray::from_iter("apae".chars())); } #[test] -fn scalar_ops() { +fn scalar_ops() +{ let a = Array::::zeros((5, 5)); let b = &a + 1; let c = (&a + &a + 2) - 3; @@ -1610,7 +1845,9 @@ fn scalar_ops() { } #[test] -fn split_at() { +#[cfg(feature = "std")] +fn split_at() +{ let mut a = arr2(&[[1., 2.], [3., 4.]]); { @@ -1627,7 +1864,9 @@ fn split_at() { } assert_eq!(a, arr2(&[[1., 5.], [8., 4.]])); - let b = ArcArray::linspace(0., 59., 60).reshape((3, 4, 5)); + let b = ArcArray::linspace(0., 59., 60) + .into_shape_with_order((3, 4, 5)) + .unwrap(); let (left, right) = b.view().split_at(Axis(2), 2); assert_eq!(left.shape(), [3, 4, 2]); @@ -1648,20 +1887,24 @@ fn split_at() { #[test] #[should_panic] -fn deny_split_at_axis_out_of_bounds() { +fn deny_split_at_axis_out_of_bounds() +{ let a = arr2(&[[1., 2.], [3., 4.]]); a.view().split_at(Axis(2), 0); } #[test] #[should_panic] -fn deny_split_at_index_out_of_bounds() { +fn deny_split_at_index_out_of_bounds() +{ let a = arr2(&[[1., 2.], [3., 4.]]); a.view().split_at(Axis(1), 3); } #[test] -fn test_range() { +#[cfg(feature = "std")] +fn test_range() +{ let a = Array::range(0., 5., 1.); assert_eq!(a.len(), 5); assert_eq!(a[0], 0.); @@ -1690,7 +1933,8 @@ fn test_range() { } #[test] -fn test_f_order() { +fn test_f_order() +{ // Test that arrays are logically equal in every way, // even if the underlying memory order is different let c = arr2(&[[1, 2, 3], [4, 5, 6]]); @@ -1701,7 +1945,7 @@ fn test_f_order() { assert_eq!(c.strides(), &[3, 1]); assert_eq!(f.strides(), &[1, 2]); itertools::assert_equal(f.iter(), c.iter()); - itertools::assert_equal(f.genrows(), c.genrows()); + itertools::assert_equal(f.rows(), c.rows()); itertools::assert_equal(f.outer_iter(), c.outer_iter()); itertools::assert_equal(f.axis_iter(Axis(0)), c.axis_iter(Axis(0))); itertools::assert_equal(f.axis_iter(Axis(1)), c.axis_iter(Axis(1))); @@ -1712,27 +1956,39 @@ fn test_f_order() { } #[test] -fn to_owned_memory_order() { +fn to_owned_memory_order() +{ // check that .to_owned() makes f-contiguous arrays out of f-contiguous // input. let c = arr2(&[[1, 2, 3], [4, 5, 6]]); let mut f = c.view(); + + // transposed array f.swap_axes(0, 1); let fo = f.to_owned(); assert_eq!(f, fo); assert_eq!(f.strides(), fo.strides()); + + // negated stride axis + f.invert_axis(Axis(1)); + let fo2 = f.to_owned(); + assert_eq!(f, fo2); + assert_eq!(f.strides(), fo2.strides()); } #[test] -fn to_owned_neg_stride() { +fn to_owned_neg_stride() +{ let mut c = arr2(&[[1, 2, 3], [4, 5, 6]]); c.slice_collapse(s![.., ..;-1]); let co = c.to_owned(); assert_eq!(c, co); + assert_eq!(c.strides(), co.strides()); } #[test] -fn discontiguous_owned_to_owned() { +fn discontiguous_owned_to_owned() +{ let mut c = arr2(&[[1, 2, 3], [4, 5, 6]]); c.slice_collapse(s![.., ..;2]); @@ -1743,7 +1999,8 @@ fn discontiguous_owned_to_owned() { } #[test] -fn map_memory_order() { +fn map_memory_order() +{ let a = arr3(&[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [0, -1, -2]]]); let mut v = a.view(); v.swap_axes(0, 1); @@ -1753,7 +2010,76 @@ fn map_memory_order() { } #[test] -fn test_contiguous() { +fn map_mut_with_unsharing() +{ + // Fortran-layout `ArcArray`. + let a = rcarr2(&[[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]).reversed_axes(); + assert_eq!(a.shape(), &[2, 5]); + assert_eq!(a.strides(), &[1, 2]); + assert_eq!( + a.as_slice_memory_order(), + Some(&[0, 5, 1, 6, 2, 7, 3, 8, 4, 9][..]) + ); + + // Shared reference of a portion of `a`. + let mut b = a.clone().slice_move(s![.., ..2]); + assert_eq!(b.shape(), &[2, 2]); + assert_eq!(b.strides(), &[1, 2]); + assert_eq!(b.as_slice_memory_order(), Some(&[0, 5, 1, 6][..])); + assert_eq!(b, array![[0, 1], [5, 6]]); + + // `.map_mut()` unshares the data. Earlier versions of `ndarray` failed + // this assertion. See #1018. + assert_eq!(b.map_mut(|&mut x| x + 10), array![[10, 11], [15, 16]]); + + // The strides should be preserved. + assert_eq!(b.shape(), &[2, 2]); + assert_eq!(b.strides(), &[1, 2]); +} + +#[test] +fn test_view_from_shape() +{ + let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; + let a = ArrayView::from_shape((2, 3, 2), &s).unwrap(); + let mut answer = Array::from(s.to_vec()) + .into_shape_with_order((2, 3, 2)) + .unwrap(); + assert_eq!(a, answer); + + // custom strides (row major) + let a = ArrayView::from_shape((2, 3, 2).strides((6, 2, 1)), &s).unwrap(); + assert_eq!(a, answer); + + // custom strides (col major) + let a = ArrayView::from_shape((2, 3, 2).strides((1, 2, 6)), &s).unwrap(); + assert_eq!(a, answer.t()); + + // negative strides + let a = ArrayView::from_shape((2, 3, 2).strides((6, (-2isize) as usize, 1)), &s).unwrap(); + answer.invert_axis(Axis(1)); + assert_eq!(a, answer); +} + +#[test] +fn test_view_from_shape_allow_overlap() +{ + let data = [0, 1, 2]; + let view = ArrayView::from_shape((2, 3).strides((0, 1)), &data).unwrap(); + assert_eq!(view, aview2(&[data; 2])); +} + +#[test] +fn test_view_mut_from_shape_deny_overlap() +{ + let mut data = [0, 1, 2]; + let result = ArrayViewMut::from_shape((2, 3).strides((0, 1)), &mut data); + assert_matches!(result.map_err(|e| e.kind()), Err(ErrorKind::Unsupported)); +} + +#[test] +fn test_contiguous() +{ let c = arr3(&[[[1, 2, 3], [4, 5, 6]], [[4, 5, 6], [7, 7, 7]]]); assert!(c.is_standard_layout()); assert!(c.as_slice_memory_order().is_some()); @@ -1787,18 +2113,83 @@ fn test_contiguous() { } #[test] -#[allow(deprecated)] -fn test_all_close() { - let c = arr3(&[ - [[1., 2., 3.], [1.5, 1.5, 3.]], - [[1., 2., 3.], [1., 2.5, 3.]], - ]); - assert!(c.all_close(&aview1(&[1., 2., 3.]), 1.)); - assert!(!c.all_close(&aview1(&[1., 2., 3.]), 0.1)); +fn test_contiguous_single_element() +{ + assert_matches!(array![1].as_slice_memory_order(), Some(&[1])); + + let arr1 = array![1, 2, 3]; + assert_matches!(arr1.slice(s![0..1]).as_slice_memory_order(), Some(&[1])); + assert_matches!(arr1.slice(s![1..2]).as_slice_memory_order(), Some(&[2])); + assert_matches!(arr1.slice(s![2..3]).as_slice_memory_order(), Some(&[3])); + assert_matches!(arr1.slice(s![0..0]).as_slice_memory_order(), Some(&[])); + + let arr2 = array![[1, 2, 3], [4, 5, 6]]; + assert_matches!(arr2.slice(s![.., 2..3]).as_slice_memory_order(), None); + assert_matches!(arr2.slice(s![1, 2..3]).as_slice_memory_order(), Some(&[6])); } #[test] -fn test_swap() { +fn test_contiguous_neg_strides() +{ + let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]; + let a = ArrayView::from_shape((2, 3, 2).strides((1, 4, 2)), &s).unwrap(); + assert_eq!( + a, + arr3(&[[[0, 2], [4, 6], [8, 10]], [[1, 3], [5, 7], [9, 11]]]) + ); + assert!(a.as_slice_memory_order().is_some()); + + let mut b = a.slice(s![..;1, ..;-1, ..;-1]); + assert_eq!( + b, + arr3(&[[[10, 8], [6, 4], [2, 0]], [[11, 9], [7, 5], [3, 1]]]) + ); + assert!(b.as_slice_memory_order().is_some()); + + b.swap_axes(1, 2); + assert_eq!(b, arr3(&[[[10, 6, 2], [8, 4, 0]], [[11, 7, 3], [9, 5, 1]]])); + assert!(b.as_slice_memory_order().is_some()); + + b.invert_axis(Axis(0)); + assert_eq!(b, arr3(&[[[11, 7, 3], [9, 5, 1]], [[10, 6, 2], [8, 4, 0]]])); + assert!(b.as_slice_memory_order().is_some()); + + let mut c = b.reversed_axes(); + assert_eq!( + c, + arr3(&[[[11, 10], [9, 8]], [[7, 6], [5, 4]], [[3, 2], [1, 0]]]) + ); + assert!(c.as_slice_memory_order().is_some()); + + c.merge_axes(Axis(1), Axis(2)); + assert_eq!(c, arr3(&[[[11, 10, 9, 8]], [[7, 6, 5, 4]], [[3, 2, 1, 0]]])); + assert!(c.as_slice_memory_order().is_some()); + + let d = b.remove_axis(Axis(1)); + assert_eq!(d, arr2(&[[11, 7, 3], [10, 6, 2]])); + assert!(d.as_slice_memory_order().is_none()); + + let e = b.remove_axis(Axis(2)); + assert_eq!(e, arr2(&[[11, 9], [10, 8]])); + assert!(e.as_slice_memory_order().is_some()); + + let f = e.insert_axis(Axis(2)); + assert_eq!(f, arr3(&[[[11], [9]], [[10], [8]]])); + assert!(f.as_slice_memory_order().is_some()); + + let mut g = b.clone(); + g.collapse_axis(Axis(1), 0); + assert_eq!(g, arr3(&[[[11, 7, 3]], [[10, 6, 2]]])); + assert!(g.as_slice_memory_order().is_none()); + + b.collapse_axis(Axis(2), 0); + assert_eq!(b, arr3(&[[[11], [9]], [[10], [8]]])); + assert!(b.as_slice_memory_order().is_some()); +} + +#[test] +fn test_swap() +{ let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9]]); let b = a.clone(); @@ -1811,7 +2202,8 @@ fn test_swap() { } #[test] -fn test_uswap() { +fn test_uswap() +{ let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9]]); let b = a.clone(); @@ -1824,7 +2216,8 @@ fn test_uswap() { } #[test] -fn test_shape() { +fn test_shape() +{ let data = [0, 1, 2, 3, 4, 5]; let a = Array::from_shape_vec((1, 2, 3), data.to_vec()).unwrap(); let b = Array::from_shape_vec((1, 2, 3).f(), data.to_vec()).unwrap(); @@ -1838,7 +2231,8 @@ fn test_shape() { } #[test] -fn test_view_from_shape_ptr() { +fn test_view_from_shape_ptr() +{ let data = [0, 1, 2, 3, 4, 5]; let view = unsafe { ArrayView::from_shape_ptr((2, 3), data.as_ptr()) }; assert_eq!(view, aview2(&[[0, 1, 2], [3, 4, 5]])); @@ -1851,8 +2245,68 @@ fn test_view_from_shape_ptr() { assert_eq!(view, aview2(&[[0, 0, 2], [3, 4, 6]])); } +#[should_panic(expected = "Unsupported")] +#[cfg(debug_assertions)] +#[test] +fn test_view_from_shape_ptr_deny_neg_strides() +{ + let data = [0, 1, 2, 3, 4, 5]; + let _view = unsafe { ArrayView::from_shape_ptr((2, 3).strides((-3isize as usize, 1)), data.as_ptr()) }; +} + +#[should_panic(expected = "Unsupported")] +#[cfg(debug_assertions)] +#[test] +fn test_view_mut_from_shape_ptr_deny_neg_strides() +{ + let mut data = [0, 1, 2, 3, 4, 5]; + let _view = unsafe { ArrayViewMut::from_shape_ptr((2, 3).strides((-3isize as usize, 1)), data.as_mut_ptr()) }; +} + +#[should_panic(expected = "Unsupported")] +#[cfg(debug_assertions)] +#[test] +fn test_raw_view_from_shape_ptr_deny_neg_strides() +{ + let data = [0, 1, 2, 3, 4, 5]; + let _view = unsafe { RawArrayView::from_shape_ptr((2, 3).strides((-3isize as usize, 1)), data.as_ptr()) }; +} + +#[should_panic(expected = "Unsupported")] +#[cfg(debug_assertions)] +#[test] +fn test_raw_view_mut_from_shape_ptr_deny_neg_strides() +{ + let mut data = [0, 1, 2, 3, 4, 5]; + let _view = unsafe { RawArrayViewMut::from_shape_ptr((2, 3).strides((-3isize as usize, 1)), data.as_mut_ptr()) }; +} + +#[test] +fn test_raw_view_from_shape_allow_overlap() +{ + let data = [0, 1, 2]; + let view; + unsafe { + let raw_view = RawArrayView::from_shape_ptr((2, 3).strides((0, 1)), data.as_ptr()); + view = raw_view.deref_into_view(); + } + assert_eq!(view, aview2(&[data, data])); +} + +#[should_panic(expected = "strides must not allow any element")] +#[cfg(debug_assertions)] +#[test] +fn test_raw_view_mut_from_shape_deny_overlap() +{ + let mut data = [0, 1, 2]; + unsafe { + RawArrayViewMut::from_shape_ptr((2, 3).strides((0, 1)), data.as_mut_ptr()); + } +} + #[test] -fn test_default() { +fn test_default() +{ let a = as Default>::default(); assert_eq!(a, aview2(&[[0.0; 0]; 0])); @@ -1863,14 +2317,16 @@ fn test_default() { } #[test] -fn test_default_ixdyn() { +fn test_default_ixdyn() +{ let a = as Default>::default(); let b = >::zeros(IxDyn(&[0])); assert_eq!(a, b); } #[test] -fn test_map_axis() { +fn test_map_axis() +{ let a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); let b = a.map_axis(Axis(0), |view| view.sum()); @@ -1903,7 +2359,8 @@ fn test_map_axis() { } #[test] -fn test_accumulate_axis_inplace_noop() { +fn test_accumulate_axis_inplace_noop() +{ let mut a = Array2::::zeros((0, 3)); a.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev); assert_eq!(a, Array2::zeros((0, 3))); @@ -1945,7 +2402,8 @@ fn test_accumulate_axis_inplace_nonstandard_layout() { } #[test] -fn test_to_vec() { +fn test_to_vec() +{ let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); a.slice_collapse(s![..;-1, ..]); @@ -1956,7 +2414,8 @@ fn test_to_vec() { } #[test] -fn test_array_clone_unalias() { +fn test_array_clone_unalias() +{ let a = Array::::zeros((3, 3)); let mut b = a.clone(); b.fill(1); @@ -1965,15 +2424,19 @@ fn test_array_clone_unalias() { } #[test] -fn test_array_clone_same_view() { - let mut a = Array::from_iter(0..9).into_shape((3, 3)).unwrap(); +fn test_array_clone_same_view() +{ + let mut a = Array::from_iter(0..9) + .into_shape_with_order((3, 3)) + .unwrap(); a.slice_collapse(s![..;-1, ..;-1]); let b = a.clone(); assert_eq!(a, b); } #[test] -fn test_array2_from_diag() { +fn test_array2_from_diag() +{ let diag = arr1(&[0, 1, 2]); let x = Array2::from_diag(&diag); let x_exp = arr2(&[[0, 0, 0], [0, 1, 0], [0, 0, 2]]); @@ -1987,7 +2450,8 @@ fn test_array2_from_diag() { } #[test] -fn array_macros() { +fn array_macros() +{ // array let a1 = array![1, 2, 3]; assert_eq!(a1, arr1(&[1, 2, 3])); @@ -2015,7 +2479,8 @@ fn array_macros() { } #[cfg(test)] -mod as_standard_layout_tests { +mod as_standard_layout_tests +{ use super::*; use ndarray::Data; use std::fmt::Debug; @@ -2034,7 +2499,8 @@ mod as_standard_layout_tests { } #[test] - fn test_f_layout() { + fn test_f_layout() + { let shape = (2, 2).f(); let arr = Array::::from_shape_vec(shape, vec![1, 2, 3, 4]).unwrap(); assert!(!arr.is_standard_layout()); @@ -2042,14 +2508,16 @@ mod as_standard_layout_tests { } #[test] - fn test_c_layout() { + fn test_c_layout() + { let arr = Array::::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap(); assert!(arr.is_standard_layout()); test_as_standard_layout_for(arr); } #[test] - fn test_f_layout_view() { + fn test_f_layout_view() + { let shape = (2, 2).f(); let arr = Array::::from_shape_vec(shape, vec![1, 2, 3, 4]).unwrap(); let arr_view = arr.view(); @@ -2058,7 +2526,8 @@ mod as_standard_layout_tests { } #[test] - fn test_c_layout_view() { + fn test_c_layout_view() + { let arr = Array::::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap(); let arr_view = arr.view(); assert!(arr_view.is_standard_layout()); @@ -2066,14 +2535,16 @@ mod as_standard_layout_tests { } #[test] - fn test_zero_dimensional_array() { + fn test_zero_dimensional_array() + { let arr_view = ArrayView1::::from(&[]); assert!(arr_view.is_standard_layout()); test_as_standard_layout_for(arr_view); } #[test] - fn test_custom_layout() { + fn test_custom_layout() + { let shape = (1, 2, 3, 2).strides((12, 1, 2, 6)); let arr_data: Vec = (0..12).collect(); let arr = Array::::from_shape_vec(shape, arr_data).unwrap(); @@ -2083,11 +2554,13 @@ mod as_standard_layout_tests { } #[cfg(test)] -mod array_cow_tests { +mod array_cow_tests +{ use super::*; #[test] - fn test_is_variant() { + fn test_is_variant() + { let arr: Array = array![[1, 2], [3, 4]]; let arr_cow = CowArray::::from(arr.view()); assert!(arr_cow.is_view()); @@ -2097,7 +2570,8 @@ mod array_cow_tests { assert!(!arr_cow.is_view()); } - fn run_with_various_layouts(mut f: impl FnMut(Array2)) { + fn run_with_various_layouts(mut f: impl FnMut(Array2)) + { for all in vec![ Array2::from_shape_vec((7, 8), (0..7 * 8).collect()).unwrap(), Array2::from_shape_vec((7, 8).f(), (0..7 * 8).collect()).unwrap(), @@ -2115,7 +2589,8 @@ mod array_cow_tests { } #[test] - fn test_element_mutation() { + fn test_element_mutation() + { run_with_various_layouts(|arr: Array2| { let mut expected = arr.clone(); expected[(1, 1)] = 2; @@ -2135,7 +2610,8 @@ mod array_cow_tests { } #[test] - fn test_clone() { + fn test_clone() + { run_with_various_layouts(|arr: Array2| { let arr_cow = CowArray::::from(arr.view()); let arr_cow_clone = arr_cow.clone(); @@ -2153,12 +2629,12 @@ mod array_cow_tests { }); } + #[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] - fn test_clone_from() { - fn assert_eq_contents_and_layout( - arr1: &CowArray<'_, i32, Ix2>, - arr2: &CowArray<'_, i32, Ix2>, - ) { + fn test_clone_from() + { + fn assert_eq_contents_and_layout(arr1: &CowArray<'_, i32, Ix2>, arr2: &CowArray<'_, i32, Ix2>) + { assert_eq!(arr1, arr2); assert_eq!(arr1.dim(), arr2.dim()); assert_eq!(arr1.strides(), arr2.strides()); @@ -2194,7 +2670,8 @@ mod array_cow_tests { } #[test] - fn test_into_owned() { + fn test_into_owned() + { run_with_various_layouts(|arr: Array2| { let before = CowArray::::from(arr.view()); let after = before.into_owned(); @@ -2208,3 +2685,138 @@ mod array_cow_tests { }); } } + +#[test] +fn test_remove_index() +{ + let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); + a.remove_index(Axis(0), 1); + a.remove_index(Axis(1), 2); + assert_eq!(a.shape(), &[3, 2]); + assert_eq!(a, + array![[1, 2], + [7, 8], + [10,11]]); + + let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); + a.invert_axis(Axis(0)); + a.remove_index(Axis(0), 1); + a.remove_index(Axis(1), 2); + assert_eq!(a.shape(), &[3, 2]); + assert_eq!(a, + array![[10,11], + [4, 5], + [1, 2]]); + + a.remove_index(Axis(1), 1); + + assert_eq!(a.shape(), &[3, 1]); + assert_eq!(a, + array![[10], + [4], + [1]]); + a.remove_index(Axis(1), 0); + assert_eq!(a.shape(), &[3, 0]); + assert_eq!(a, + array![[], + [], + []]); +} + +#[should_panic(expected = "must be less")] +#[test] +fn test_remove_index_oob1() +{ + let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); + a.remove_index(Axis(0), 4); +} + +#[should_panic(expected = "must be less")] +#[test] +fn test_remove_index_oob2() +{ + let mut a = array![[10], [4], [1]]; + a.remove_index(Axis(1), 0); + assert_eq!(a.shape(), &[3, 0]); + assert_eq!(a, + array![[], + [], + []]); + a.remove_index(Axis(0), 1); // ok + assert_eq!(a, + array![[], + []]); + a.remove_index(Axis(1), 0); // oob +} + +#[should_panic(expected = "index out of bounds")] +#[test] +fn test_remove_index_oob3() +{ + let mut a = array![[10], [4], [1]]; + a.remove_index(Axis(2), 0); +} + +#[test] +fn test_split_complex_view() +{ + let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| Complex::::new(i as f32 * j as f32, k as f32)); + let Complex { re, im } = a.view().split_complex(); + assert_relative_eq!(re.sum(), 90.); + assert_relative_eq!(im.sum(), 120.); +} + +#[test] +fn test_split_complex_view_roundtrip() +{ + let a_re = Array3::from_shape_fn((3, 1, 5), |(i, j, _k)| i * j); + let a_im = Array3::from_shape_fn((3, 1, 5), |(_i, _j, k)| k); + let a = Array3::from_shape_fn((3, 1, 5), |(i, j, k)| Complex::new(a_re[[i, j, k]], a_im[[i, j, k]])); + let Complex { re, im } = a.view().split_complex(); + assert_eq!(a_re, re); + assert_eq!(a_im, im); +} + +#[test] +fn test_split_complex_view_mut() +{ + let eye_scalar = Array2::::eye(4); + let eye_complex = Array2::>::eye(4); + let mut a = Array2::>::zeros((4, 4)); + let Complex { mut re, im } = a.view_mut().split_complex(); + re.assign(&eye_scalar); + assert_eq!(im.sum(), 0); + assert_eq!(a, eye_complex); +} + +#[test] +fn test_split_complex_zerod() +{ + let mut a = Array0::from_elem((), Complex::new(42, 32)); + let Complex { re, im } = a.view().split_complex(); + assert_eq!(re.get(()), Some(&42)); + assert_eq!(im.get(()), Some(&32)); + let cmplx = a.view_mut().split_complex(); + cmplx.re.assign_to(cmplx.im); + assert_eq!(a.get(()).unwrap().im, 42); +} + +#[test] +fn test_split_complex_permuted() +{ + let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| Complex::new(i * k + j, k)); + let permuted = a.view().permuted_axes([1, 0, 2]); + let Complex { re, im } = permuted.split_complex(); + assert_eq!(re.get((3,2,4)).unwrap(), &11); + assert_eq!(im.get((3,2,4)).unwrap(), &4); +} + +#[test] +fn test_split_complex_invert_axis() +{ + let mut a = Array::from_shape_fn((2, 3, 2), |(i, j, k)| Complex::new(i as f64 + j as f64, i as f64 + k as f64)); + a.invert_axis(Axis(1)); + let cmplx = a.view().split_complex(); + assert_eq!(cmplx.re, a.mapv(|z| z.re)); + assert_eq!(cmplx.im, a.mapv(|z| z.im)); +} diff --git a/tests/assign.rs b/tests/assign.rs new file mode 100644 index 000000000..29a6b851a --- /dev/null +++ b/tests/assign.rs @@ -0,0 +1,282 @@ +use ndarray::prelude::*; + +use std::sync::atomic::{AtomicUsize, Ordering}; + +#[test] +fn assign() +{ + let mut a = arr2(&[[1., 2.], [3., 4.]]); + let b = arr2(&[[1., 3.], [2., 4.]]); + a.assign(&b); + assert_eq!(a, b); + + /* Test broadcasting */ + a.assign(&ArcArray::zeros(1)); + assert_eq!(a, ArcArray::zeros((2, 2))); + + /* Test other type */ + a.assign(&Array::from_elem((2, 2), 3.)); + assert_eq!(a, ArcArray::from_elem((2, 2), 3.)); + + /* Test mut view */ + let mut a = arr2(&[[1, 2], [3, 4]]); + { + let mut v = a.view_mut(); + v.slice_collapse(s![..1, ..]); + v.fill(0); + } + assert_eq!(a, arr2(&[[0, 0], [3, 4]])); +} + +#[test] +fn assign_to() +{ + let mut a = arr2(&[[1., 2.], [3., 4.]]); + let b = arr2(&[[0., 3.], [2., 0.]]); + b.assign_to(&mut a); + assert_eq!(a, b); +} + +#[test] +fn move_into_copy() +{ + let a = arr2(&[[1., 2.], [3., 4.]]); + let acopy = a.clone(); + let mut b = Array::uninit(a.dim()); + a.move_into_uninit(b.view_mut()); + let b = unsafe { b.assume_init() }; + assert_eq!(acopy, b); + + let a = arr2(&[[1., 2.], [3., 4.]]).reversed_axes(); + let acopy = a.clone(); + let mut b = Array::uninit(a.dim()); + a.move_into_uninit(b.view_mut()); + let b = unsafe { b.assume_init() }; + assert_eq!(acopy, b); +} + +#[test] +fn move_into_owned() +{ + // Test various memory layouts and holes while moving String elements. + for &use_f_order in &[false, true] { + for &invert_axis in &[0b00, 0b01, 0b10, 0b11] { + // bitmask for axis to invert + for &slice in &[false, true] { + let mut a = Array::from_shape_fn((5, 4).set_f(use_f_order), |idx| format!("{:?}", idx)); + if slice { + a.slice_collapse(s![1..-1, ..;2]); + } + + if invert_axis & 0b01 != 0 { + a.invert_axis(Axis(0)); + } + if invert_axis & 0b10 != 0 { + a.invert_axis(Axis(1)); + } + + let acopy = a.clone(); + let mut b = Array::uninit(a.dim()); + a.move_into_uninit(b.view_mut()); + let b = unsafe { b.assume_init() }; + + assert_eq!(acopy, b); + } + } + } +} + +#[test] +fn move_into_slicing() +{ + // Count correct number of drops when using move_into_uninit and discontiguous arrays (with holes). + for &use_f_order in &[false, true] { + for &invert_axis in &[0b00, 0b01, 0b10, 0b11] { + // bitmask for axis to invert + let counter = DropCounter::default(); + { + let (m, n) = (5, 4); + + let mut a = Array::from_shape_fn((m, n).set_f(use_f_order), |_idx| counter.element()); + a.slice_collapse(s![1..-1, ..;2]); + if invert_axis & 0b01 != 0 { + a.invert_axis(Axis(0)); + } + if invert_axis & 0b10 != 0 { + a.invert_axis(Axis(1)); + } + + let mut b = Array::uninit(a.dim()); + a.move_into_uninit(b.view_mut()); + let b = unsafe { b.assume_init() }; + + let total = m * n; + let dropped_1 = total - (m - 2) * (n - 2); + assert_eq!(counter.created(), total); + assert_eq!(counter.dropped(), dropped_1); + drop(b); + } + counter.assert_drop_count(); + } + } +} + +#[test] +fn move_into_diag() +{ + // Count correct number of drops when using move_into_uninit and discontiguous arrays (with holes). + for &use_f_order in &[false, true] { + let counter = DropCounter::default(); + { + let (m, n) = (5, 4); + + let a = Array::from_shape_fn((m, n).set_f(use_f_order), |_idx| counter.element()); + let a = a.into_diag(); + + let mut b = Array::uninit(a.dim()); + a.move_into_uninit(b.view_mut()); + let b = unsafe { b.assume_init() }; + + let total = m * n; + let dropped_1 = total - Ord::min(m, n); + assert_eq!(counter.created(), total); + assert_eq!(counter.dropped(), dropped_1); + drop(b); + } + counter.assert_drop_count(); + } +} + +#[test] +fn move_into_0dim() +{ + // Count correct number of drops when using move_into_uninit and discontiguous arrays (with holes). + for &use_f_order in &[false, true] { + let counter = DropCounter::default(); + { + let (m, n) = (5, 4); + + // slice into a 0-dim array + let a = Array::from_shape_fn((m, n).set_f(use_f_order), |_idx| counter.element()); + let a = a.slice_move(s![2, 2]); + + assert_eq!(a.ndim(), 0); + let mut b = Array::uninit(a.dim()); + a.move_into_uninit(b.view_mut()); + let b = unsafe { b.assume_init() }; + + let total = m * n; + let dropped_1 = total - 1; + assert_eq!(counter.created(), total); + assert_eq!(counter.dropped(), dropped_1); + drop(b); + } + counter.assert_drop_count(); + } +} + +#[test] +fn move_into_empty() +{ + // Count correct number of drops when using move_into_uninit and discontiguous arrays (with holes). + for &use_f_order in &[false, true] { + let counter = DropCounter::default(); + { + let (m, n) = (5, 4); + + // slice into an empty array; + let a = Array::from_shape_fn((m, n).set_f(use_f_order), |_idx| counter.element()); + let a = a.slice_move(s![..0, 1..1]); + assert!(a.is_empty()); + let mut b = Array::uninit(a.dim()); + a.move_into_uninit(b.view_mut()); + let b = unsafe { b.assume_init() }; + + let total = m * n; + let dropped_1 = total; + assert_eq!(counter.created(), total); + assert_eq!(counter.dropped(), dropped_1); + drop(b); + } + counter.assert_drop_count(); + } +} + +#[test] +fn move_into() +{ + // Test various memory layouts and holes while moving String elements with move_into + for &use_f_order in &[false, true] { + for &invert_axis in &[0b00, 0b01, 0b10, 0b11] { + // bitmask for axis to invert + for &slice in &[false, true] { + let mut a = Array::from_shape_fn((5, 4).set_f(use_f_order), |idx| format!("{:?}", idx)); + if slice { + a.slice_collapse(s![1..-1, ..;2]); + } + + if invert_axis & 0b01 != 0 { + a.invert_axis(Axis(0)); + } + if invert_axis & 0b10 != 0 { + a.invert_axis(Axis(1)); + } + + let acopy = a.clone(); + let mut b = Array::default(a.dim().set_f(!use_f_order ^ !slice)); + a.move_into(&mut b); + + assert_eq!(acopy, b); + } + } + } +} + +/// This counter can create elements, and then count and verify +/// the number of which have actually been dropped again. +#[derive(Default)] +struct DropCounter +{ + created: AtomicUsize, + dropped: AtomicUsize, +} + +struct Element<'a>(&'a AtomicUsize); + +impl DropCounter +{ + fn created(&self) -> usize + { + self.created.load(Ordering::Relaxed) + } + + fn dropped(&self) -> usize + { + self.dropped.load(Ordering::Relaxed) + } + + fn element(&self) -> Element<'_> + { + self.created.fetch_add(1, Ordering::Relaxed); + Element(&self.dropped) + } + + fn assert_drop_count(&self) + { + assert_eq!( + self.created(), + self.dropped(), + "Expected {} dropped elements, but found {}", + self.created(), + self.dropped() + ); + } +} + +impl<'a> Drop for Element<'a> +{ + fn drop(&mut self) + { + self.0.fetch_add(1, Ordering::Relaxed); + } +} diff --git a/tests/azip.rs b/tests/azip.rs index 885db85b7..9d8bebab7 100644 --- a/tests/azip.rs +++ b/tests/azip.rs @@ -1,21 +1,18 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names, + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names, clippy::float_cmp )] use ndarray::prelude::*; use ndarray::Zip; -use std::iter::FromIterator; use itertools::{assert_equal, cloned}; use std::mem::swap; #[test] -fn test_azip1() { +fn test_azip1() +{ let mut a = Array::zeros(62); let mut x = 0; azip!((a in &mut a) { *a = x; x += 1; }); @@ -23,7 +20,8 @@ fn test_azip1() { } #[test] -fn test_azip2() { +fn test_azip2() +{ let mut a = Array::zeros((5, 7)); let b = Array::from_shape_fn(a.dim(), |(i, j)| 1. / (i + 2 * j) as f32); azip!((a in &mut a, &b in &b) *a = b); @@ -31,7 +29,8 @@ fn test_azip2() { } #[test] -fn test_azip2_1() { +fn test_azip2_1() +{ let mut a = Array::zeros((5, 7)); let b = Array::from_shape_fn((5, 10), |(i, j)| 1. / (i + 2 * j) as f32); let b = b.slice(s![..;-1, 3..]); @@ -40,7 +39,8 @@ fn test_azip2_1() { } #[test] -fn test_azip2_3() { +fn test_azip2_3() +{ let mut b = Array::from_shape_fn((5, 10), |(i, j)| 1. / (i + 2 * j) as f32); let mut c = Array::from_shape_fn((5, 10), |(i, j)| f32::exp((i + j) as f32)); let a = b.clone(); @@ -50,7 +50,148 @@ fn test_azip2_3() { } #[test] -fn test_azip_syntax_trailing_comma() { +#[cfg(feature = "approx")] +fn test_zip_collect() +{ + use approx::assert_abs_diff_eq; + + // test Zip::map_collect and that it preserves c/f layout. + + let b = Array::from_shape_fn((5, 10), |(i, j)| 1. / (i + 2 * j + 1) as f32); + let c = Array::from_shape_fn((5, 10), |(i, j)| f32::exp((i + j) as f32)); + + { + let a = Zip::from(&b).and(&c).map_collect(|x, y| x + y); + + assert_abs_diff_eq!(a, &b + &c, epsilon = 1e-6); + assert_eq!(a.strides(), b.strides()); + } + + { + let b = b.t(); + let c = c.t(); + + let a = Zip::from(&b).and(&c).map_collect(|x, y| x + y); + + assert_abs_diff_eq!(a, &b + &c, epsilon = 1e-6); + assert_eq!(a.strides(), b.strides()); + } +} + +#[test] +#[cfg(feature = "approx")] +fn test_zip_assign_into() +{ + use approx::assert_abs_diff_eq; + + let mut a = Array::::zeros((5, 10)); + let b = Array::from_shape_fn((5, 10), |(i, j)| 1. / (i + 2 * j + 1) as f32); + let c = Array::from_shape_fn((5, 10), |(i, j)| f32::exp((i + j) as f32)); + + Zip::from(&b).and(&c).map_assign_into(&mut a, |x, y| x + y); + + assert_abs_diff_eq!(a, &b + &c, epsilon = 1e-6); +} + +#[test] +#[cfg(feature = "approx")] +fn test_zip_assign_into_cell() +{ + use approx::assert_abs_diff_eq; + use std::cell::Cell; + + let a = Array::, _>::default((5, 10)); + let b = Array::from_shape_fn((5, 10), |(i, j)| 1. / (i + 2 * j + 1) as f32); + let c = Array::from_shape_fn((5, 10), |(i, j)| f32::exp((i + j) as f32)); + + Zip::from(&b).and(&c).map_assign_into(&a, |x, y| x + y); + let a2 = a.mapv(|elt| elt.get()); + + assert_abs_diff_eq!(a2, &b + &c, epsilon = 1e-6); +} + +#[test] +fn test_zip_collect_drop() +{ + use std::cell::RefCell; + use std::panic; + + struct Recorddrop<'a>((usize, usize), &'a RefCell>); + + impl Drop for Recorddrop<'_> + { + fn drop(&mut self) + { + self.1.borrow_mut().push(self.0); + } + } + + #[derive(Copy, Clone)] + enum Config + { + CC, + CF, + FF, + } + + impl Config + { + fn a_is_f(self) -> bool + { + match self { + Config::CC | Config::CF => false, + _ => true, + } + } + fn b_is_f(self) -> bool + { + match self { + Config::CC => false, + _ => true, + } + } + } + + let test_collect_panic = |config: Config, will_panic: bool, slice: bool| { + let mut inserts = RefCell::new(Vec::new()); + let mut drops = RefCell::new(Vec::new()); + + let mut a = Array::from_shape_fn((5, 10).set_f(config.a_is_f()), |idx| idx); + let mut b = Array::from_shape_fn((5, 10).set_f(config.b_is_f()), |_| 0); + if slice { + a = a.slice_move(s![.., ..-1]); + b = b.slice_move(s![.., ..-1]); + } + + let _result = panic::catch_unwind(panic::AssertUnwindSafe(|| { + Zip::from(&a).and(&b).map_collect(|&elt, _| { + if elt.0 > 3 && will_panic { + panic!(); + } + inserts.borrow_mut().push(elt); + Recorddrop(elt, &drops) + }); + })); + + println!("{:?}", inserts.get_mut()); + println!("{:?}", drops.get_mut()); + + assert_eq!(inserts.get_mut().len(), drops.get_mut().len(), "Incorrect number of drops"); + assert_eq!(inserts.get_mut(), drops.get_mut(), "Incorrect order of drops"); + }; + + for &should_panic in &[true, false] { + for &should_slice in &[false, true] { + test_collect_panic(Config::CC, should_panic, should_slice); + test_collect_panic(Config::CF, should_panic, should_slice); + test_collect_panic(Config::FF, should_panic, should_slice); + } + } +} + +#[test] +fn test_azip_syntax_trailing_comma() +{ let mut b = Array::::zeros((5, 5)); let mut c = Array::::ones((5, 5)); let a = b.clone(); @@ -61,7 +202,8 @@ fn test_azip_syntax_trailing_comma() { #[test] #[cfg(feature = "approx")] -fn test_azip2_sum() { +fn test_azip2_sum() +{ use approx::assert_abs_diff_eq; let c = Array::from_shape_fn((5, 10), |(i, j)| f32::exp((i + j) as f32)); @@ -74,8 +216,9 @@ fn test_azip2_sum() { } #[test] -#[cfg(feature = "approx")] -fn test_azip3_slices() { +#[cfg(all(feature = "approx", feature = "std"))] +fn test_azip3_slices() +{ use approx::assert_abs_diff_eq; let mut a = [0.; 32]; @@ -89,13 +232,14 @@ fn test_azip3_slices() { *a += b / 10.; *c = a.sin(); }); - let res = Array::linspace(0., 3.1, 32).mapv_into(f32::sin); + let res = Array::from_iter(0..32).mapv(|x| f32::sin(x as f32 / 10.)); assert_abs_diff_eq!(res, ArrayView::from(&c), epsilon = 1e-4); } #[test] #[cfg(feature = "approx")] -fn test_broadcast() { +fn test_broadcast() +{ use approx::assert_abs_diff_eq; let n = 16; @@ -112,7 +256,7 @@ fn test_broadcast() { .and_broadcast(&b) .and_broadcast(&d) .and_broadcast(&e); - z.apply(|x, &y, &z, &w| *x = y + z + w); + z.for_each(|x, &y, &z, &w| *x = y + z + w); } let sum = &b + &d + &e; assert_abs_diff_eq!(a, sum.broadcast((n, n)).unwrap(), epsilon = 1e-4); @@ -120,7 +264,8 @@ fn test_broadcast() { #[should_panic] #[test] -fn test_zip_dim_mismatch_1() { +fn test_zip_dim_mismatch_1() +{ let mut a = Array::zeros((5, 7)); let mut d = a.raw_dim(); d[0] += 1; @@ -132,8 +277,11 @@ fn test_zip_dim_mismatch_1() { // Zip::from(A).and(B) // where A is F-contiguous and B contiguous but neither F nor C contiguous. #[test] -fn test_contiguous_but_not_c_or_f() { - let a = Array::from_iter(0..27).into_shape((3, 3, 3)).unwrap(); +fn test_contiguous_but_not_c_or_f() +{ + let a = Array::from_iter(0..27) + .into_shape_with_order((3, 3, 3)) + .unwrap(); // both F order let a = a.reversed_axes(); @@ -156,31 +304,49 @@ fn test_contiguous_but_not_c_or_f() { } #[test] -fn test_clone() { - let a = Array::from_iter(0..27).into_shape((3, 3, 3)).unwrap(); +fn test_clone() +{ + let a = Array::from_iter(0..27) + .into_shape_with_order((3, 3, 3)) + .unwrap(); let z = Zip::from(&a).and(a.exact_chunks((1, 1, 1))); let w = z.clone(); let mut result = Vec::new(); - z.apply(|x, y| { + z.for_each(|x, y| { result.push((x, y)); }); let mut i = 0; - w.apply(|x, y| { + w.for_each(|x, y| { assert_eq!(result[i], (x, y)); i += 1; }); } #[test] -fn test_indices_1() { +fn test_indices_0() +{ + let a1 = arr0(3); + + let mut count = 0; + Zip::indexed(&a1).for_each(|i, elt| { + count += 1; + assert_eq!(i, ()); + assert_eq!(*elt, 3); + }); + assert_eq!(count, 1); +} + +#[test] +fn test_indices_1() +{ let mut a1 = Array::default(12); for (i, elt) in a1.indexed_iter_mut() { *elt = i; } let mut count = 0; - Zip::indexed(&a1).apply(|i, elt| { + Zip::indexed(&a1).for_each(|i, elt| { count += 1; assert_eq!(*elt, i); }); @@ -190,12 +356,12 @@ fn test_indices_1() { let len = a1.len(); let (x, y) = Zip::indexed(&mut a1).split(); - x.apply(|i, elt| { + x.for_each(|i, elt| { count += 1; assert_eq!(*elt, i); }); assert_eq!(count, len / 2); - y.apply(|i, elt| { + y.for_each(|i, elt| { count += 1; assert_eq!(*elt, i); }); @@ -203,7 +369,8 @@ fn test_indices_1() { } #[test] -fn test_indices_2() { +fn test_indices_2() +{ let mut a1 = Array::default((10, 12)); for (i, elt) in a1.indexed_iter_mut() { *elt = i; @@ -220,12 +387,12 @@ fn test_indices_2() { let len = a1.len(); let (x, y) = Zip::indexed(&mut a1).split(); - x.apply(|i, elt| { + x.for_each(|i, elt| { count += 1; assert_eq!(*elt, i); }); assert_eq!(count, len / 2); - y.apply(|i, elt| { + y.for_each(|i, elt| { count += 1; assert_eq!(*elt, i); }); @@ -233,14 +400,15 @@ fn test_indices_2() { } #[test] -fn test_indices_3() { +fn test_indices_3() +{ let mut a1 = Array::default((4, 5, 6)); for (i, elt) in a1.indexed_iter_mut() { *elt = i; } let mut count = 0; - Zip::indexed(&a1).apply(|i, elt| { + Zip::indexed(&a1).for_each(|i, elt| { count += 1; assert_eq!(*elt, i); }); @@ -250,12 +418,12 @@ fn test_indices_3() { let len = a1.len(); let (x, y) = Zip::indexed(&mut a1).split(); - x.apply(|i, elt| { + x.for_each(|i, elt| { count += 1; assert_eq!(*elt, i); }); assert_eq!(count, len / 2); - y.apply(|i, elt| { + y.for_each(|i, elt| { count += 1; assert_eq!(*elt, i); }); @@ -263,7 +431,8 @@ fn test_indices_3() { } #[test] -fn test_indices_split_1() { +fn test_indices_split_1() +{ for m in (0..4).chain(10..12) { for n in (0..4).chain(10..12) { let a1 = Array::::default((m, n)); @@ -274,12 +443,12 @@ fn test_indices_split_1() { let mut seen = Vec::new(); let mut ac = 0; - a.apply(|i, _| { + a.for_each(|i, _| { ac += 1; seen.push(i); }); let mut bc = 0; - b.apply(|i, _| { + b.for_each(|i, _| { bc += 1; seen.push(i); }); @@ -295,20 +464,22 @@ fn test_indices_split_1() { } #[test] -fn test_zip_all() { +fn test_zip_all() +{ let a = Array::::zeros(62); let b = Array::::ones(62); let mut c = Array::::ones(62); c[5] = 0.0; - assert_eq!(true, Zip::from(&a).and(&b).all(|&x, &y| x + y == 1.0)); - assert_eq!(false, Zip::from(&a).and(&b).all(|&x, &y| x == y)); - assert_eq!(false, Zip::from(&a).and(&c).all(|&x, &y| x + y == 1.0)); + assert!(Zip::from(&a).and(&b).all(|&x, &y| x + y == 1.0)); + assert!(!Zip::from(&a).and(&b).all(|&x, &y| x == y)); + assert!(!Zip::from(&a).and(&c).all(|&x, &y| x + y == 1.0)); } #[test] -fn test_zip_all_empty_array() { +fn test_zip_all_empty_array() +{ let a = Array::::zeros(0); let b = Array::::ones(0); - assert_eq!(true, Zip::from(&a).and(&b).all(|&_x, &_y| true)); - assert_eq!(true, Zip::from(&a).and(&b).all(|&_x, &_y| false)); + assert!(Zip::from(&a).and(&b).all(|&_x, &_y| true)); + assert!(Zip::from(&a).and(&b).all(|&_x, &_y| false)); } diff --git a/tests/broadcast.rs b/tests/broadcast.rs index 6840947bb..288ccb38a 100644 --- a/tests/broadcast.rs +++ b/tests/broadcast.rs @@ -1,15 +1,23 @@ use ndarray::prelude::*; #[test] -fn broadcast_1() { +#[cfg(feature = "std")] +fn broadcast_1() +{ let a_dim = Dim([2, 4, 2, 2]); let b_dim = Dim([2, 1, 2, 1]); - let a = ArcArray::linspace(0., 1., a_dim.size()).reshape(a_dim); - let b = ArcArray::linspace(0., 1., b_dim.size()).reshape(b_dim); + let a = ArcArray::linspace(0., 1., a_dim.size()) + .into_shape_with_order(a_dim) + .unwrap(); + let b = ArcArray::linspace(0., 1., b_dim.size()) + .into_shape_with_order(b_dim) + .unwrap(); assert!(b.broadcast(a.dim()).is_some()); let c_dim = Dim([2, 1]); - let c = ArcArray::linspace(0., 1., c_dim.size()).reshape(c_dim); + let c = ArcArray::linspace(0., 1., c_dim.size()) + .into_shape_with_order(c_dim) + .unwrap(); assert!(c.broadcast(1).is_none()); assert!(c.broadcast(()).is_none()); assert!(c.broadcast((2, 1)).is_some()); @@ -26,11 +34,17 @@ fn broadcast_1() { } #[test] -fn test_add() { +#[cfg(feature = "std")] +fn test_add() +{ let a_dim = Dim([2, 4, 2, 2]); let b_dim = Dim([2, 1, 2, 1]); - let mut a = ArcArray::linspace(0.0, 1., a_dim.size()).reshape(a_dim); - let b = ArcArray::linspace(0.0, 1., b_dim.size()).reshape(b_dim); + let mut a = ArcArray::linspace(0.0, 1., a_dim.size()) + .into_shape_with_order(a_dim) + .unwrap(); + let b = ArcArray::linspace(0.0, 1., b_dim.size()) + .into_shape_with_order(b_dim) + .unwrap(); a += &b; let t = ArcArray::from_elem((), 1.0f32); a += &t; @@ -38,15 +52,20 @@ fn test_add() { #[test] #[should_panic] -fn test_add_incompat() { +#[cfg(feature = "std")] +fn test_add_incompat() +{ let a_dim = Dim([2, 4, 2, 2]); - let mut a = ArcArray::linspace(0.0, 1., a_dim.size()).reshape(a_dim); + let mut a = ArcArray::linspace(0.0, 1., a_dim.size()) + .into_shape_with_order(a_dim) + .unwrap(); let incompat = ArcArray::from_elem(3, 1.0f32); a += &incompat; } #[test] -fn test_broadcast() { +fn test_broadcast() +{ let (_, n, k) = (16, 16, 16); let x1 = 1.; // b0 broadcast 1 -> n, k @@ -66,7 +85,8 @@ fn test_broadcast() { } #[test] -fn test_broadcast_1d() { +fn test_broadcast_1d() +{ let n = 16; let x1 = 1.; // b0 broadcast 1 -> n diff --git a/tests/clone.rs b/tests/clone.rs index e1914ba7f..4a7e50b8e 100644 --- a/tests/clone.rs +++ b/tests/clone.rs @@ -1,7 +1,8 @@ use ndarray::arr2; #[test] -fn test_clone_from() { +fn test_clone_from() +{ let a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9]]); let b = arr2(&[[7, 7, 7]]); let mut c = b.clone(); diff --git a/tests/complex.rs b/tests/complex.rs index 543889dd7..824e296a4 100644 --- a/tests/complex.rs +++ b/tests/complex.rs @@ -3,14 +3,16 @@ use ndarray::{arr1, arr2, Axis}; use num_complex::Complex; use num_traits::Num; -fn c(re: T, im: T) -> Complex { +fn c(re: T, im: T) -> Complex +{ Complex::new(re, im) } #[test] -fn complex_mat_mul() { +fn complex_mat_mul() +{ let a = arr2(&[[c(3., 4.), c(2., 0.)], [c(0., -2.), c(3., 0.)]]); - let b = (&a * c(3., 0.)).map(|c| 5. * c / c.norm()); + let b = (&a * c(3., 0.)).map(|c| 5. * c / c.norm_sqr()); println!("{:>8.2}", b); let e = Array::eye(2); let r = a.dot(&e); diff --git a/tests/dimension.rs b/tests/dimension.rs index 7e76132aa..fe53d96b3 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -2,12 +2,13 @@ use defmac::defmac; -use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, IntoDimension, IxDyn, RemoveAxis}; +use ndarray::{arr2, ArcArray, Array, Axis, Dim, Dimension, IxDyn, RemoveAxis}; use std::hash::{Hash, Hasher}; #[test] -fn insert_axis() { +fn insert_axis() +{ assert_eq!(Dim([]).insert_axis(Axis(0)), Dim([1])); assert_eq!(Dim([3]).insert_axis(Axis(0)), Dim([1, 3])); @@ -41,7 +42,8 @@ fn insert_axis() { } #[test] -fn remove_axis() { +fn remove_axis() +{ assert_eq!(Dim([3]).remove_axis(Axis(0)), Dim([])); assert_eq!(Dim([1, 2]).remove_axis(Axis(0)), Dim([2])); assert_eq!(Dim([4, 5, 6]).remove_axis(Axis(1)), Dim([4, 6])); @@ -55,14 +57,19 @@ fn remove_axis() { let a = ArcArray::::zeros(vec![4, 5, 6]); let _b = a .index_axis_move(Axis(1), 0) - .reshape((4, 6)) - .reshape(vec![2, 3, 4]); + .to_shape((4, 6)) + .unwrap() + .to_shape(vec![2, 3, 4]) + .unwrap(); } #[test] #[allow(clippy::eq_op)] -fn dyn_dimension() { - let a = arr2(&[[1., 2.], [3., 4.0]]).into_shape(vec![2, 2]).unwrap(); +fn dyn_dimension() +{ + let a = arr2(&[[1., 2.], [3., 4.0]]) + .into_shape_with_order(vec![2, 2]) + .unwrap(); assert_eq!(&a - &a, Array::zeros(vec![2, 2])); assert_eq!(a[&[0, 0][..]], 1.); assert_eq!(a[[0, 0]], 1.); @@ -75,7 +82,8 @@ fn dyn_dimension() { } #[test] -fn dyn_insert() { +fn dyn_insert() +{ let mut v = vec![2, 3, 4, 5]; let mut dim = Dim(v.clone()); defmac!(test_insert index => { @@ -94,7 +102,8 @@ fn dyn_insert() { } #[test] -fn dyn_remove() { +fn dyn_remove() +{ let mut v = vec![1, 2, 3, 4, 5, 6, 7]; let mut dim = Dim(v.clone()); defmac!(test_remove index => { @@ -113,24 +122,42 @@ fn dyn_remove() { } #[test] -fn fastest_varying_order() { +fn fastest_varying_order() +{ let strides = Dim([2, 8, 4, 1]); let order = strides._fastest_varying_stride_order(); assert_eq!(order.slice(), &[3, 0, 2, 1]); + let strides = Dim([-2isize as usize, 8, -4isize as usize, -1isize as usize]); + let order = strides._fastest_varying_stride_order(); + assert_eq!(order.slice(), &[3, 0, 2, 1]); + assert_eq!(Dim([1, 3])._fastest_varying_stride_order(), Dim([0, 1])); + assert_eq!( + Dim([1, -3isize as usize])._fastest_varying_stride_order(), + Dim([0, 1]) + ); assert_eq!(Dim([7, 2])._fastest_varying_stride_order(), Dim([1, 0])); + assert_eq!( + Dim([-7isize as usize, 2])._fastest_varying_stride_order(), + Dim([1, 0]) + ); assert_eq!( Dim([6, 1, 3])._fastest_varying_stride_order(), Dim([1, 2, 0]) ); + assert_eq!( + Dim([-6isize as usize, 1, -3isize as usize])._fastest_varying_stride_order(), + Dim([1, 2, 0]) + ); // it's important that it produces distinct indices. Prefer the stable order // where 0 is before 1 when they are equal. assert_eq!(Dim([2, 2])._fastest_varying_stride_order(), [0, 1]); assert_eq!(Dim([2, 2, 1])._fastest_varying_stride_order(), [2, 0, 1]); assert_eq!( - Dim([2, 2, 3, 1, 2])._fastest_varying_stride_order(), + Dim([-2isize as usize, -2isize as usize, 3, 1, -2isize as usize]) + ._fastest_varying_stride_order(), [3, 0, 1, 4, 2] ); } @@ -169,7 +196,8 @@ fn min_stride_axis() { */ #[test] -fn max_stride_axis() { +fn max_stride_axis() +{ let a = ArrayF32::zeros(10); assert_eq!(a.max_stride_axis(), Axis(0)); @@ -196,7 +224,8 @@ fn max_stride_axis() { } #[test] -fn test_indexing() { +fn test_indexing() +{ let mut x = Dim([1, 2]); assert_eq!(x[0], 1); @@ -207,7 +236,8 @@ fn test_indexing() { } #[test] -fn test_operations() { +fn test_operations() +{ let mut x = Dim([1, 2]); let mut y = Dim([1, 1]); @@ -224,8 +254,10 @@ fn test_operations() { #[test] #[allow(clippy::cognitive_complexity)] -fn test_hash() { - fn calc_hash(value: &T) -> u64 { +fn test_hash() +{ + fn calc_hash(value: &T) -> u64 + { let mut hasher = std::collections::hash_map::DefaultHasher::new(); value.hash(&mut hasher); hasher.finish() @@ -260,8 +292,10 @@ fn test_hash() { } #[test] -fn test_generic_operations() { - fn test_dim(d: &D) { +fn test_generic_operations() +{ + fn test_dim(d: &D) + { let mut x = d.clone(); x[0] += 1; assert_eq!(x[0], 3); @@ -275,8 +309,10 @@ fn test_generic_operations() { } #[test] -fn test_array_view() { - fn test_dim(d: &D) { +fn test_array_view() +{ + fn test_dim(d: &D) + { assert_eq!(d.as_array_view().sum(), 7); assert_eq!(d.as_array_view().strides(), &[1]); } @@ -287,10 +323,14 @@ fn test_array_view() { } #[test] +#[cfg_attr(miri, ignore)] // Very slow on CI/CD machines +#[cfg(feature = "std")] #[allow(clippy::cognitive_complexity)] -fn test_all_ndindex() { +fn test_all_ndindex() +{ + use ndarray::IntoDimension; macro_rules! ndindex { - ($($i:expr),*) => { + ($($i:expr),*) => { for &rev in &[false, true] { // rev is for C / F order let size = $($i *)* 1; @@ -312,8 +352,8 @@ fn test_all_ndindex() { assert_eq!(elt, b[dim]); } } + }; } -} ndindex!(10); ndindex!(10, 4); ndindex!(10, 4, 3); diff --git a/tests/format.rs b/tests/format.rs index 5c2e2b6f4..35909871f 100644 --- a/tests/format.rs +++ b/tests/format.rs @@ -2,11 +2,12 @@ use ndarray::prelude::*; use ndarray::rcarr1; #[test] -fn formatting() { +fn formatting() +{ let a = rcarr1::(&[1., 2., 3., 4.]); assert_eq!(format!("{}", a), "[1, 2, 3, 4]"); assert_eq!(format!("{:4}", a), "[ 1, 2, 3, 4]"); - let a = a.reshape((4, 1, 1)); + let a = a.into_shape_clone((4, 1, 1)).unwrap(); assert_eq!( format!("{}", a), "\ @@ -30,7 +31,7 @@ fn formatting() { [[ 4]]]", ); - let a = a.reshape((2, 2)); + let a = a.into_shape_clone((2, 2)).unwrap(); assert_eq!( format!("{}", a), "\ @@ -55,20 +56,21 @@ fn formatting() { } #[test] -fn debug_format() { +fn debug_format() +{ let a = Array2::::zeros((3, 4)); assert_eq!( format!("{:?}", a), "\ [[0, 0, 0, 0], [0, 0, 0, 0], - [0, 0, 0, 0]], shape=[3, 4], strides=[4, 1], layout=C (0x1), const ndim=2" + [0, 0, 0, 0]], shape=[3, 4], strides=[4, 1], layout=Cc (0x5), const ndim=2" ); assert_eq!( format!("{:?}", a.into_dyn()), "\ [[0, 0, 0, 0], [0, 0, 0, 0], - [0, 0, 0, 0]], shape=[3, 4], strides=[4, 1], layout=C (0x1), dynamic ndim=2" + [0, 0, 0, 0]], shape=[3, 4], strides=[4, 1], layout=Cc (0x5), dynamic ndim=2" ); } diff --git a/tests/higher_order_f.rs b/tests/higher_order_f.rs index c567eb3e0..72245412f 100644 --- a/tests/higher_order_f.rs +++ b/tests/higher_order_f.rs @@ -2,7 +2,8 @@ use ndarray::prelude::*; #[test] #[should_panic] -fn test_fold_axis_oob() { +fn test_fold_axis_oob() +{ let a = arr2(&[[1., 2.], [3., 4.]]); a.fold_axis(Axis(2), 0., |x, y| x + y); } diff --git a/tests/indices.rs b/tests/indices.rs index 3e2c0796c..a9414f9a7 100644 --- a/tests/indices.rs +++ b/tests/indices.rs @@ -1,15 +1,17 @@ use ndarray::indices_of; use ndarray::prelude::*; +use ndarray::Order; #[test] -fn test_ixdyn_index_iterate() { - for &rev in &[false, true] { - let mut a = Array::zeros((2, 3, 4).set_f(rev)); +fn test_ixdyn_index_iterate() +{ + for &order in &[Order::C, Order::F] { + let mut a = Array::zeros((2, 3, 4).set_f(order.is_column_major())); let dim = a.shape().to_vec(); for ((i, j, k), elt) in a.indexed_iter_mut() { *elt = i + 10 * j + 100 * k; } - let a = a.into_shape(dim).unwrap(); + let a = a.into_shape_with_order((dim, order)).unwrap(); println!("{:?}", a.dim()); let mut c = 0; for i in indices_of(&a) { diff --git a/tests/into-ixdyn.rs b/tests/into-ixdyn.rs index a9383a0e6..6e7bf9607 100644 --- a/tests/into-ixdyn.rs +++ b/tests/into-ixdyn.rs @@ -1,20 +1,19 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names, + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names, clippy::float_cmp )] use ndarray::prelude::*; #[test] -fn test_arr0_into_dyn() { +fn test_arr0_into_dyn() +{ assert!(arr0(1.234).into_dyn()[IxDyn(&[])] == 1.234); } #[test] -fn test_arr2_into_arrd_nonstandard_strides() { +fn test_arr2_into_arrd_nonstandard_strides() +{ let arr = Array2::from_shape_fn((12, 34).f(), |(i, j)| i * 34 + j).into_dyn(); let brr = ArrayD::from_shape_fn(vec![12, 34], |d| d[0] * 34 + d[1]); diff --git a/tests/iterator_chunks.rs b/tests/iterator_chunks.rs index 8ac885022..79b5403ef 100644 --- a/tests/iterator_chunks.rs +++ b/tests/iterator_chunks.rs @@ -1,18 +1,17 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names, + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names, clippy::float_cmp )] use ndarray::prelude::*; -use ndarray::NdProducer; #[test] -fn chunks() { +#[cfg(feature = "std")] +fn chunks() +{ + use ndarray::NdProducer; let a = >::linspace(1., 100., 10 * 10) - .into_shape((10, 10)) + .into_shape_with_order((10, 10)) .unwrap(); let (m, n) = a.dim(); @@ -48,13 +47,15 @@ fn chunks() { #[should_panic] #[test] -fn chunks_different_size_1() { +fn chunks_different_size_1() +{ let a = Array::::zeros(vec![2, 3]); a.exact_chunks(vec![2]); } #[test] -fn chunks_ok_size() { +fn chunks_ok_size() +{ let mut a = Array::::zeros(vec![2, 3]); a.fill(1.); let mut c = 0; @@ -68,13 +69,15 @@ fn chunks_ok_size() { #[should_panic] #[test] -fn chunks_different_size_2() { +fn chunks_different_size_2() +{ let a = Array::::zeros(vec![2, 3]); a.exact_chunks(vec![2, 3, 4]); } #[test] -fn chunks_mut() { +fn chunks_mut() +{ let mut a = Array::zeros((7, 8)); for (i, mut chunk) in a.exact_chunks_mut((2, 3)).into_iter().enumerate() { chunk.fill(i); @@ -94,7 +97,8 @@ fn chunks_mut() { #[should_panic] #[test] -fn chunks_different_size_3() { +fn chunks_different_size_3() +{ let mut a = Array::::zeros(vec![2, 3]); a.exact_chunks_mut(vec![2, 3, 4]); } diff --git a/tests/iterators.rs b/tests/iterators.rs index 371339b96..bdfd3ee50 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -1,17 +1,11 @@ -#![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names -)] +#![allow(clippy::deref_addrof, clippy::unreadable_literal)] use ndarray::prelude::*; -use ndarray::Ix; -use ndarray::{arr2, arr3, aview1, indices, s, Axis, Data, Dimension, Slice, Zip}; +use ndarray::{arr3, indices, s, Slice, Zip}; use itertools::assert_equal; -use itertools::{enumerate, rev}; -use std::iter::FromIterator; +use itertools::enumerate; +use std::cell::Cell; macro_rules! assert_panics { ($body:expr) => { @@ -28,7 +22,9 @@ macro_rules! assert_panics { } #[test] -fn double_ended() { +#[cfg(feature = "std")] +fn double_ended() +{ let a = ArcArray::linspace(0., 7., 8); let mut it = a.iter().cloned(); assert_eq!(it.next(), Some(0.)); @@ -36,13 +32,38 @@ fn double_ended() { assert_eq!(it.next(), Some(1.)); assert_eq!(it.rev().last(), Some(2.)); assert_equal(aview1(&[1, 2, 3]), &[1, 2, 3]); - assert_equal(rev(aview1(&[1, 2, 3])), rev(&[1, 2, 3])); + assert_equal(aview1(&[1, 2, 3]).into_iter().rev(), [1, 2, 3].iter().rev()); +} + +#[test] +fn double_ended_rows() +{ + let a = ArcArray::from_iter(0..8).into_shape_clone((4, 2)).unwrap(); + let mut row_it = a.rows().into_iter(); + assert_equal(row_it.next_back().unwrap(), &[6, 7]); + assert_equal(row_it.next().unwrap(), &[0, 1]); + assert_equal(row_it.next_back().unwrap(), &[4, 5]); + assert_equal(row_it.next_back().unwrap(), &[2, 3]); + assert!(row_it.next().is_none()); + assert!(row_it.next_back().is_none()); + + for (row, check) in a + .rows() + .into_iter() + .rev() + .zip(&[[6, 7], [4, 5], [2, 3], [0, 1]]) + { + assert_equal(row, check); + } } #[test] -fn iter_size_hint() { +fn iter_size_hint() +{ // Check that the size hint is correctly computed - let a = ArcArray::from_iter(0..24).reshape((2, 3, 4)); + let a = ArcArray::from_iter(0..24) + .into_shape_with_order((2, 3, 4)) + .unwrap(); let mut data = [0; 24]; for (i, elt) in enumerate(&mut data) { *elt = i as i32; @@ -58,12 +79,14 @@ fn iter_size_hint() { } #[test] -fn indexed() { +#[cfg(feature = "std")] +fn indexed() +{ let a = ArcArray::linspace(0., 7., 8); for (i, elt) in a.indexed_iter() { - assert_eq!(i, *elt as Ix); + assert_eq!(i, *elt as usize); } - let a = a.reshape((2, 4, 1)); + let a = a.into_shape_with_order((2, 4, 1)).unwrap(); let (mut i, mut j, k) = (0, 0, 0); for (idx, elt) in a.indexed_iter() { assert_eq!(idx, (i, j, k)); @@ -76,27 +99,31 @@ fn indexed() { } } -fn assert_slice_correct(v: &ArrayBase) -where - S: Data, - D: Dimension, - A: PartialEq + std::fmt::Debug, +#[test] +#[cfg(feature = "std")] +fn as_slice() { - let slc = v.as_slice(); - assert!(slc.is_some()); - let slc = slc.unwrap(); - assert_eq!(v.len(), slc.len()); - assert_equal(v.iter(), slc); -} + use ndarray::Data; + + fn assert_slice_correct(v: &ArrayBase) + where + S: Data, + D: Dimension, + A: PartialEq + std::fmt::Debug, + { + let slc = v.as_slice(); + assert!(slc.is_some()); + let slc = slc.unwrap(); + assert_eq!(v.len(), slc.len()); + assert_equal(v.iter(), slc); + } -#[test] -fn as_slice() { let a = ArcArray::linspace(0., 7., 8); - let a = a.reshape((2, 4, 1)); + let a = a.into_shape_with_order((2, 4, 1)).unwrap(); assert_slice_correct(&a); - let a = a.reshape((2, 4)); + let a = a.into_shape_with_order((2, 4)).unwrap(); assert_slice_correct(&a); assert!(a.view().index_axis(Axis(1), 0).as_slice().is_none()); @@ -119,7 +146,7 @@ fn as_slice() { assert!(u.as_slice().is_some()); assert_slice_correct(&u); - let a = a.reshape((8, 1)); + let a = a.into_shape_with_order((8, 1)).unwrap(); assert_slice_correct(&a); let u = a.slice(s![..;2, ..]); println!( @@ -132,60 +159,59 @@ fn as_slice() { } #[test] -fn inner_iter() { +fn inner_iter() +{ let a = ArcArray::from_iter(0..12); - let a = a.reshape((2, 3, 2)); + let a = a.into_shape_with_order((2, 3, 2)).unwrap(); // [[[0, 1], // [2, 3], // [4, 5]], // [[6, 7], // [8, 9], // ... - assert_equal( - a.genrows(), - vec![ + assert_equal(a.rows(), vec![ aview1(&[0, 1]), aview1(&[2, 3]), aview1(&[4, 5]), aview1(&[6, 7]), aview1(&[8, 9]), aview1(&[10, 11]), - ], - ); + ]); let mut b = ArcArray::zeros((2, 3, 2)); b.swap_axes(0, 2); b.assign(&a); - assert_equal( - b.genrows(), - vec![ + assert_equal(b.rows(), vec![ aview1(&[0, 1]), aview1(&[2, 3]), aview1(&[4, 5]), aview1(&[6, 7]), aview1(&[8, 9]), aview1(&[10, 11]), - ], - ); + ]); } #[test] -fn inner_iter_corner_cases() { +fn inner_iter_corner_cases() +{ let a0 = ArcArray::::zeros(()); - assert_equal(a0.genrows(), vec![aview1(&[0])]); + assert_equal(a0.rows(), vec![aview1(&[0])]); let a2 = ArcArray::::zeros((0, 3)); - assert_equal(a2.genrows(), vec![aview1(&[]); 0]); + assert_equal(a2.rows(), vec![aview1(&[]); 0]); let a2 = ArcArray::::zeros((3, 0)); - assert_equal(a2.genrows(), vec![aview1(&[]); 3]); + assert_equal(a2.rows(), vec![aview1(&[]); 3]); } #[test] -fn inner_iter_size_hint() { +fn inner_iter_size_hint() +{ // Check that the size hint is correctly computed - let a = ArcArray::from_iter(0..24).reshape((2, 3, 4)); + let a = ArcArray::from_iter(0..24) + .into_shape_with_order((2, 3, 4)) + .unwrap(); let mut len = 6; - let mut it = a.genrows().into_iter(); + let mut it = a.rows().into_iter(); assert_eq!(it.len(), len); while len > 0 { it.next(); @@ -196,26 +222,21 @@ fn inner_iter_size_hint() { #[allow(deprecated)] // into_outer_iter #[test] -fn outer_iter() { +fn outer_iter() +{ let a = ArcArray::from_iter(0..12); - let a = a.reshape((2, 3, 2)); + let a = a.into_shape_with_order((2, 3, 2)).unwrap(); // [[[0, 1], // [2, 3], // [4, 5]], // [[6, 7], // [8, 9], // ... - assert_equal( - a.outer_iter(), - vec![a.index_axis(Axis(0), 0), a.index_axis(Axis(0), 1)], - ); + assert_equal(a.outer_iter(), vec![a.index_axis(Axis(0), 0), a.index_axis(Axis(0), 1)]); let mut b = ArcArray::zeros((2, 3, 2)); b.swap_axes(0, 2); b.assign(&a); - assert_equal( - b.outer_iter(), - vec![a.index_axis(Axis(0), 0), a.index_axis(Axis(0), 1)], - ); + assert_equal(b.outer_iter(), vec![a.index_axis(Axis(0), 0), a.index_axis(Axis(0), 1)]); let mut found_rows = Vec::new(); for sub in b.outer_iter() { @@ -223,7 +244,7 @@ fn outer_iter() { found_rows.push(row); } } - assert_equal(a.genrows(), found_rows.clone()); + assert_equal(a.rows(), found_rows.clone()); let mut found_rows_rev = Vec::new(); for sub in b.outer_iter().rev() { @@ -239,10 +260,7 @@ fn outer_iter() { let mut cv = c.slice_mut(s![..;-1, ..;-1, ..;-1]); cv.assign(&a); assert_eq!(&a, &cv); - assert_equal( - cv.outer_iter(), - vec![a.index_axis(Axis(0), 0), a.index_axis(Axis(0), 1)], - ); + assert_equal(cv.outer_iter(), vec![a.index_axis(Axis(0), 0), a.index_axis(Axis(0), 1)]); let mut found_rows = Vec::new(); for sub in cv.outer_iter() { @@ -251,31 +269,30 @@ fn outer_iter() { } } println!("{:#?}", found_rows); - assert_equal(a.genrows(), found_rows); + assert_equal(a.rows(), found_rows); } #[test] -fn axis_iter() { +fn axis_iter() +{ let a = ArcArray::from_iter(0..12); - let a = a.reshape((2, 3, 2)); + let a = a.into_shape_with_order((2, 3, 2)).unwrap(); // [[[0, 1], // [2, 3], // [4, 5]], // [[6, 7], // [8, 9], // ... - assert_equal( - a.axis_iter(Axis(1)), - vec![ + assert_equal(a.axis_iter(Axis(1)), vec![ a.index_axis(Axis(1), 0), a.index_axis(Axis(1), 1), a.index_axis(Axis(1), 2), - ], - ); + ]); } #[test] -fn axis_iter_split_at() { +fn axis_iter_split_at() +{ let a = Array::from_iter(0..5); let iter = a.axis_iter(Axis(0)); let all: Vec<_> = iter.clone().collect(); @@ -287,7 +304,8 @@ fn axis_iter_split_at() { } #[test] -fn axis_iter_split_at_partially_consumed() { +fn axis_iter_split_at_partially_consumed() +{ let a = Array::from_iter(0..5); let mut iter = a.axis_iter(Axis(0)); while iter.next().is_some() { @@ -301,29 +319,34 @@ fn axis_iter_split_at_partially_consumed() { } #[test] -fn axis_iter_zip() { +fn axis_iter_zip() +{ let a = Array::from_iter(0..5); let iter = a.axis_iter(Axis(0)); let mut b = Array::zeros(5); - Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]); + Zip::from(&mut b).and(iter).for_each(|b, a| *b = a[()]); assert_eq!(a, b); } #[test] -fn axis_iter_zip_partially_consumed() { +fn axis_iter_zip_partially_consumed() +{ let a = Array::from_iter(0..5); let mut iter = a.axis_iter(Axis(0)); let mut consumed = 0; while iter.next().is_some() { consumed += 1; let mut b = Array::zeros(a.len() - consumed); - Zip::from(&mut b).and(iter.clone()).apply(|b, a| *b = a[()]); + Zip::from(&mut b) + .and(iter.clone()) + .for_each(|b, a| *b = a[()]); assert_eq!(a.slice(s![consumed..]), b); } } #[test] -fn axis_iter_zip_partially_consumed_discontiguous() { +fn axis_iter_zip_partially_consumed_discontiguous() +{ let a = Array::from_iter(0..5); let mut iter = a.axis_iter(Axis(0)); let mut consumed = 0; @@ -331,13 +354,16 @@ fn axis_iter_zip_partially_consumed_discontiguous() { consumed += 1; let mut b = Array::zeros((a.len() - consumed) * 2); b.slice_collapse(s![..;2]); - Zip::from(&mut b).and(iter.clone()).apply(|b, a| *b = a[()]); + Zip::from(&mut b) + .and(iter.clone()) + .for_each(|b, a| *b = a[()]); assert_eq!(a.slice(s![consumed..]), b); } } #[test] -fn outer_iter_corner_cases() { +fn outer_iter_corner_cases() +{ let a2 = ArcArray::::zeros((0, 3)); assert_equal(a2.outer_iter(), vec![aview1(&[]); 0]); @@ -347,9 +373,10 @@ fn outer_iter_corner_cases() { #[allow(deprecated)] #[test] -fn outer_iter_mut() { +fn outer_iter_mut() +{ let a = ArcArray::from_iter(0..12); - let a = a.reshape((2, 3, 2)); + let a = a.into_shape_with_order((2, 3, 2)).unwrap(); // [[[0, 1], // [2, 3], // [4, 5]], @@ -359,10 +386,7 @@ fn outer_iter_mut() { let mut b = ArcArray::zeros((2, 3, 2)); b.swap_axes(0, 2); b.assign(&a); - assert_equal( - b.outer_iter_mut(), - vec![a.index_axis(Axis(0), 0), a.index_axis(Axis(0), 1)], - ); + assert_equal(b.outer_iter_mut(), vec![a.index_axis(Axis(0), 0), a.index_axis(Axis(0), 1)]); let mut found_rows = Vec::new(); for sub in b.outer_iter_mut() { @@ -370,13 +394,14 @@ fn outer_iter_mut() { found_rows.push(row); } } - assert_equal(a.genrows(), found_rows); + assert_equal(a.rows(), found_rows); } #[test] -fn axis_iter_mut() { +fn axis_iter_mut() +{ let a = ArcArray::from_iter(0..12); - let a = a.reshape((2, 3, 2)); + let a = a.into_shape_with_order((2, 3, 2)).unwrap(); // [[[0, 1], // [2, 3], // [4, 5]], @@ -394,44 +419,36 @@ fn axis_iter_mut() { } #[test] -fn axis_chunks_iter() { +fn axis_chunks_iter() +{ let a = ArcArray::from_iter(0..24); - let a = a.reshape((2, 6, 2)); + let a = a.into_shape_with_order((2, 6, 2)).unwrap(); let it = a.axis_chunks_iter(Axis(1), 2); - assert_equal( - it, - vec![ + assert_equal(it, vec![ arr3(&[[[0, 1], [2, 3]], [[12, 13], [14, 15]]]), arr3(&[[[4, 5], [6, 7]], [[16, 17], [18, 19]]]), arr3(&[[[8, 9], [10, 11]], [[20, 21], [22, 23]]]), - ], - ); + ]); let a = ArcArray::from_iter(0..28); - let a = a.reshape((2, 7, 2)); + let a = a.into_shape_with_order((2, 7, 2)).unwrap(); let it = a.axis_chunks_iter(Axis(1), 2); - assert_equal( - it, - vec![ + assert_equal(it, vec![ arr3(&[[[0, 1], [2, 3]], [[14, 15], [16, 17]]]), arr3(&[[[4, 5], [6, 7]], [[18, 19], [20, 21]]]), arr3(&[[[8, 9], [10, 11]], [[22, 23], [24, 25]]]), arr3(&[[[12, 13]], [[26, 27]]]), - ], - ); + ]); let it = a.axis_chunks_iter(Axis(1), 2).rev(); - assert_equal( - it, - vec![ + assert_equal(it, vec![ arr3(&[[[12, 13]], [[26, 27]]]), arr3(&[[[8, 9], [10, 11]], [[22, 23], [24, 25]]]), arr3(&[[[4, 5], [6, 7]], [[18, 19], [20, 21]]]), arr3(&[[[0, 1], [2, 3]], [[14, 15], [16, 17]]]), - ], - ); + ]); let it = a.axis_chunks_iter(Axis(1), 7); assert_equal(it, vec![a.view()]); @@ -441,7 +458,8 @@ fn axis_chunks_iter() { } #[test] -fn axis_iter_mut_split_at() { +fn axis_iter_mut_split_at() +{ let mut a = Array::from_iter(0..5); let mut a_clone = a.clone(); let all: Vec<_> = a_clone.axis_iter_mut(Axis(0)).collect(); @@ -453,7 +471,8 @@ fn axis_iter_mut_split_at() { } #[test] -fn axis_iter_mut_split_at_partially_consumed() { +fn axis_iter_mut_split_at_partially_consumed() +{ let mut a = Array::from_iter(0..5); for consumed in 1..=a.len() { for mid in 0..=(a.len() - consumed) { @@ -479,12 +498,13 @@ fn axis_iter_mut_split_at_partially_consumed() { } #[test] -fn axis_iter_mut_zip() { +fn axis_iter_mut_zip() +{ let orig = Array::from_iter(0..5); let mut cloned = orig.clone(); let iter = cloned.axis_iter_mut(Axis(0)); let mut b = Array::zeros(5); - Zip::from(&mut b).and(iter).apply(|b, mut a| { + Zip::from(&mut b).and(iter).for_each(|b, mut a| { a[()] += 1; *b = a[()]; }); @@ -493,7 +513,8 @@ fn axis_iter_mut_zip() { } #[test] -fn axis_iter_mut_zip_partially_consumed() { +fn axis_iter_mut_zip_partially_consumed() +{ let mut a = Array::from_iter(0..5); for consumed in 1..=a.len() { let remaining = a.len() - consumed; @@ -502,13 +523,14 @@ fn axis_iter_mut_zip_partially_consumed() { iter.next(); } let mut b = Array::zeros(remaining); - Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]); + Zip::from(&mut b).and(iter).for_each(|b, a| *b = a[()]); assert_eq!(a.slice(s![consumed..]), b); } } #[test] -fn axis_iter_mut_zip_partially_consumed_discontiguous() { +fn axis_iter_mut_zip_partially_consumed_discontiguous() +{ let mut a = Array::from_iter(0..5); for consumed in 1..=a.len() { let remaining = a.len() - consumed; @@ -518,33 +540,34 @@ fn axis_iter_mut_zip_partially_consumed_discontiguous() { } let mut b = Array::zeros(remaining * 2); b.slice_collapse(s![..;2]); - Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]); + Zip::from(&mut b).and(iter).for_each(|b, a| *b = a[()]); assert_eq!(a.slice(s![consumed..]), b); } } #[test] -fn axis_chunks_iter_corner_cases() { +#[cfg(feature = "std")] +fn axis_chunks_iter_corner_cases() +{ // examples provided by @bluss in PR #65 // these tests highlight corner cases of the axis_chunks_iter implementation - // and enable checking if no pointer offseting is out of bounds. However - // checking the absence of of out of bounds offseting cannot (?) be + // and enable checking if no pointer offsetting is out of bounds. However + // checking the absence of of out of bounds offsetting cannot (?) be // done automatically, so one has to launch this test in a debugger. - let a = ArcArray::::linspace(0., 7., 8).reshape((8, 1)); + let a = ArcArray::::linspace(0., 7., 8) + .into_shape_with_order((8, 1)) + .unwrap(); let it = a.axis_chunks_iter(Axis(0), 4); assert_equal(it, vec![a.slice(s![..4, ..]), a.slice(s![4.., ..])]); let a = a.slice(s![..;-1,..]); let it = a.axis_chunks_iter(Axis(0), 8); assert_equal(it, vec![a.view()]); let it = a.axis_chunks_iter(Axis(0), 3); - assert_equal( - it, - vec![ - arr2(&[[7.], [6.], [5.]]), - arr2(&[[4.], [3.], [2.]]), - arr2(&[[1.], [0.]]), - ], - ); + assert_equal(it, vec![ + array![[7.], [6.], [5.]], + array![[4.], [3.], [2.]], + array![[1.], [0.]], + ]); let b = ArcArray::::zeros((8, 2)); let a = b.slice(s![1..;2,..]); @@ -556,10 +579,13 @@ fn axis_chunks_iter_corner_cases() { } #[test] -fn axis_chunks_iter_zero_stride() { +fn axis_chunks_iter_zero_stride() +{ { // stride 0 case - let b = Array::from(vec![0f32; 0]).into_shape((5, 0, 3)).unwrap(); + let b = Array::from(vec![0f32; 0]) + .into_shape_with_order((5, 0, 3)) + .unwrap(); let shapes: Vec<_> = b .axis_chunks_iter(Axis(0), 2) .map(|v| v.raw_dim()) @@ -569,7 +595,9 @@ fn axis_chunks_iter_zero_stride() { { // stride 0 case reverse - let b = Array::from(vec![0f32; 0]).into_shape((5, 0, 3)).unwrap(); + let b = Array::from(vec![0f32; 0]) + .into_shape_with_order((5, 0, 3)) + .unwrap(); let shapes: Vec<_> = b .axis_chunks_iter(Axis(0), 2) .rev() @@ -588,19 +616,22 @@ fn axis_chunks_iter_zero_stride() { #[should_panic] #[test] -fn axis_chunks_iter_zero_chunk_size() { +fn axis_chunks_iter_zero_chunk_size() +{ let a = Array::from_iter(0..5); a.axis_chunks_iter(Axis(0), 0); } #[test] -fn axis_chunks_iter_zero_axis_len() { +fn axis_chunks_iter_zero_axis_len() +{ let a = Array::from_iter(0..0); assert!(a.axis_chunks_iter(Axis(0), 5).next().is_none()); } #[test] -fn axis_chunks_iter_split_at() { +fn axis_chunks_iter_split_at() +{ let mut a = Array2::::zeros((11, 3)); a.iter_mut().enumerate().for_each(|(i, elt)| *elt = i); for source in &[ @@ -627,9 +658,10 @@ fn axis_chunks_iter_split_at() { } #[test] -fn axis_chunks_iter_mut() { +fn axis_chunks_iter_mut() +{ let a = ArcArray::from_iter(0..24); - let mut a = a.reshape((2, 6, 2)); + let mut a = a.into_shape_with_order((2, 6, 2)).unwrap(); let mut it = a.axis_chunks_iter_mut(Axis(1), 2); let mut col0 = it.next().unwrap(); @@ -639,21 +671,26 @@ fn axis_chunks_iter_mut() { #[should_panic] #[test] -fn axis_chunks_iter_mut_zero_chunk_size() { +fn axis_chunks_iter_mut_zero_chunk_size() +{ let mut a = Array::from_iter(0..5); a.axis_chunks_iter_mut(Axis(0), 0); } #[test] -fn axis_chunks_iter_mut_zero_axis_len() { +fn axis_chunks_iter_mut_zero_axis_len() +{ let mut a = Array::from_iter(0..0); assert!(a.axis_chunks_iter_mut(Axis(0), 5).next().is_none()); } #[test] -fn outer_iter_size_hint() { +fn outer_iter_size_hint() +{ // Check that the size hint is correctly computed - let a = ArcArray::from_iter(0..24).reshape((4, 3, 2)); + let a = ArcArray::from_iter(0..24) + .into_shape_with_order((4, 3, 2)) + .unwrap(); let mut len = 4; let mut it = a.outer_iter(); assert_eq!(it.len(), len); @@ -684,8 +721,11 @@ fn outer_iter_size_hint() { } #[test] -fn outer_iter_split_at() { - let a = ArcArray::from_iter(0..30).reshape((5, 3, 2)); +fn outer_iter_split_at() +{ + let a = ArcArray::from_iter(0..30) + .into_shape_with_order((5, 3, 2)) + .unwrap(); let it = a.outer_iter(); let (mut itl, mut itr) = it.clone().split_at(2); @@ -706,16 +746,22 @@ fn outer_iter_split_at() { #[test] #[should_panic] -fn outer_iter_split_at_panics() { - let a = ArcArray::from_iter(0..30).reshape((5, 3, 2)); +fn outer_iter_split_at_panics() +{ + let a = ArcArray::from_iter(0..30) + .into_shape_with_order((5, 3, 2)) + .unwrap(); let it = a.outer_iter(); it.split_at(6); } #[test] -fn outer_iter_mut_split_at() { - let mut a = ArcArray::from_iter(0..30).reshape((5, 3, 2)); +fn outer_iter_mut_split_at() +{ + let mut a = ArcArray::from_iter(0..30) + .into_shape_with_order((5, 3, 2)) + .unwrap(); { let it = a.outer_iter_mut(); @@ -734,12 +780,15 @@ fn outer_iter_mut_split_at() { } #[test] -fn iterators_are_send_sync() { +fn iterators_are_send_sync() +{ // When the element type is Send + Sync, then the iterators and views // are too. fn _send_sync(_: &T) {} - let mut a = ArcArray::from_iter(0..30).into_shape((5, 3, 2)).unwrap(); + let mut a = ArcArray::from_iter(0..30) + .into_shape_with_order((5, 3, 2)) + .unwrap(); _send_sync(&a.view()); _send_sync(&a.view_mut()); @@ -747,8 +796,8 @@ fn iterators_are_send_sync() { _send_sync(&a.iter_mut()); _send_sync(&a.indexed_iter()); _send_sync(&a.indexed_iter_mut()); - _send_sync(&a.genrows()); - _send_sync(&a.genrows_mut()); + _send_sync(&a.rows()); + _send_sync(&a.rows_mut()); _send_sync(&a.outer_iter()); _send_sync(&a.outer_iter_mut()); _send_sync(&a.axis_iter(Axis(1))); @@ -764,7 +813,8 @@ fn iterators_are_send_sync() { #[test] #[allow(clippy::unnecessary_fold)] -fn test_fold() { +fn test_fold() +{ let mut a = Array2::::default((20, 20)); a += 1; let mut iter = a.iter(); @@ -777,7 +827,8 @@ fn test_fold() { } #[test] -fn nth_back_examples() { +fn nth_back_examples() +{ let mut a: Array1 = (0..256).collect(); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); assert_eq!(a.iter().nth_back(0), Some(&a[a.len() - 1])); @@ -790,7 +841,8 @@ fn nth_back_examples() { } #[test] -fn nth_back_zero_n() { +fn nth_back_zero_n() +{ let mut a: Array1 = (0..256).collect(); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); let mut iter1 = a.iter(); @@ -802,7 +854,8 @@ fn nth_back_zero_n() { } #[test] -fn nth_back_nonzero_n() { +fn nth_back_nonzero_n() +{ let mut a: Array1 = (0..256).collect(); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); let mut iter1 = a.iter(); @@ -818,7 +871,8 @@ fn nth_back_nonzero_n() { } #[test] -fn nth_back_past_end() { +fn nth_back_past_end() +{ let mut a: Array1 = (0..256).collect(); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); let mut iter = a.iter(); @@ -827,7 +881,8 @@ fn nth_back_past_end() { } #[test] -fn nth_back_partially_consumed() { +fn nth_back_partially_consumed() +{ let mut a: Array1 = (0..256).collect(); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); let mut iter = a.iter(); @@ -845,7 +900,8 @@ fn nth_back_partially_consumed() { } #[test] -fn test_rfold() { +fn test_rfold() +{ { let mut a = Array1::::default(256); a += 1; @@ -889,3 +945,142 @@ fn test_rfold() { ); } } + +#[test] +fn test_into_iter() +{ + let a = Array1::from(vec![1, 2, 3, 4]); + let v = a.into_iter().collect::>(); + assert_eq!(v, [1, 2, 3, 4]); +} + +#[test] +fn test_into_iter_2d() +{ + let a = Array1::from(vec![1, 2, 3, 4]) + .into_shape_with_order((2, 2)) + .unwrap(); + let v = a.into_iter().collect::>(); + assert_eq!(v, [1, 2, 3, 4]); + + let a = Array1::from(vec![1, 2, 3, 4]) + .into_shape_with_order((2, 2)) + .unwrap() + .reversed_axes(); + let v = a.into_iter().collect::>(); + assert_eq!(v, [1, 3, 2, 4]); +} + +#[cfg_attr(miri, ignore)] // Very slow on CI/CD machines +#[test] +fn test_into_iter_sliced() +{ + let (m, n) = (4, 5); + let drops = Cell::new(0); + + for i in 0..m - 1 { + for j in 0..n - 1 { + for i2 in i + 1..m { + for j2 in j + 1..n { + for invert in 0..3 { + drops.set(0); + let i = i as isize; + let j = j as isize; + let i2 = i2 as isize; + let j2 = j2 as isize; + let mut a = Array1::from_iter(0..(m * n) as i32) + .mapv(|v| DropCount::new(v, &drops)) + .into_shape_with_order((m, n)) + .unwrap(); + a.slice_collapse(s![i..i2, j..j2]); + if invert < a.ndim() { + a.invert_axis(Axis(invert)); + } + + println!("{:?}, {:?}", i..i2, j..j2); + println!("{:?}", a); + let answer = a.iter().cloned().collect::>(); + let v = a.into_iter().collect::>(); + assert_eq!(v, answer); + + assert_eq!(drops.get(), m * n - v.len()); + drop(v); + assert_eq!(drops.get(), m * n); + } + } + } + } + } +} + +/// Helper struct that counts its drops Asserts that it's not dropped twice. Also global number of +/// drops is counted in the cell. +/// +/// Compares equal by its "represented value". +#[derive(Clone, Debug)] +struct DropCount<'a> +{ + value: i32, + my_drops: usize, + drops: &'a Cell, +} + +impl PartialEq for DropCount<'_> +{ + fn eq(&self, other: &Self) -> bool + { + self.value == other.value + } +} + +impl<'a> DropCount<'a> +{ + fn new(value: i32, drops: &'a Cell) -> Self + { + DropCount { + value, + my_drops: 0, + drops, + } + } +} + +impl Drop for DropCount<'_> +{ + fn drop(&mut self) + { + assert_eq!(self.my_drops, 0); + self.my_drops += 1; + self.drops.set(self.drops.get() + 1); + } +} + +#[test] +fn test_impl_iter_compiles() +{ + // Requires that the iterators are covariant in the element type + + // base case: std + fn slice_iter_non_empty_indices<'s, 'a>(array: &'a Vec<&'s str>) -> impl Iterator + 'a + { + array + .iter() + .enumerate() + .filter(|(_index, elem)| !elem.is_empty()) + .map(|(index, _elem)| index) + } + + let _ = slice_iter_non_empty_indices; + + // ndarray case + fn array_iter_non_empty_indices<'s, 'a>(array: &'a Array<&'s str, Ix1>) -> impl Iterator + 'a + { + array + .iter() + .enumerate() + .filter(|(_index, elem)| !elem.is_empty()) + .map(|(index, _elem)| index) + } + + let _ = array_iter_non_empty_indices; +} diff --git a/tests/ix0.rs b/tests/ix0.rs index c8c6c73aa..f1038556a 100644 --- a/tests/ix0.rs +++ b/tests/ix0.rs @@ -1,8 +1,5 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names, + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names, clippy::float_cmp )] @@ -11,7 +8,8 @@ use ndarray::Ix0; use ndarray::ShapeBuilder; #[test] -fn test_ix0() { +fn test_ix0() +{ let mut a = Array::zeros(Ix0()); assert_eq!(a[()], 0.); a[()] = 1.; @@ -30,7 +28,8 @@ fn test_ix0() { } #[test] -fn test_ix0_add() { +fn test_ix0_add() +{ let mut a = Array::zeros(Ix0()); a += 1.; assert_eq!(a[()], 1.); @@ -39,7 +38,8 @@ fn test_ix0_add() { } #[test] -fn test_ix0_add_add() { +fn test_ix0_add_add() +{ let mut a = Array::zeros(Ix0()); a += 1.; let mut b = Array::zeros(Ix0()); @@ -49,7 +49,8 @@ fn test_ix0_add_add() { } #[test] -fn test_ix0_add_broad() { +fn test_ix0_add_broad() +{ let mut b = Array::from(vec![5., 6.]); let mut a = Array::zeros(Ix0()); a += 1.; diff --git a/tests/ixdyn.rs b/tests/ixdyn.rs index 3d96967a0..05f123ba1 100644 --- a/tests/ixdyn.rs +++ b/tests/ixdyn.rs @@ -1,18 +1,17 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names, + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names, clippy::float_cmp )] use ndarray::Array; use ndarray::IntoDimension; +use ndarray::Ix3; +use ndarray::Order; use ndarray::ShapeBuilder; -use ndarray::{Ix0, Ix1, Ix2, Ix3, IxDyn}; #[test] -fn test_ixdyn() { +fn test_ixdyn() +{ // check that we can use fixed size arrays for indexing let mut a = Array::zeros(vec![2, 3, 4]); a[[1, 1, 1]] = 1.; @@ -21,7 +20,8 @@ fn test_ixdyn() { #[should_panic] #[test] -fn test_ixdyn_wrong_dim() { +fn test_ixdyn_wrong_dim() +{ // check that we can use but it panics at runtime, if number of axes is wrong let mut a = Array::zeros(vec![2, 3, 4]); a[[1, 1, 1]] = 1.; @@ -30,7 +30,8 @@ fn test_ixdyn_wrong_dim() { } #[test] -fn test_ixdyn_out_of_bounds() { +fn test_ixdyn_out_of_bounds() +{ // check that we are out of bounds let a = Array::::zeros(vec![2, 3, 4]); let res = a.get([0, 3, 0]); @@ -38,15 +39,16 @@ fn test_ixdyn_out_of_bounds() { } #[test] -fn test_ixdyn_iterate() { - for &rev in &[false, true] { - let mut a = Array::zeros((2, 3, 4).set_f(rev)); +fn test_ixdyn_iterate() +{ + for &order in &[Order::C, Order::F] { + let mut a = Array::zeros((2, 3, 4).set_f(order.is_column_major())); let dim = a.shape().to_vec(); for (i, elt) in a.iter_mut().enumerate() { *elt = i; } println!("{:?}", a.dim()); - let mut a = a.into_shape(dim).unwrap(); + let mut a = a.into_shape_with_order((dim, order)).unwrap(); println!("{:?}", a.dim()); let mut c = 0; for (i, elt) in a.iter_mut().enumerate() { @@ -58,14 +60,15 @@ fn test_ixdyn_iterate() { } #[test] -fn test_ixdyn_index_iterate() { - for &rev in &[false, true] { - let mut a = Array::zeros((2, 3, 4).set_f(rev)); +fn test_ixdyn_index_iterate() +{ + for &order in &[Order::C, Order::F] { + let mut a = Array::zeros((2, 3, 4).set_f(order.is_column_major())); let dim = a.shape().to_vec(); for ((i, j, k), elt) in a.indexed_iter_mut() { *elt = i + 10 * j + 100 * k; } - let a = a.into_shape(dim).unwrap(); + let a = a.into_shape_with_order((dim, order)).unwrap(); println!("{:?}", a.dim()); let mut c = 0; for (i, elt) in a.indexed_iter() { @@ -77,7 +80,8 @@ fn test_ixdyn_index_iterate() { } #[test] -fn test_ixdyn_uget() { +fn test_ixdyn_uget() +{ // check that we are out of bounds let mut a = Array::::zeros(vec![2, 3, 4]); @@ -106,7 +110,8 @@ fn test_ixdyn_uget() { } #[test] -fn test_0() { +fn test_0() +{ let mut a = Array::zeros(vec![]); let z = vec![].into_dimension(); assert_eq!(a[z.clone()], 0.); @@ -126,7 +131,8 @@ fn test_0() { } #[test] -fn test_0_add() { +fn test_0_add() +{ let mut a = Array::zeros(vec![]); a += 1.; assert_eq!(a[[]], 1.); @@ -135,7 +141,8 @@ fn test_0_add() { } #[test] -fn test_0_add_add() { +fn test_0_add_add() +{ let mut a = Array::zeros(vec![]); a += 1.; let mut b = Array::zeros(vec![]); @@ -145,7 +152,8 @@ fn test_0_add_add() { } #[test] -fn test_0_add_broad() { +fn test_0_add_broad() +{ let mut b = Array::from(vec![5., 6.]); let mut a = Array::zeros(vec![]); a += 1.; @@ -155,9 +163,15 @@ fn test_0_add_broad() { } #[test] -fn test_into_dimension() { - let a = Array::linspace(0., 41., 6 * 7).into_shape((6, 7)).unwrap(); - let a2 = a.clone().into_shape(IxDyn(&[6, 7])).unwrap(); +#[cfg(feature = "std")] +fn test_into_dimension() +{ + use ndarray::{Ix0, Ix1, Ix2, IxDyn}; + + let a = Array::linspace(0., 41., 6 * 7) + .into_shape_with_order((6, 7)) + .unwrap(); + let a2 = a.clone().into_shape_with_order(IxDyn(&[6, 7])).unwrap(); let b = a2.clone().into_dimensionality::().unwrap(); assert_eq!(a, b); diff --git a/tests/numeric.rs b/tests/numeric.rs index 7c6f1441e..839aba58e 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -1,8 +1,5 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names, + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names, clippy::float_cmp )] @@ -11,19 +8,22 @@ use ndarray::{arr0, arr1, arr2, array, aview1, Array, Array1, Array2, Array3, Ax use std::f64; #[test] -fn test_mean_with_nan_values() { +fn test_mean_with_nan_values() +{ let a = array![f64::NAN, 1.]; assert!(a.mean().unwrap().is_nan()); } #[test] -fn test_mean_with_empty_array_of_floats() { +fn test_mean_with_empty_array_of_floats() +{ let a: Array1 = array![]; assert!(a.mean().is_none()); } #[test] -fn test_mean_with_array_of_floats() { +fn test_mean_with_array_of_floats() +{ let a: Array1 = array![ 0.99889651, 0.0150731, 0.28492482, 0.83819218, 0.48413156, 0.80710412, 0.41762936, 0.22879429, 0.43997224, 0.23831807, 0.02416466, 0.6269962, 0.47420614, 0.56275487, @@ -39,34 +39,133 @@ fn test_mean_with_array_of_floats() { } #[test] -fn sum_mean() { - let a = arr2(&[[1., 2.], [3., 4.]]); +fn sum_mean_prod() +{ + let a: Array2 = arr2(&[[1., 2.], [3., 4.]]); assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.])); assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.])); + assert_eq!(a.product_axis(Axis(0)), arr1(&[3., 8.])); + assert_eq!(a.product_axis(Axis(1)), arr1(&[2., 12.])); assert_eq!(a.mean_axis(Axis(0)), Some(arr1(&[2., 3.]))); assert_eq!(a.mean_axis(Axis(1)), Some(arr1(&[1.5, 3.5]))); assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.)); + assert_eq!(a.product_axis(Axis(1)).product_axis(Axis(0)), arr0(24.)); assert_eq!(a.view().mean_axis(Axis(1)).unwrap(), aview1(&[1.5, 3.5])); assert_eq!(a.sum(), 10.); } #[test] -fn sum_mean_empty() { +fn sum_mean_prod_empty() +{ assert_eq!(Array3::::ones((2, 0, 3)).sum(), 0.); + assert_eq!(Array3::::ones((2, 0, 3)).product(), 1.); assert_eq!(Array1::::ones(0).sum_axis(Axis(0)), arr0(0.)); + assert_eq!(Array1::::ones(0).product_axis(Axis(0)), arr0(1.)); assert_eq!( Array3::::ones((2, 0, 3)).sum_axis(Axis(1)), Array::zeros((2, 3)), ); + assert_eq!( + Array3::::ones((2, 0, 3)).product_axis(Axis(1)), + Array::ones((2, 3)), + ); let a = Array1::::ones(0).mean_axis(Axis(0)); assert_eq!(a, None); let a = Array3::::ones((2, 0, 3)).mean_axis(Axis(1)); assert_eq!(a, None); } +#[test] +#[cfg(feature = "std")] +fn var() +{ + let a = array![1., -4.32, 1.14, 0.32]; + assert_abs_diff_eq!(a.var(0.), 5.049875, epsilon = 1e-8); +} + +#[test] +#[cfg(feature = "std")] +#[should_panic] +fn var_negative_ddof() +{ + let a = array![1., 2., 3.]; + a.var(-1.); +} + +#[test] +#[cfg(feature = "std")] +#[should_panic] +fn var_too_large_ddof() +{ + let a = array![1., 2., 3.]; + a.var(4.); +} + +#[test] +#[cfg(feature = "std")] +fn var_nan_ddof() +{ + let a = Array2::::zeros((2, 3)); + let v = a.var(::std::f64::NAN); + assert!(v.is_nan()); +} + +#[test] +#[cfg(feature = "std")] +fn var_empty_arr() +{ + let a: Array1 = array![]; + assert!(a.var(0.0).is_nan()); +} + +#[test] +#[cfg(feature = "std")] +fn std() +{ + let a = array![1., -4.32, 1.14, 0.32]; + assert_abs_diff_eq!(a.std(0.), 2.24719, epsilon = 1e-5); +} + +#[test] +#[cfg(feature = "std")] +#[should_panic] +fn std_negative_ddof() +{ + let a = array![1., 2., 3.]; + a.std(-1.); +} + +#[test] +#[cfg(feature = "std")] +#[should_panic] +fn std_too_large_ddof() +{ + let a = array![1., 2., 3.]; + a.std(4.); +} + +#[test] +#[cfg(feature = "std")] +fn std_nan_ddof() +{ + let a = Array2::::zeros((2, 3)); + let v = a.std(::std::f64::NAN); + assert!(v.is_nan()); +} + +#[test] +#[cfg(feature = "std")] +fn std_empty_arr() +{ + let a: Array1 = array![]; + assert!(a.std(0.0).is_nan()); +} + #[test] #[cfg(feature = "approx")] -fn var_axis() { +#[cfg(feature = "std")] +fn var_axis() +{ use ndarray::{aview0, aview2}; let a = array![ @@ -124,7 +223,9 @@ fn var_axis() { #[test] #[cfg(feature = "approx")] -fn std_axis() { +#[cfg(feature = "std")] +fn std_axis() +{ use ndarray::aview2; let a = array![ @@ -183,20 +284,26 @@ fn std_axis() { #[test] #[should_panic] -fn var_axis_negative_ddof() { +#[cfg(feature = "std")] +fn var_axis_negative_ddof() +{ let a = array![1., 2., 3.]; a.var_axis(Axis(0), -1.); } #[test] #[should_panic] -fn var_axis_too_large_ddof() { +#[cfg(feature = "std")] +fn var_axis_too_large_ddof() +{ let a = array![1., 2., 3.]; a.var_axis(Axis(0), 4.); } #[test] -fn var_axis_nan_ddof() { +#[cfg(feature = "std")] +fn var_axis_nan_ddof() +{ let a = Array2::::zeros((2, 3)); let v = a.var_axis(Axis(1), ::std::f64::NAN); assert_eq!(v.shape(), &[2]); @@ -204,7 +311,9 @@ fn var_axis_nan_ddof() { } #[test] -fn var_axis_empty_axis() { +#[cfg(feature = "std")] +fn var_axis_empty_axis() +{ let a = Array2::::zeros((2, 0)); let v = a.var_axis(Axis(1), 0.); assert_eq!(v.shape(), &[2]); @@ -213,15 +322,87 @@ fn var_axis_empty_axis() { #[test] #[should_panic] -fn std_axis_bad_dof() { +#[cfg(feature = "std")] +fn std_axis_bad_dof() +{ let a = array![1., 2., 3.]; a.std_axis(Axis(0), 4.); } #[test] -fn std_axis_empty_axis() { +#[cfg(feature = "std")] +fn std_axis_empty_axis() +{ let a = Array2::::zeros((2, 0)); let v = a.std_axis(Axis(1), 0.); assert_eq!(v.shape(), &[2]); v.mapv(|x| assert!(x.is_nan())); } + +#[test] +fn diff_1d_order1() +{ + let data = array![1.0, 2.0, 4.0, 7.0]; + let expected = array![1.0, 2.0, 3.0]; + assert_eq!(data.diff(1, Axis(0)), expected); +} + +#[test] +fn diff_1d_order2() +{ + let data = array![1.0, 2.0, 4.0, 7.0]; + assert_eq!( + data.diff(2, Axis(0)), + data.diff(1, Axis(0)).diff(1, Axis(0)) + ); +} + +#[test] +fn diff_1d_order3() +{ + let data = array![1.0, 2.0, 4.0, 7.0]; + assert_eq!( + data.diff(3, Axis(0)), + data.diff(1, Axis(0)).diff(1, Axis(0)).diff(1, Axis(0)) + ); +} + +#[test] +fn diff_2d_order1_ax0() +{ + let data = array![ + [1.0, 2.0, 4.0, 7.0], + [1.0, 3.0, 6.0, 6.0], + [1.5, 3.5, 5.5, 5.5] + ]; + let expected = array![[0.0, 1.0, 2.0, -1.0], [0.5, 0.5, -0.5, -0.5]]; + assert_eq!(data.diff(1, Axis(0)), expected); +} + +#[test] +fn diff_2d_order1_ax1() +{ + let data = array![ + [1.0, 2.0, 4.0, 7.0], + [1.0, 3.0, 6.0, 6.0], + [1.5, 3.5, 5.5, 5.5] + ]; + let expected = array![[1.0, 2.0, 3.0], [2.0, 3.0, 0.0], [2.0, 2.0, 0.0]]; + assert_eq!(data.diff(1, Axis(1)), expected); +} + +#[test] +#[should_panic] +fn diff_panic_n_too_big() +{ + let data = array![1.0, 2.0, 4.0, 7.0]; + data.diff(10, Axis(0)); +} + +#[test] +#[should_panic] +fn diff_panic_axis_out_of_bounds() +{ + let data = array![1, 2, 4, 7]; + data.diff(1, Axis(2)); +} diff --git a/tests/oper.rs b/tests/oper.rs index 89c45888c..401913e2b 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -1,43 +1,42 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names )] - +#![cfg(feature = "std")] use ndarray::linalg::general_mat_mul; +use ndarray::linalg::kron; use ndarray::prelude::*; +#[cfg(feature = "approx")] +use ndarray::Order; use ndarray::{rcarr1, rcarr2}; use ndarray::{Data, LinalgScalar}; use ndarray::{Ix, Ixs}; -use std::iter::FromIterator; +use ndarray_gen::array_builder::ArrayBuilder; use approx::assert_abs_diff_eq; use defmac::defmac; -use std::ops::Neg; +use num_traits::Num; +use num_traits::Zero; -fn test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32]) { - let aa = rcarr1(a); - let bb = rcarr1(b); - let cc = rcarr1(c); +fn test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32]) +{ + let aa = CowArray::from(arr1(a)); + let bb = CowArray::from(arr1(b)); + let cc = CowArray::from(arr1(c)); test_oper_arr::(op, aa.clone(), bb.clone(), cc.clone()); let dim = (2, 2); - let aa = aa.reshape(dim); - let bb = bb.reshape(dim); - let cc = cc.reshape(dim); + let aa = aa.to_shape(dim).unwrap(); + let bb = bb.to_shape(dim).unwrap(); + let cc = cc.to_shape(dim).unwrap(); test_oper_arr::(op, aa.clone(), bb.clone(), cc.clone()); let dim = (1, 2, 1, 2); - let aa = aa.reshape(dim); - let bb = bb.reshape(dim); - let cc = cc.reshape(dim); + let aa = aa.to_shape(dim).unwrap(); + let bb = bb.to_shape(dim).unwrap(); + let cc = cc.to_shape(dim).unwrap(); test_oper_arr::(op, aa.clone(), bb.clone(), cc.clone()); } -fn test_oper_arr(op: &str, mut aa: ArcArray, bb: ArcArray, cc: ArcArray) -where - A: NdFloat, - for<'a> &'a A: Neg, - D: Dimension, +fn test_oper_arr(op: &str, mut aa: CowArray, bb: CowArray, cc: CowArray) +where D: Dimension { match op { "+" => { @@ -67,54 +66,26 @@ where } "neg" => { assert_eq!(-&aa, cc); - assert_eq!(-aa.clone(), cc); + assert_eq!(-aa.into_owned(), cc); } _ => panic!(), } } #[test] -fn operations() { - test_oper( - "+", - &[1.0, 2.0, 3.0, 4.0], - &[0.0, 1.0, 2.0, 3.0], - &[1.0, 3.0, 5.0, 7.0], - ); - test_oper( - "-", - &[1.0, 2.0, 3.0, 4.0], - &[0.0, 1.0, 2.0, 3.0], - &[1.0, 1.0, 1.0, 1.0], - ); - test_oper( - "*", - &[1.0, 2.0, 3.0, 4.0], - &[0.0, 1.0, 2.0, 3.0], - &[0.0, 2.0, 6.0, 12.0], - ); - test_oper( - "/", - &[1.0, 2.0, 3.0, 4.0], - &[1.0, 1.0, 2.0, 3.0], - &[1.0, 2.0, 3.0 / 2.0, 4.0 / 3.0], - ); - test_oper( - "%", - &[1.0, 2.0, 3.0, 4.0], - &[1.0, 1.0, 2.0, 3.0], - &[0.0, 0.0, 1.0, 1.0], - ); - test_oper( - "neg", - &[1.0, 2.0, 3.0, 4.0], - &[1.0, 1.0, 2.0, 3.0], - &[-1.0, -2.0, -3.0, -4.0], - ); +fn operations() +{ + test_oper("+", &[1.0, 2.0, 3.0, 4.0], &[0.0, 1.0, 2.0, 3.0], &[1.0, 3.0, 5.0, 7.0]); + test_oper("-", &[1.0, 2.0, 3.0, 4.0], &[0.0, 1.0, 2.0, 3.0], &[1.0, 1.0, 1.0, 1.0]); + test_oper("*", &[1.0, 2.0, 3.0, 4.0], &[0.0, 1.0, 2.0, 3.0], &[0.0, 2.0, 6.0, 12.0]); + test_oper("/", &[1.0, 2.0, 3.0, 4.0], &[1.0, 1.0, 2.0, 3.0], &[1.0, 2.0, 3.0 / 2.0, 4.0 / 3.0]); + test_oper("%", &[1.0, 2.0, 3.0, 4.0], &[1.0, 1.0, 2.0, 3.0], &[0.0, 0.0, 1.0, 1.0]); + test_oper("neg", &[1.0, 2.0, 3.0, 4.0], &[1.0, 1.0, 2.0, 3.0], &[-1.0, -2.0, -3.0, -4.0]); } #[test] -fn scalar_operations() { +fn scalar_operations() +{ let a = arr0::(1.); let b = rcarr1::(&[1., 1.]); let c = rcarr2(&[[1., 1.], [1., 1.]]); @@ -147,21 +118,21 @@ fn scalar_operations() { } } -fn reference_dot<'a, A, V1, V2>(a: V1, b: V2) -> A +fn reference_dot<'a, V1, V2>(a: V1, b: V2) -> f32 where - A: NdFloat, - V1: AsArray<'a, A>, - V2: AsArray<'a, A>, + V1: AsArray<'a, f32>, + V2: AsArray<'a, f32>, { let a = a.into(); let b = b.into(); a.iter() .zip(b.iter()) - .fold(A::zero(), |acc, (&x, &y)| acc + x * y) + .fold(f32::zero(), |acc, (&x, &y)| acc + x * y) } #[test] -fn dot_product() { +fn dot_product() +{ let a = Array::range(0., 69., 1.); let b = &a * 2. - 7.; let dot = 197846.; @@ -199,7 +170,8 @@ fn dot_product() { // test that we can dot product with a broadcast array #[test] -fn dot_product_0() { +fn dot_product_0() +{ let a = Array::range(0., 69., 1.); let x = 1.5; let b = aview0(&x); @@ -219,7 +191,8 @@ fn dot_product_0() { } #[test] -fn dot_product_neg_stride() { +fn dot_product_neg_stride() +{ // test that we can dot with negative stride let a = Array::range(0., 69., 1.); let b = &a * 2. - 7.; @@ -238,8 +211,11 @@ fn dot_product_neg_stride() { } #[test] -fn fold_and_sum() { - let a = Array::linspace(0., 127., 128).into_shape((8, 16)).unwrap(); +fn fold_and_sum() +{ + let a = Array::linspace(0., 127., 128) + .into_shape_with_order((8, 16)) + .unwrap(); assert_abs_diff_eq!(a.fold(0., |acc, &x| acc + x), a.sum(), epsilon = 1e-5); // test different strides @@ -277,8 +253,11 @@ fn fold_and_sum() { } #[test] -fn product() { - let a = Array::linspace(0.5, 2., 128).into_shape((8, 16)).unwrap(); +fn product() +{ + let a = Array::linspace(0.5, 2., 128) + .into_shape_with_order((8, 16)) + .unwrap(); assert_abs_diff_eq!(a.fold(1., |acc, &x| acc * x), a.product(), epsilon = 1e-5); // test different strides @@ -296,27 +275,20 @@ fn product() { } } -fn range_mat(m: Ix, n: Ix) -> Array2 { - Array::linspace(0., (m * n) as f32 - 1., m * n) - .into_shape((m, n)) - .unwrap() -} - -fn range_mat64(m: Ix, n: Ix) -> Array2 { - Array::linspace(0., (m * n) as f64 - 1., m * n) - .into_shape((m, n)) - .unwrap() +fn range_mat(m: Ix, n: Ix) -> Array2 +{ + ArrayBuilder::new((m, n)).build() } #[cfg(feature = "approx")] -fn range1_mat64(m: Ix) -> Array1 { - Array::linspace(0., m as f64 - 1., m) +fn range1_mat64(m: Ix) -> Array1 +{ + ArrayBuilder::new(m).build() } -fn range_i32(m: Ix, n: Ix) -> Array2 { - Array::from_iter(0..(m * n) as i32) - .into_shape((m, n)) - .unwrap() +fn range_i32(m: Ix, n: Ix) -> Array2 +{ + ArrayBuilder::new((m, n)).build() } // simple, slow, correct (hopefully) mat mul @@ -338,9 +310,7 @@ where let mut j = 0; for rr in &mut res_elems { unsafe { - *rr = (0..k).fold(A::zero(), move |s, x| { - s + *lhs.uget((i, x)) * *rhs.uget((x, j)) - }); + *rr = (0..k).fold(A::zero(), move |s, x| s + *lhs.uget((i, x)) * *rhs.uget((x, j))); } j += 1; if j == n { @@ -352,10 +322,11 @@ where } #[test] -fn mat_mul() { +fn mat_mul() +{ let (m, n, k) = (8, 8, 8); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut b = b / 4.; { let mut c = b.column_mut(0); @@ -373,8 +344,8 @@ fn mat_mul() { assert_eq!(ab, af.dot(&bf)); let (m, n, k) = (10, 5, 11); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut b = b / 4.; { let mut c = b.column_mut(0); @@ -392,8 +363,8 @@ fn mat_mul() { assert_eq!(ab, af.dot(&bf)); let (m, n, k) = (10, 8, 1); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut b = b / 4.; { let mut c = b.column_mut(0); @@ -414,10 +385,11 @@ fn mat_mul() { // Check that matrix multiplication of contiguous matrices returns a // matrix with the same order #[test] -fn mat_mul_order() { +fn mat_mul_order() +{ let (m, n, k) = (8, 8, 8); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut af = Array::zeros(a.dim().f()); let mut bf = Array::zeros(b.dim().f()); af.assign(&a); @@ -433,30 +405,33 @@ fn mat_mul_order() { // test matrix multiplication shape mismatch #[test] #[should_panic] -fn mat_mul_shape_mismatch() { +fn mat_mul_shape_mismatch() +{ let (m, k, k2, n) = (8, 8, 9, 8); - let a = range_mat(m, k); - let b = range_mat(k2, n); + let a = range_mat::(m, k); + let b = range_mat::(k2, n); a.dot(&b); } // test matrix multiplication shape mismatch #[test] #[should_panic] -fn mat_mul_shape_mismatch_2() { +fn mat_mul_shape_mismatch_2() +{ let (m, k, k2, n) = (8, 8, 8, 8); - let a = range_mat(m, k); - let b = range_mat(k2, n); - let mut c = range_mat(m, n + 1); + let a = range_mat::(m, k); + let b = range_mat::(k2, n); + let mut c = range_mat::(m, n + 1); general_mat_mul(1., &a, &b, 1., &mut c); } // Check that matrix multiplication // supports broadcast arrays. #[test] -fn mat_mul_broadcast() { +fn mat_mul_broadcast() +{ let (m, n, k) = (16, 16, 16); - let a = range_mat(m, n); + let a = range_mat::(m, n); let x1 = 1.; let x = Array::from(vec![x1]); let b0 = x.broadcast((n, k)).unwrap(); @@ -473,10 +448,11 @@ fn mat_mul_broadcast() { // Check that matrix multiplication supports reversed axes #[test] -fn mat_mul_rev() { +fn mat_mul_rev() +{ let (m, n, k) = (16, 16, 16); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut rev = Array::zeros(b.dim()); let mut rev = rev.slice_mut(s![..;-1, ..]); rev.assign(&b); @@ -489,7 +465,8 @@ fn mat_mul_rev() { // Check that matrix multiplication supports arrays with zero rows or columns #[test] -fn mat_mut_zero_len() { +fn mat_mut_zero_len() +{ defmac!(mat_mul_zero_len range_mat_fn => { for n in 0..4 { for m in 0..4 { @@ -504,13 +481,14 @@ fn mat_mut_zero_len() { } } }); - mat_mul_zero_len!(range_mat); - mat_mul_zero_len!(range_mat64); + mat_mul_zero_len!(range_mat::); + mat_mul_zero_len!(range_mat::); mat_mul_zero_len!(range_i32); } #[test] -fn scaled_add() { +fn scaled_add() +{ let a = range_mat(16, 15); let mut b = range_mat(16, 15); b.mapv_inplace(f32::exp); @@ -524,8 +502,10 @@ fn scaled_add() { } #[cfg(feature = "approx")] +#[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] -fn scaled_add_2() { +fn scaled_add_2() +{ let beta = -2.3; let sizes = vec![ (4, 4, 1, 4), @@ -542,9 +522,9 @@ fn scaled_add_2() { for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k, n, q) in &sizes { - let mut a = range_mat64(m, k); + let mut a = range_mat::(m, k); let mut answer = a.clone(); - let c = range_mat64(n, q); + let c = range_mat::(n, q); { let mut av = a.slice_mut(s![..;s1, ..;s2]); @@ -561,10 +541,13 @@ fn scaled_add_2() { } #[cfg(feature = "approx")] +#[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] -fn scaled_add_3() { +fn scaled_add_3() +{ use approx::assert_relative_eq; - use ndarray::{SliceInfo, SliceOrIndex}; + use ndarray::{Slice, SliceInfo, SliceInfoElem}; + use std::convert::TryFrom; let beta = -2.3; let sizes = vec![ @@ -582,23 +565,23 @@ fn scaled_add_3() { for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k, n, q) in &sizes { - let mut a = range_mat64(m, k); + let mut a = range_mat::(m, k); let mut answer = a.clone(); let cdim = if n == 1 { vec![q] } else { vec![n, q] }; - let cslice = if n == 1 { - vec![SliceOrIndex::from(..).step_by(s2)] + let cslice: Vec = if n == 1 { + vec![Slice::from(..).step_by(s2).into()] } else { vec![ - SliceOrIndex::from(..).step_by(s1), - SliceOrIndex::from(..).step_by(s2), + Slice::from(..).step_by(s1).into(), + Slice::from(..).step_by(s2).into(), ] }; - let c = range_mat64(n, q).into_shape(cdim).unwrap(); + let c = range_mat::(n, q).into_shape_with_order(cdim).unwrap(); { let mut av = a.slice_mut(s![..;s1, ..;s2]); - let c = c.slice(SliceInfo::<_, IxDyn>::new(cslice).unwrap().as_ref()); + let c = c.slice(SliceInfo::<_, IxDyn, IxDyn>::try_from(cslice).unwrap()); let mut answerv = answer.slice_mut(s![..;s1, ..;s2]); answerv += &(beta * &c); @@ -611,8 +594,10 @@ fn scaled_add_3() { } #[cfg(feature = "approx")] +#[cfg_attr(miri, ignore)] #[test] -fn gen_mat_mul() { +fn gen_mat_mul() +{ let alpha = -2.3; let beta = 3.14; let sizes = vec![ @@ -630,9 +615,9 @@ fn gen_mat_mul() { for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k, n) in &sizes { - let a = range_mat64(m, k); - let b = range_mat64(k, n); - let mut c = range_mat64(m, n); + let a = range_mat::(m, k); + let b = range_mat::(k, n); + let mut c = range_mat::(m, n); let mut answer = c.clone(); { @@ -654,32 +639,38 @@ fn gen_mat_mul() { // Test y = A x where A is f-order #[cfg(feature = "approx")] #[test] -fn gemm_64_1_f() { - let a = range_mat64(64, 64).reversed_axes(); +fn gemm_64_1_f() +{ + let a = range_mat::(64, 64).reversed_axes(); let (m, n) = a.dim(); // m x n times n x 1 == m x 1 - let x = range_mat64(n, 1); - let mut y = range_mat64(m, 1); + let x = range_mat::(n, 1); + let mut y = range_mat::(m, 1); let answer = reference_mat_mul(&a, &x) + &y; general_mat_mul(1.0, &a, &x, 1.0, &mut y); approx::assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7); } #[test] -fn gen_mat_mul_i32() { +fn gen_mat_mul_i32() +{ let alpha = -1; let beta = 2; - let sizes = vec![ - (4, 4, 4), - (8, 8, 8), - (17, 15, 16), - (4, 17, 3), - (17, 3, 22), - (19, 18, 2), - (16, 17, 15), - (15, 16, 17), - (67, 63, 62), - ]; + let sizes = if cfg!(miri) { + vec![(4, 4, 4), (4, 7, 3)] + } else { + vec![ + (4, 4, 4), + (8, 8, 8), + (17, 15, 16), + (4, 17, 3), + (17, 3, 22), + (19, 18, 2), + (16, 17, 15), + (15, 16, 17), + (67, 63, 62), + ] + }; for &(m, k, n) in &sizes { let a = range_i32(m, k); let b = range_i32(k, n); @@ -693,24 +684,28 @@ fn gen_mat_mul_i32() { #[cfg(feature = "approx")] #[test] -fn gen_mat_vec_mul() { +#[cfg_attr(miri, ignore)] // Takes too long +fn gen_mat_vec_mul() +{ use approx::assert_relative_eq; use ndarray::linalg::general_mat_vec_mul; // simple, slow, correct (hopefully) mat mul - fn reference_mat_vec_mul( - lhs: &ArrayBase, - rhs: &ArrayBase, - ) -> Array1 + fn reference_mat_vec_mul(lhs: &ArrayBase, rhs: &ArrayBase) -> Array1 where A: LinalgScalar, S: Data, S2: Data, { let ((m, _), k) = (lhs.dim(), rhs.dim()); - reference_mat_mul(lhs, &rhs.to_owned().into_shape((k, 1)).unwrap()) - .into_shape(m) - .unwrap() + reference_mat_mul( + lhs, + &rhs.as_standard_layout() + .into_shape_with_order((k, 1)) + .unwrap(), + ) + .into_shape_with_order(m) + .unwrap() } let alpha = -2.3; @@ -730,11 +725,8 @@ fn gen_mat_vec_mul() { for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k) in &sizes { - for &rev in &[false, true] { - let mut a = range_mat64(m, k); - if rev { - a = a.reversed_axes(); - } + for order in [Order::C, Order::F] { + let a = ArrayBuilder::new((m, k)).memory_order(order).build(); let (m, k) = a.dim(); let b = range1_mat64(k); let mut c = range1_mat64(m); @@ -758,24 +750,28 @@ fn gen_mat_vec_mul() { } #[cfg(feature = "approx")] +#[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] -fn vec_mat_mul() { +fn vec_mat_mul() +{ use approx::assert_relative_eq; // simple, slow, correct (hopefully) mat mul - fn reference_vec_mat_mul( - lhs: &ArrayBase, - rhs: &ArrayBase, - ) -> Array1 + fn reference_vec_mat_mul(lhs: &ArrayBase, rhs: &ArrayBase) -> Array1 where A: LinalgScalar, S: Data, S2: Data, { let (m, (_, n)) = (lhs.dim(), rhs.dim()); - reference_mat_mul(&lhs.to_owned().into_shape((1, m)).unwrap(), rhs) - .into_shape(n) - .unwrap() + reference_mat_mul( + &lhs.as_standard_layout() + .into_shape_with_order((1, m)) + .unwrap(), + rhs, + ) + .into_shape_with_order(n) + .unwrap() } let sizes = vec![ @@ -793,11 +789,8 @@ fn vec_mat_mul() { for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, n) in &sizes { - for &rev in &[false, true] { - let mut b = range_mat64(m, n); - if rev { - b = b.reversed_axes(); - } + for order in [Order::C, Order::F] { + let b = ArrayBuilder::new((m, n)).memory_order(order).build(); let (m, n) = b.dim(); let a = range1_mat64(m); let mut c = range1_mat64(n); @@ -818,3 +811,61 @@ fn vec_mat_mul() { } } } + +#[test] +fn kron_square_f64() +{ + let a = arr2(&[[1.0, 0.0], [0.0, 1.0]]); + let b = arr2(&[[0.0, 1.0], [1.0, 0.0]]); + + assert_eq!( + kron(&a, &b), + arr2(&[ + [0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 0.0] + ]), + ); + + assert_eq!( + kron(&b, &a), + arr2(&[ + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0] + ]), + ) +} + +#[test] +fn kron_square_i64() +{ + let a = arr2(&[[1, 0], [0, 1]]); + let b = arr2(&[[0, 1], [1, 0]]); + + assert_eq!( + kron(&a, &b), + arr2(&[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]), + ); + + assert_eq!( + kron(&b, &a), + arr2(&[[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]), + ) +} + +#[test] +fn kron_i64() +{ + let a = arr2(&[[1, 0]]); + let b = arr2(&[[0, 1], [1, 0]]); + let r = arr2(&[[0, 1, 0, 0], [1, 0, 0, 0]]); + assert_eq!(kron(&a, &b), r); + + let a = arr2(&[[1, 0], [0, 0], [0, 1]]); + let b = arr2(&[[0, 1], [1, 0]]); + let r = arr2(&[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]); + assert_eq!(kron(&a, &b), r); +} diff --git a/tests/par_azip.rs b/tests/par_azip.rs index e5dc02c4e..418c21ef8 100644 --- a/tests/par_azip.rs +++ b/tests/par_azip.rs @@ -7,7 +7,8 @@ use ndarray::prelude::*; use std::sync::atomic::{AtomicUsize, Ordering}; #[test] -fn test_par_azip1() { +fn test_par_azip1() +{ let mut a = Array::zeros(62); let b = Array::from_elem(62, 42); par_azip!((a in &mut a) { *a = 42 }); @@ -15,7 +16,8 @@ fn test_par_azip1() { } #[test] -fn test_par_azip2() { +fn test_par_azip2() +{ let mut a = Array::zeros((5, 7)); let b = Array::from_shape_fn(a.dim(), |(i, j)| 1. / (i + 2 * j) as f32); par_azip!((a in &mut a, &b in &b, ) *a = b ); @@ -24,7 +26,8 @@ fn test_par_azip2() { #[test] #[cfg(feature = "approx")] -fn test_par_azip3() { +fn test_par_azip3() +{ use approx::assert_abs_diff_eq; let mut a = [0.; 32]; @@ -44,7 +47,8 @@ fn test_par_azip3() { #[should_panic] #[test] -fn test_zip_dim_mismatch_1() { +fn test_zip_dim_mismatch_1() +{ let mut a = Array::zeros((5, 7)); let mut d = a.raw_dim(); d[0] += 1; @@ -53,7 +57,8 @@ fn test_zip_dim_mismatch_1() { } #[test] -fn test_indices_1() { +fn test_indices_1() +{ let mut a1 = Array::default(12); for (i, elt) in a1.indexed_iter_mut() { *elt = i; diff --git a/tests/par_rayon.rs b/tests/par_rayon.rs index 4d5a8f1a9..13669763f 100644 --- a/tests/par_rayon.rs +++ b/tests/par_rayon.rs @@ -9,7 +9,8 @@ const CHUNK_SIZE: usize = 100; const N_CHUNKS: usize = (M + CHUNK_SIZE - 1) / CHUNK_SIZE; #[test] -fn test_axis_iter() { +fn test_axis_iter() +{ let mut a = Array2::::zeros((M, N)); for (i, mut v) in a.axis_iter_mut(Axis(0)).enumerate() { v.fill(i as _); @@ -22,10 +23,11 @@ fn test_axis_iter() { #[test] #[cfg(feature = "approx")] -fn test_axis_iter_mut() { +fn test_axis_iter_mut() +{ use approx::assert_abs_diff_eq; let mut a = Array::linspace(0., 1.0f64, M * N) - .into_shape((M, N)) + .into_shape_with_order((M, N)) .unwrap(); let b = a.mapv(|x| x.exp()); a.axis_iter_mut(Axis(0)) @@ -36,7 +38,8 @@ fn test_axis_iter_mut() { } #[test] -fn test_regular_iter() { +fn test_regular_iter() +{ let mut a = Array2::::zeros((M, N)); for (i, mut v) in a.axis_iter_mut(Axis(0)).enumerate() { v.fill(i as _); @@ -47,7 +50,8 @@ fn test_regular_iter() { } #[test] -fn test_regular_iter_collect() { +fn test_regular_iter_collect() +{ let mut a = Array2::::zeros((M, N)); for (i, mut v) in a.axis_iter_mut(Axis(0)).enumerate() { v.fill(i as _); @@ -57,7 +61,8 @@ fn test_regular_iter_collect() { } #[test] -fn test_axis_chunks_iter() { +fn test_axis_chunks_iter() +{ let mut a = Array2::::zeros((M, N)); for (i, mut v) in a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE).enumerate() { v.fill(i as _); @@ -74,10 +79,11 @@ fn test_axis_chunks_iter() { #[test] #[cfg(feature = "approx")] -fn test_axis_chunks_iter_mut() { +fn test_axis_chunks_iter_mut() +{ use approx::assert_abs_diff_eq; let mut a = Array::linspace(0., 1.0f64, M * N) - .into_shape((M, N)) + .into_shape_with_order((M, N)) .unwrap(); let b = a.mapv(|x| x.exp()); a.axis_chunks_iter_mut(Axis(0), CHUNK_SIZE) diff --git a/tests/par_zip.rs b/tests/par_zip.rs index f10f3acde..9f10d9fd5 100644 --- a/tests/par_zip.rs +++ b/tests/par_zip.rs @@ -8,17 +8,19 @@ const M: usize = 1024 * 10; const N: usize = 100; #[test] -fn test_zip_1() { +fn test_zip_1() +{ let mut a = Array2::::zeros((M, N)); - Zip::from(&mut a).par_apply(|x| *x = x.exp()); + Zip::from(&mut a).par_for_each(|x| *x = x.exp()); } #[test] -fn test_zip_index_1() { +fn test_zip_index_1() +{ let mut a = Array2::default((10, 10)); - Zip::indexed(&mut a).par_apply(|i, x| { + Zip::indexed(&mut a).par_for_each(|i, x| { *x = i; }); @@ -28,10 +30,11 @@ fn test_zip_index_1() { } #[test] -fn test_zip_index_2() { +fn test_zip_index_2() +{ let mut a = Array2::default((M, N)); - Zip::indexed(&mut a).par_apply(|i, x| { + Zip::indexed(&mut a).par_for_each(|i, x| { *x = i; }); @@ -41,10 +44,11 @@ fn test_zip_index_2() { } #[test] -fn test_zip_index_3() { +fn test_zip_index_3() +{ let mut a = Array::default((1, 2, 1, 2, 3)); - Zip::indexed(&mut a).par_apply(|i, x| { + Zip::indexed(&mut a).par_for_each(|i, x| { *x = i; }); @@ -54,14 +58,17 @@ fn test_zip_index_3() { } #[test] -fn test_zip_index_4() { +fn test_zip_index_4() +{ let mut a = Array2::zeros((M, N)); let mut b = Array2::zeros((M, N)); - Zip::indexed(&mut a).and(&mut b).par_apply(|(i, j), x, y| { - *x = i; - *y = j; - }); + Zip::indexed(&mut a) + .and(&mut b) + .par_for_each(|(i, j), x, y| { + *x = i; + *y = j; + }); for ((i, _), elt) in a.indexed_iter() { assert_eq!(*elt, i); @@ -70,3 +77,76 @@ fn test_zip_index_4() { assert_eq!(*elt, j); } } + +#[test] +#[cfg(feature = "approx")] +fn test_zip_collect() +{ + use approx::assert_abs_diff_eq; + + // test Zip::map_collect and that it preserves c/f layout. + + let b = Array::from_shape_fn((M, N), |(i, j)| 1. / (i + 2 * j + 1) as f32); + let c = Array::from_shape_fn((M, N), |(i, j)| f32::ln((1 + i + j) as f32)); + + { + let a = Zip::from(&b).and(&c).par_map_collect(|x, y| x + y); + + assert_abs_diff_eq!(a, &b + &c, epsilon = 1e-6); + assert_eq!(a.strides(), b.strides()); + } + + { + let b = b.t(); + let c = c.t(); + + let a = Zip::from(&b).and(&c).par_map_collect(|x, y| x + y); + + assert_abs_diff_eq!(a, &b + &c, epsilon = 1e-6); + assert_eq!(a.strides(), b.strides()); + } +} + +#[test] +#[cfg(feature = "approx")] +fn test_zip_small_collect() +{ + use approx::assert_abs_diff_eq; + + for m in 0..32 { + for n in 0..16 { + for &is_f in &[false, true] { + let dim = (m, n).set_f(is_f); + let b = Array::from_shape_fn(dim, |(i, j)| 1. / (i + 2 * j + 1) as f32); + let c = Array::from_shape_fn(dim, |(i, j)| f32::ln((1 + i + j) as f32)); + + { + let a = Zip::from(&b).and(&c).par_map_collect(|x, y| x + y); + + assert_abs_diff_eq!(a, &b + &c, epsilon = 1e-6); + if m > 1 && n > 1 { + assert_eq!(a.strides(), b.strides(), + "Failure for {}x{} c/f: {:?}", m, n, is_f); + } + } + } + } + } +} + +#[test] +#[cfg(feature = "approx")] +fn test_zip_assign_into() +{ + use approx::assert_abs_diff_eq; + + let mut a = Array::::zeros((M, N)); + let b = Array::from_shape_fn((M, N), |(i, j)| 1. / (i + 2 * j + 1) as f32); + let c = Array::from_shape_fn((M, N), |(i, j)| f32::ln((1 + i + j) as f32)); + + Zip::from(&b) + .and(&c) + .par_map_assign_into(&mut a, |x, y| x + y); + + assert_abs_diff_eq!(a, &b + &c, epsilon = 1e-6); +} diff --git a/tests/raw_views.rs b/tests/raw_views.rs index b63e42926..929e969d7 100644 --- a/tests/raw_views.rs +++ b/tests/raw_views.rs @@ -4,7 +4,8 @@ use ndarray::Zip; use std::cell::Cell; #[test] -fn raw_view_cast_cell() { +fn raw_view_cast_cell() +{ // Test .cast() by creating an ArrayView> let mut a = Array::from_shape_fn((10, 5), |(i, j)| (i * j) as f32); @@ -14,13 +15,14 @@ fn raw_view_cast_cell() { let raw_cell_view = a.raw_view_mut().cast::>(); let cell_view = unsafe { raw_cell_view.deref_into_view() }; - Zip::from(cell_view).apply(|elt| elt.set(elt.get() + 1.)); + Zip::from(cell_view).for_each(|elt| elt.set(elt.get() + 1.)); } assert_eq!(a, answer); } #[test] -fn raw_view_cast_reinterpret() { +fn raw_view_cast_reinterpret() +{ // Test .cast() by reinterpreting u16 as [u8; 2] let a = Array::from_shape_fn((5, 5).f(), |(i, j)| (i as u16) << 8 | j as u16); let answer = a.mapv(u16::to_ne_bytes); @@ -31,7 +33,8 @@ fn raw_view_cast_reinterpret() { } #[test] -fn raw_view_cast_zst() { +fn raw_view_cast_zst() +{ struct Zst; let a = Array::<(), _>::default((250, 250)); @@ -42,14 +45,16 @@ fn raw_view_cast_zst() { #[test] #[should_panic] -fn raw_view_invalid_size_cast() { +fn raw_view_invalid_size_cast() +{ let data = [0i32; 16]; ArrayView::from(&data[..]).raw_view().cast::(); } #[test] #[should_panic] -fn raw_view_mut_invalid_size_cast() { +fn raw_view_mut_invalid_size_cast() +{ let mut data = [0i32; 16]; ArrayViewMut::from(&mut data[..]) .raw_view_mut() @@ -57,7 +62,8 @@ fn raw_view_mut_invalid_size_cast() { } #[test] -fn raw_view_misaligned() { +fn raw_view_misaligned() +{ let data: [u16; 2] = [0x0011, 0x2233]; let ptr: *const u16 = data.as_ptr(); unsafe { @@ -69,8 +75,10 @@ fn raw_view_misaligned() { #[test] #[cfg(debug_assertions)] #[should_panic = "The pointer must be aligned."] -fn raw_view_deref_into_view_misaligned() { - fn misaligned_deref(data: &[u16; 2]) -> ArrayView1<'_, u16> { +fn raw_view_deref_into_view_misaligned() +{ + fn misaligned_deref(data: &[u16; 2]) -> ArrayView1<'_, u16> + { let ptr: *const u16 = data.as_ptr(); unsafe { let misaligned_ptr = (ptr as *const u8).add(1) as *const u16; @@ -81,3 +89,20 @@ fn raw_view_deref_into_view_misaligned() { let data: [u16; 2] = [0x0011, 0x2233]; misaligned_deref(&data); } + +#[test] +#[cfg(debug_assertions)] +#[should_panic = "Unsupported"] +fn raw_view_negative_strides() +{ + fn misaligned_deref(data: &[u16; 2]) -> ArrayView1<'_, u16> + { + let ptr: *const u16 = data.as_ptr(); + unsafe { + let raw_view = RawArrayView::from_shape_ptr(1.strides((-1isize) as usize), ptr); + raw_view.deref_into_view() + } + } + let data: [u16; 2] = [0x0011, 0x2233]; + misaligned_deref(&data); +} diff --git a/tests/reserve.rs b/tests/reserve.rs new file mode 100644 index 000000000..108620014 --- /dev/null +++ b/tests/reserve.rs @@ -0,0 +1,70 @@ +use ndarray::prelude::*; + +fn into_raw_vec_capacity(a: Array) -> usize +{ + a.into_raw_vec_and_offset().0.capacity() +} + +#[test] +fn reserve_1d() +{ + let mut a = Array1::::zeros((4,)); + a.reserve(Axis(0), 1000).unwrap(); + assert_eq!(a.shape(), &[4]); + assert!(into_raw_vec_capacity(a) >= 1004); +} + +#[test] +fn reserve_3d() +{ + let mut a = Array3::::zeros((0, 4, 8)); + a.reserve(Axis(0), 10).unwrap(); + assert_eq!(a.shape(), &[0, 4, 8]); + assert!(into_raw_vec_capacity(a) >= 4 * 8 * 10); +} + +#[test] +fn reserve_empty_3d() +{ + let mut a = Array3::::zeros((0, 0, 0)); + a.reserve(Axis(0), 10).unwrap(); +} + +#[test] +fn reserve_3d_axis1() +{ + let mut a = Array3::::zeros((2, 4, 8)); + a.reserve(Axis(1), 10).unwrap(); + assert!(into_raw_vec_capacity(a) >= 2 * 8 * 10); +} + +#[test] +fn reserve_3d_repeat() +{ + let mut a = Array3::::zeros((2, 4, 8)); + a.reserve(Axis(1), 10).unwrap(); + a.reserve(Axis(2), 30).unwrap(); + assert!(into_raw_vec_capacity(a) >= 2 * 4 * 30); +} + +#[test] +fn reserve_2d_with_data() +{ + let mut a = array![[1, 2], [3, 4], [5, 6]]; + a.reserve(Axis(1), 100).unwrap(); + assert_eq!(a, array![[1, 2], [3, 4], [5, 6]]); + assert!(into_raw_vec_capacity(a) >= 3 * 100); +} + +#[test] +fn reserve_2d_inverted_with_data() +{ + let mut a = array![[1, 2], [3, 4], [5, 6]]; + a.invert_axis(Axis(1)); + assert_eq!(a, array![[2, 1], [4, 3], [6, 5]]); + a.reserve(Axis(1), 100).unwrap(); + assert_eq!(a, array![[2, 1], [4, 3], [6, 5]]); + let (raw_vec, offset) = a.into_raw_vec_and_offset(); + assert!(raw_vec.capacity() >= 3 * 100); + assert_eq!(offset, Some(1)); +} diff --git a/tests/reshape.rs b/tests/reshape.rs new file mode 100644 index 000000000..a13a5c05f --- /dev/null +++ b/tests/reshape.rs @@ -0,0 +1,331 @@ +use ndarray::prelude::*; + +use itertools::enumerate; + +use ndarray::Order; + +#[test] +fn reshape() +{ + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.into_shape_with_order((3, 3)); + assert!(u.is_err()); + let u = v.into_shape_with_order((2, 2, 2)); + assert!(u.is_ok()); + let u = u.unwrap(); + assert_eq!(u.shape(), &[2, 2, 2]); + let s = u.into_shape_with_order((4, 2)).unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]])); +} + +#[test] +#[should_panic(expected = "IncompatibleShape")] +fn reshape_error1() +{ + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let _u = v.into_shape_with_order((2, 5)).unwrap(); +} + +#[test] +#[should_panic(expected = "IncompatibleLayout")] +fn reshape_error2() +{ + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let mut u = v.into_shape_with_order((2, 2, 2)).unwrap(); + u.swap_axes(0, 1); + let _s = u.into_shape_with_order((2, 4)).unwrap(); +} + +#[test] +fn reshape_f() +{ + let mut u = Array::zeros((3, 4).f()); + for (i, elt) in enumerate(u.as_slice_memory_order_mut().unwrap()) { + *elt = i as i32; + } + let v = u.view(); + println!("{:?}", v); + + // noop ok + let v2 = v.into_shape_with_order(((3, 4), Order::F)); + assert!(v2.is_ok()); + assert_eq!(v, v2.unwrap()); + + let u = v.into_shape_with_order(((3, 2, 2), Order::F)); + assert!(u.is_ok()); + let u = u.unwrap(); + println!("{:?}", u); + assert_eq!(u.shape(), &[3, 2, 2]); + let s = u.into_shape_with_order(((4, 3), Order::F)).unwrap(); + println!("{:?}", s); + assert_eq!(s.shape(), &[4, 3]); + assert_eq!(s, aview2(&[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]])); +} + +#[test] +fn to_shape_easy() +{ + // 1D -> C -> C + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.to_shape(((3, 3), Order::RowMajor)); + assert!(u.is_err()); + + let u = v.to_shape(((2, 2, 2), Order::C)); + assert!(u.is_ok()); + + let u = u.unwrap(); + assert!(u.is_view()); + assert_eq!(u.shape(), &[2, 2, 2]); + assert_eq!(u, array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + + let s = u.to_shape((4, 2)).unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]])); + + // 1D -> F -> F + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.to_shape(((3, 3), Order::ColumnMajor)); + assert!(u.is_err()); + + let u = v.to_shape(((2, 2, 2), Order::ColumnMajor)); + assert!(u.is_ok()); + + let u = u.unwrap(); + assert!(u.is_view()); + assert_eq!(u.shape(), &[2, 2, 2]); + assert_eq!(u, array![[[1, 5], [3, 7]], [[2, 6], [4, 8]]]); + + let s = u.to_shape(((4, 2), Order::ColumnMajor)).unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, array![[1, 5], [2, 6], [3, 7], [4, 8]]); +} + +#[test] +fn to_shape_copy() +{ + // 1D -> C -> F + let v = ArrayView::from(&[1, 2, 3, 4, 5, 6, 7, 8]); + let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap(); + assert_eq!(u.shape(), &[4, 2]); + assert_eq!(u, array![[1, 2], [3, 4], [5, 6], [7, 8]]); + + let u = u.to_shape(((2, 4), Order::ColumnMajor)).unwrap(); + assert_eq!(u.shape(), &[2, 4]); + assert_eq!(u, array![[1, 5, 2, 6], [3, 7, 4, 8]]); + + // 1D -> F -> C + let v = ArrayView::from(&[1, 2, 3, 4, 5, 6, 7, 8]); + let u = v.to_shape(((4, 2), Order::ColumnMajor)).unwrap(); + assert_eq!(u.shape(), &[4, 2]); + assert_eq!(u, array![[1, 5], [2, 6], [3, 7], [4, 8]]); + + let u = u.to_shape((2, 4)).unwrap(); + assert_eq!(u.shape(), &[2, 4]); + assert_eq!(u, array![[1, 5, 2, 6], [3, 7, 4, 8]]); +} + +#[test] +fn to_shape_add_axis() +{ + // 1D -> C -> C + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap(); + + assert!(u.to_shape(((1, 4, 2), Order::RowMajor)).unwrap().is_view()); + assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_view()); +} + +#[test] +fn to_shape_copy_stride() +{ + let v = array![[1, 2, 3, 4], [5, 6, 7, 8]]; + let vs = v.slice(s![.., ..3]); + let lin1 = vs.to_shape(6).unwrap(); + assert_eq!(lin1, array![1, 2, 3, 5, 6, 7]); + assert!(lin1.is_owned()); + + let lin2 = vs.to_shape((6, Order::ColumnMajor)).unwrap(); + assert_eq!(lin2, array![1, 5, 2, 6, 3, 7]); + assert!(lin2.is_owned()); +} + +#[test] +fn to_shape_zero_len() +{ + let v = array![[1, 2, 3, 4], [5, 6, 7, 8]]; + let vs = v.slice(s![.., ..0]); + let lin1 = vs.to_shape(0).unwrap(); + assert_eq!(lin1, array![]); + assert!(lin1.is_view()); +} + +#[test] +#[should_panic(expected = "IncompatibleShape")] +fn to_shape_error1() +{ + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let _u = v.to_shape((2, 5)).unwrap(); +} + +#[test] +#[should_panic(expected = "IncompatibleShape")] +fn to_shape_error2() +{ + // overflow + let data = [3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let _u = v.to_shape((2, usize::MAX)).unwrap(); +} + +#[test] +fn to_shape_discontig() +{ + for &create_order in &[Order::C, Order::F] { + let a = Array::from_iter(0..64); + let mut a1 = a.to_shape(((4, 4, 4), create_order)).unwrap(); + a1.slice_collapse(s![.., ..;2, ..]); // now shape (4, 2, 4) + assert!(a1.as_slice_memory_order().is_none()); + + for &order in &[Order::C, Order::F] { + let v1 = a1.to_shape(((2, 2, 2, 2, 2), order)).unwrap(); + assert!(v1.is_view()); + let v1 = a1.to_shape(((4, 1, 2, 1, 2, 2), order)).unwrap(); + assert!(v1.is_view()); + let v1 = a1.to_shape(((4, 2, 4), order)).unwrap(); + assert!(v1.is_view()); + let v1 = a1.to_shape(((8, 4), order)).unwrap(); + assert_eq!(v1.is_view(), order == create_order && create_order == Order::C, + "failed for {:?}, {:?}", create_order, order); + let v1 = a1.to_shape(((4, 8), order)).unwrap(); + assert_eq!(v1.is_view(), order == create_order && create_order == Order::F, + "failed for {:?}, {:?}", create_order, order); + let v1 = a1.to_shape((32, order)).unwrap(); + assert!(!v1.is_view()); + } + } +} + +#[test] +fn to_shape_broadcast() +{ + for &create_order in &[Order::C, Order::F] { + let a = Array::from_iter(0..64); + let mut a1 = a.to_shape(((4, 4, 4), create_order)).unwrap(); + a1.slice_collapse(s![.., ..1, ..]); // now shape (4, 1, 4) + let v1 = a1.broadcast((4, 4, 4)).unwrap(); // Now shape (4, 4, 4) + assert!(v1.as_slice_memory_order().is_none()); + + for &order in &[Order::C, Order::F] { + let v2 = v1.to_shape(((2, 2, 2, 2, 2, 2), order)).unwrap(); + assert_eq!(v2.strides(), match (create_order, order) { + (Order::C, Order::C) => { &[32, 16, 0, 0, 2, 1] } + (Order::C, Order::F) => { &[16, 32, 0, 0, 1, 2] } + (Order::F, Order::C) => { &[2, 1, 0, 0, 32, 16] } + (Order::F, Order::F) => { &[1, 2, 0, 0, 16, 32] } + _other => unreachable!() + }); + + let v2 = v1.to_shape(((4, 4, 4), order)).unwrap(); + assert!(v2.is_view()); + let v2 = v1.to_shape(((8, 8), order)).unwrap(); + assert!(v2.is_owned()); + } + } +} + +#[test] +fn into_shape_with_order() +{ + // 1D -> C -> C + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.into_shape_with_order(((3, 3), Order::RowMajor)); + assert!(u.is_err()); + + let u = v.into_shape_with_order(((2, 2, 2), Order::C)); + assert!(u.is_ok()); + + let u = u.unwrap(); + assert_eq!(u.shape(), &[2, 2, 2]); + assert_eq!(u, array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + + let s = u.into_shape_with_order((4, 2)).unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]])); + + // 1D -> F -> F + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.into_shape_with_order(((3, 3), Order::ColumnMajor)); + assert!(u.is_err()); + + let u = v.into_shape_with_order(((2, 2, 2), Order::ColumnMajor)); + assert!(u.is_ok()); + + let u = u.unwrap(); + assert_eq!(u.shape(), &[2, 2, 2]); + assert_eq!(u, array![[[1, 5], [3, 7]], [[2, 6], [4, 8]]]); + + let s = u + .into_shape_with_order(((4, 2), Order::ColumnMajor)) + .unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, array![[1, 5], [2, 6], [3, 7], [4, 8]]); +} + +#[test] +fn into_shape_clone() +{ + // 1D -> C -> C + { + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = Array::from(data.to_vec()); + let u = v.clone().into_shape_clone(((3, 3), Order::RowMajor)); + assert!(u.is_err()); + + let u = v.clone().into_shape_clone(((2, 2, 2), Order::C)); + assert!(u.is_ok()); + + let u = u.unwrap(); + assert_eq!(u.shape(), &[2, 2, 2]); + assert_eq!(u, array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + + let s = u.into_shape_clone((4, 2)).unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]])); + + let u = v.clone().into_shape_clone(((2, 2, 2), Order::F)); + assert!(u.is_ok()); + + let u = u.unwrap(); + assert_eq!(u.shape(), &[2, 2, 2]); + assert_eq!(u, array![[[1, 5], [3, 7]], [[2, 6], [4, 8]]]); + } + + // 1D -> F -> F + { + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = Array::from(data.to_vec()); + let u = v.clone().into_shape_clone(((3, 3), Order::ColumnMajor)); + assert!(u.is_err()); + + let u = v.into_shape_clone(((2, 2, 2), Order::ColumnMajor)); + assert!(u.is_ok()); + + let u = u.unwrap(); + assert_eq!(u.shape(), &[2, 2, 2]); + assert_eq!(u, array![[[1, 5], [3, 7]], [[2, 6], [4, 8]]]); + + let s = u.into_shape_clone(((4, 2), Order::ColumnMajor)).unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, array![[1, 5], [2, 6], [3, 7], [4, 8]]); + } +} diff --git a/tests/s.rs b/tests/s.rs index dbc68184a..edb3f071a 100644 --- a/tests/s.rs +++ b/tests/s.rs @@ -1,14 +1,12 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names )] use ndarray::{s, Array}; #[test] -fn test_s() { +fn test_s() +{ let a = Array::::zeros((3, 4)); let vi = a.slice(s![1.., ..;2]); assert_eq!(vi.shape(), &[2, 2]); diff --git a/tests/stacking.rs b/tests/stacking.rs index a9a031711..bdfe478b4 100644 --- a/tests/stacking.rs +++ b/tests/stacking.rs @@ -1,26 +1,55 @@ -use ndarray::{arr2, aview1, stack, Array2, Axis, ErrorKind}; +use ndarray::{arr2, arr3, aview1, aview2, concatenate, stack, Array2, Axis, ErrorKind, Ix1}; #[test] -fn stacking() { +fn concatenating() +{ let a = arr2(&[[2., 2.], [3., 3.]]); - let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap(); + let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap(); assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]])); - let c = stack![Axis(0), a, b]; + let c = concatenate![Axis(0), a, b]; assert_eq!( c, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]]) ); - let d = stack![Axis(0), a.row(0), &[9., 9.]]; + let d = concatenate![Axis(0), a.row(0), &[9., 9.]]; assert_eq!(d, aview1(&[2., 2., 9., 9.])); + let d = concatenate![Axis(1), a.row(0).insert_axis(Axis(1)), aview1(&[9., 9.]).insert_axis(Axis(1))]; + assert_eq!(d, aview2(&[[2., 9.], + [2., 9.]])); + + let d = concatenate![Axis(0), a.row(0).insert_axis(Axis(1)), aview1(&[9., 9.]).insert_axis(Axis(1))]; + assert_eq!(d, aview2(&[[2.], [2.], [9.], [9.]])); + + let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]); + assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); + + let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]); + assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); + + let res: Result, _> = ndarray::concatenate(Axis(0), &[]); + assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); +} + +#[test] +fn stacking() +{ + let a = arr2(&[[2., 2.], [3., 3.]]); + let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap(); + assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]])); + + let c = stack![Axis(0), a, a]; + assert_eq!(c, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]])); + + let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]); let res = ndarray::stack(Axis(1), &[a.view(), c.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); - let res = ndarray::stack(Axis(2), &[a.view(), c.view()]); + let res = ndarray::stack(Axis(3), &[a.view(), a.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); - let res: Result, _> = ndarray::stack(Axis(0), &[]); + let res: Result, _> = ndarray::stack::<_, Ix1>(Axis(0), &[]); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } diff --git a/tests/views.rs b/tests/views.rs new file mode 100644 index 000000000..02970b1b7 --- /dev/null +++ b/tests/views.rs @@ -0,0 +1,17 @@ +use ndarray::prelude::*; +use ndarray::Zip; + +#[test] +fn cell_view() +{ + let mut a = Array::from_shape_fn((10, 5), |(i, j)| (i * j) as f32); + let answer = &a + 1.; + + { + let cv1 = a.cell_view(); + let cv2 = cv1; + + Zip::from(cv1).and(cv2).for_each(|a, b| a.set(b.get() + 1.)); + } + assert_eq!(a, answer); +} diff --git a/tests/windows.rs b/tests/windows.rs index eb1a33411..6506d8301 100644 --- a/tests/windows.rs +++ b/tests/windows.rs @@ -1,13 +1,9 @@ #![allow( - clippy::many_single_char_names, - clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names )] use ndarray::prelude::*; -use ndarray::Zip; -use std::iter::FromIterator; +use ndarray::{arr3, Zip}; // Edge Cases for Windows iterator: // @@ -26,26 +22,31 @@ use std::iter::FromIterator; /// Test that verifies the `Windows` iterator panics on window sizes equal to zero. #[test] #[should_panic] -fn windows_iterator_zero_size() { - let a = Array::from_iter(10..37).into_shape((3, 3, 3)).unwrap(); +fn windows_iterator_zero_size() +{ + let a = Array::from_iter(10..37) + .into_shape_with_order((3, 3, 3)) + .unwrap(); a.windows(Dim((0, 0, 0))); } -/// Test that verifites that no windows are yielded on oversized window sizes. +/// Test that verifies that no windows are yielded on oversized window sizes. #[test] -fn windows_iterator_oversized() { - let a = Array::from_iter(10..37).into_shape((3, 3, 3)).unwrap(); +fn windows_iterator_oversized() +{ + let a = Array::from_iter(10..37) + .into_shape_with_order((3, 3, 3)) + .unwrap(); let mut iter = a.windows((4, 3, 2)).into_iter(); // (4,3,2) doesn't fit into (3,3,3) => oversized! assert_eq!(iter.next(), None); } /// Simple test for iterating 1d-arrays via `Windows`. #[test] -fn windows_iterator_1d() { - let a = Array::from_iter(10..20).into_shape(10).unwrap(); - itertools::assert_equal( - a.windows(Dim(4)), - vec![ +fn windows_iterator_1d() +{ + let a = Array::from_iter(10..20).into_shape_with_order(10).unwrap(); + itertools::assert_equal(a.windows(Dim(4)), vec![ arr1(&[10, 11, 12, 13]), arr1(&[11, 12, 13, 14]), arr1(&[12, 13, 14, 15]), @@ -53,17 +54,17 @@ fn windows_iterator_1d() { arr1(&[14, 15, 16, 17]), arr1(&[15, 16, 17, 18]), arr1(&[16, 17, 18, 19]), - ], - ); + ]); } /// Simple test for iterating 2d-arrays via `Windows`. #[test] -fn windows_iterator_2d() { - let a = Array::from_iter(10..30).into_shape((5, 4)).unwrap(); - itertools::assert_equal( - a.windows(Dim((3, 2))), - vec![ +fn windows_iterator_2d() +{ + let a = Array::from_iter(10..30) + .into_shape_with_order((5, 4)) + .unwrap(); + itertools::assert_equal(a.windows(Dim((3, 2))), vec![ arr2(&[[10, 11], [14, 15], [18, 19]]), arr2(&[[11, 12], [15, 16], [19, 20]]), arr2(&[[12, 13], [16, 17], [20, 21]]), @@ -73,18 +74,17 @@ fn windows_iterator_2d() { arr2(&[[18, 19], [22, 23], [26, 27]]), arr2(&[[19, 20], [23, 24], [27, 28]]), arr2(&[[20, 21], [24, 25], [28, 29]]), - ], - ); + ]); } /// Simple test for iterating 3d-arrays via `Windows`. #[test] -fn windows_iterator_3d() { - use ndarray::arr3; - let a = Array::from_iter(10..37).into_shape((3, 3, 3)).unwrap(); - itertools::assert_equal( - a.windows(Dim((2, 2, 2))), - vec![ +fn windows_iterator_3d() +{ + let a = Array::from_iter(10..37) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + itertools::assert_equal(a.windows(Dim((2, 2, 2))), vec![ arr3(&[[[10, 11], [13, 14]], [[19, 20], [22, 23]]]), arr3(&[[[11, 12], [14, 15]], [[20, 21], [23, 24]]]), arr3(&[[[13, 14], [16, 17]], [[22, 23], [25, 26]]]), @@ -93,18 +93,91 @@ fn windows_iterator_3d() { arr3(&[[[20, 21], [23, 24]], [[29, 30], [32, 33]]]), arr3(&[[[22, 23], [25, 26]], [[31, 32], [34, 35]]]), arr3(&[[[23, 24], [26, 27]], [[32, 33], [35, 36]]]), - ], - ); + ]); } +/// Test that verifies the `Windows` iterator panics when stride has an axis equal to zero. #[test] -fn test_window_zip() { - let a = Array::from_iter(0..64).into_shape((4, 4, 4)).unwrap(); +#[should_panic] +fn windows_iterator_stride_axis_zero() +{ + let a = Array::from_iter(10..37) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + a.windows_with_stride((2, 2, 2), (0, 2, 2)); +} + +/// Test that verifies that only first window is yielded when stride is oversized on every axis. +#[test] +fn windows_iterator_only_one_valid_window_for_oversized_stride() +{ + let a = Array::from_iter(10..135) + .into_shape_with_order((5, 5, 5)) + .unwrap(); + let mut iter = a.windows_with_stride((2, 2, 2), (8, 8, 8)).into_iter(); // (4,3,2) doesn't fit into (3,3,3) => oversized! + itertools::assert_equal(iter.next(), Some(arr3(&[[[10, 11], [15, 16]], [[35, 36], [40, 41]]]))); +} + +/// Simple test for iterating 1d-arrays via `Windows` with stride. +#[test] +fn windows_iterator_1d_with_stride() +{ + let a = Array::from_iter(10..20).into_shape_with_order(10).unwrap(); + itertools::assert_equal(a.windows_with_stride(4, 2), vec![ + arr1(&[10, 11, 12, 13]), + arr1(&[12, 13, 14, 15]), + arr1(&[14, 15, 16, 17]), + arr1(&[16, 17, 18, 19]), + ]); +} + +/// Simple test for iterating 2d-arrays via `Windows` with stride. +#[test] +fn windows_iterator_2d_with_stride() +{ + let a = Array::from_iter(10..30) + .into_shape_with_order((5, 4)) + .unwrap(); + itertools::assert_equal(a.windows_with_stride((3, 2), (2, 1)), vec![ + arr2(&[[10, 11], [14, 15], [18, 19]]), + arr2(&[[11, 12], [15, 16], [19, 20]]), + arr2(&[[12, 13], [16, 17], [20, 21]]), + arr2(&[[18, 19], [22, 23], [26, 27]]), + arr2(&[[19, 20], [23, 24], [27, 28]]), + arr2(&[[20, 21], [24, 25], [28, 29]]), + ]); +} + +/// Simple test for iterating 3d-arrays via `Windows` with stride. +#[test] +fn windows_iterator_3d_with_stride() +{ + let a = Array::from_iter(10..74) + .into_shape_with_order((4, 4, 4)) + .unwrap(); + itertools::assert_equal(a.windows_with_stride((2, 2, 2), (2, 2, 2)), vec![ + arr3(&[[[10, 11], [14, 15]], [[26, 27], [30, 31]]]), + arr3(&[[[12, 13], [16, 17]], [[28, 29], [32, 33]]]), + arr3(&[[[18, 19], [22, 23]], [[34, 35], [38, 39]]]), + arr3(&[[[20, 21], [24, 25]], [[36, 37], [40, 41]]]), + arr3(&[[[42, 43], [46, 47]], [[58, 59], [62, 63]]]), + arr3(&[[[44, 45], [48, 49]], [[60, 61], [64, 65]]]), + arr3(&[[[50, 51], [54, 55]], [[66, 67], [70, 71]]]), + arr3(&[[[52, 53], [56, 57]], [[68, 69], [72, 73]]]), + ]); +} + +#[test] +fn test_window_zip() +{ + let a = Array::from_iter(0..64) + .into_shape_with_order((4, 4, 4)) + .unwrap(); for x in 1..4 { for y in 1..4 { for z in 1..4 { - Zip::indexed(a.windows((x, y, z))).apply(|(i, j, k), window| { + Zip::indexed(a.windows((x, y, z))).for_each(|(i, j, k), window| { let x = x as isize; let y = y as isize; let z = z as isize; @@ -117,3 +190,162 @@ fn test_window_zip() { } } } + +/// Test verifies that non existent Axis results in panic +#[test] +#[should_panic] +fn axis_windows_outofbound() +{ + let a = Array::from_iter(10..37) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + a.axis_windows(Axis(4), 2); +} + +/// Test verifies that zero sizes results in panic +#[test] +#[should_panic] +fn axis_windows_zero_size() +{ + let a = Array::from_iter(10..37) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + a.axis_windows(Axis(0), 0); +} + +/// Test verifies that over sized windows yield nothing +#[test] +fn axis_windows_oversized() +{ + let a = Array::from_iter(10..37) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + let mut iter = a.axis_windows(Axis(2), 4).into_iter(); + assert_eq!(iter.next(), None); +} + +/// Simple test for iterating 1d-arrays via `Axis Windows`. +#[test] +fn test_axis_windows_1d() +{ + let a = Array::from_iter(10..20).into_shape_with_order(10).unwrap(); + + itertools::assert_equal(a.axis_windows(Axis(0), 5), vec![ + arr1(&[10, 11, 12, 13, 14]), + arr1(&[11, 12, 13, 14, 15]), + arr1(&[12, 13, 14, 15, 16]), + arr1(&[13, 14, 15, 16, 17]), + arr1(&[14, 15, 16, 17, 18]), + arr1(&[15, 16, 17, 18, 19]), + ]); +} + +/// Simple test for iterating 2d-arrays via `Axis Windows`. +#[test] +fn test_axis_windows_2d() +{ + let a = Array::from_iter(10..30) + .into_shape_with_order((5, 4)) + .unwrap(); + + itertools::assert_equal(a.axis_windows(Axis(0), 2), vec![ + arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]), + arr2(&[[14, 15, 16, 17], [18, 19, 20, 21]]), + arr2(&[[18, 19, 20, 21], [22, 23, 24, 25]]), + arr2(&[[22, 23, 24, 25], [26, 27, 28, 29]]), + ]); +} + +/// Simple test for iterating 3d-arrays via `Axis Windows`. +#[test] +fn test_axis_windows_3d() +{ + let a = Array::from_iter(0..27) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + + itertools::assert_equal(a.axis_windows(Axis(1), 2), vec![ + arr3(&[ + [[0, 1, 2], [3, 4, 5]], + [[9, 10, 11], [12, 13, 14]], + [[18, 19, 20], [21, 22, 23]], + ]), + arr3(&[ + [[3, 4, 5], [6, 7, 8]], + [[12, 13, 14], [15, 16, 17]], + [[21, 22, 23], [24, 25, 26]], + ]), + ]); +} + +#[test] +fn tests_axis_windows_3d_zips_with_1d() +{ + let a = Array::from_iter(0..27) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + let mut b = Array::zeros(2); + + Zip::from(b.view_mut()) + .and(a.axis_windows(Axis(1), 2)) + .for_each(|b, a| { + *b = a.sum(); + }); + assert_eq!(b,arr1(&[207, 261])); +} + +#[test] +fn test_window_neg_stride() +{ + let array = Array::from_iter(1..10) + .into_shape_with_order((3, 3)) + .unwrap(); + + // window neg/pos stride combinations + + // Make a 2 x 2 array of the windows of the 3 x 3 array + // and compute test answers from here + let mut answer = Array::from_iter(array.windows((2, 2)).into_iter().map(|a| a.to_owned())) + .into_shape_with_order((2, 2)) + .unwrap(); + + answer.invert_axis(Axis(1)); + answer.map_inplace(|a| a.invert_axis(Axis(1))); + + itertools::assert_equal(array.slice(s![.., ..;-1]).windows((2, 2)), answer.iter()); + + answer.invert_axis(Axis(0)); + answer.map_inplace(|a| a.invert_axis(Axis(0))); + + itertools::assert_equal(array.slice(s![..;-1, ..;-1]).windows((2, 2)), answer.iter()); + + answer.invert_axis(Axis(1)); + answer.map_inplace(|a| a.invert_axis(Axis(1))); + + itertools::assert_equal(array.slice(s![..;-1, ..]).windows((2, 2)), answer.iter()); +} + +#[test] +fn test_windows_with_stride_on_inverted_axis() +{ + let mut array = Array::from_iter(1..17) + .into_shape_with_order((4, 4)) + .unwrap(); + + // inverting axis results in negative stride + array.invert_axis(Axis(0)); + itertools::assert_equal(array.windows_with_stride((2, 2), (2, 2)), vec![ + arr2(&[[13, 14], [9, 10]]), + arr2(&[[15, 16], [11, 12]]), + arr2(&[[5, 6], [1, 2]]), + arr2(&[[7, 8], [3, 4]]), + ]); + + array.invert_axis(Axis(1)); + itertools::assert_equal(array.windows_with_stride((2, 2), (2, 2)), vec![ + arr2(&[[16, 15], [12, 11]]), + arr2(&[[14, 13], [10, 9]]), + arr2(&[[8, 7], [4, 3]]), + arr2(&[[6, 5], [2, 1]]), + ]); +} diff --git a/tests/zst.rs b/tests/zst.rs index c3c779d2c..f5f2c8e32 100644 --- a/tests/zst.rs +++ b/tests/zst.rs @@ -2,7 +2,8 @@ use ndarray::arr2; use ndarray::ArcArray; #[test] -fn test_swap() { +fn test_swap() +{ let mut a = arr2(&[[(); 3]; 3]); let b = a.clone(); @@ -16,7 +17,8 @@ fn test_swap() { } #[test] -fn test() { +fn test() +{ let c = ArcArray::<(), _>::default((3, 4)); let mut d = c.clone(); for _ in d.iter_mut() {}

IntoNdProducer for P -where - P: NdProducer, +trait ZippableTuple: Sized { - type Item = P::Item; - type Dim = P::Dim; - type Output = Self; - fn into_producer(self) -> Self::Output { - self - } -} - -/// A producer of an n-dimensional set of elements; -/// for example an array view, mutable array view or an iterator -/// that yields chunks. -/// -/// Producers are used as a arguments to [`Zip`](struct.Zip.html) and -/// [`azip!()`](macro.azip.html). -/// -/// # Comparison to `IntoIterator` -/// -/// Most `NdProducers` are *iterable* (implement `IntoIterator`) but not directly -/// iterators. This separation is needed because the producer represents -/// a multidimensional set of items, it can be split along a particular axis for -/// parallelization, and it has no fixed correspondance to a sequence. -/// -/// The natural exception is one dimensional producers, like `AxisIter`, which -/// implement `Iterator` directly -/// (`AxisIter` traverses a one dimensional sequence, along an axis, while -/// *producing* multidimensional items). -/// -/// See also [`IntoNdProducer`](trait.IntoNdProducer.html) -pub trait NdProducer { - /// The element produced per iteration. - type Item; - // Internal use / Pointee type - /// Dimension type - type Dim: Dimension; - - // The pointer Ptr is used by an array view to simply point to the - // current element. It doesn't have to be a pointer (see Indices). - // Its main function is that it can be incremented with a particular - // stride (= along a particular axis) - #[doc(hidden)] - /// Pointer or stand-in for pointer - type Ptr: Offset; - #[doc(hidden)] - /// Pointer stride - type Stride: Copy; - - #[doc(hidden)] - fn layout(&self) -> Layout; - #[doc(hidden)] - fn raw_dim(&self) -> Self::Dim; - #[doc(hidden)] - fn equal_dim(&self, dim: &Self::Dim) -> bool { - self.raw_dim() == *dim - } - #[doc(hidden)] - fn as_ptr(&self) -> Self::Ptr; - #[doc(hidden)] - unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item; - #[doc(hidden)] - unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr; - #[doc(hidden)] - fn stride_of(&self, axis: Axis) -> ::Stride; - #[doc(hidden)] - fn contiguous_stride(&self) -> Self::Stride; - #[doc(hidden)] - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) - where - Self: Sized; - private_decl! {} -} - -pub trait Offset: Copy { - type Stride: Copy; - unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self; - private_decl! {} -} - -impl Offset for *const T { - type Stride = isize; - unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self { - self.offset(s * (index as isize)) - } - private_impl! {} -} - -impl Offset for *mut T { - type Stride = isize; - unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self { - self.offset(s * (index as isize)) - } - private_impl! {} -} - -trait ZippableTuple: Sized { type Item; type Ptr: OffsetTuple + Copy; type Dim: Dimension; @@ -222,288 +116,6 @@ trait ZippableTuple: Sized { fn split_at(self, axis: Axis, index: usize) -> (Self, Self); } -/// An array reference is an n-dimensional producer of element references -/// (like ArrayView). -impl<'a, A: 'a, S, D> IntoNdProducer for &'a ArrayBase -where - D: Dimension, - S: Data, -{ - type Item = &'a A; - type Dim = D; - type Output = ArrayView<'a, A, D>; - fn into_producer(self) -> Self::Output { - self.view() - } -} - -/// A mutable array reference is an n-dimensional producer of mutable element -/// references (like ArrayViewMut). -impl<'a, A: 'a, S, D> IntoNdProducer for &'a mut ArrayBase -where - D: Dimension, - S: DataMut, -{ - type Item = &'a mut A; - type Dim = D; - type Output = ArrayViewMut<'a, A, D>; - fn into_producer(self) -> Self::Output { - self.view_mut() - } -} - -/// A slice is a one-dimensional producer -impl<'a, A: 'a> IntoNdProducer for &'a [A] { - type Item = ::Item; - type Dim = Ix1; - type Output = ArrayView1<'a, A>; - fn into_producer(self) -> Self::Output { - <_>::from(self) - } -} - -/// A mutable slice is a mutable one-dimensional producer -impl<'a, A: 'a> IntoNdProducer for &'a mut [A] { - type Item = ::Item; - type Dim = Ix1; - type Output = ArrayViewMut1<'a, A>; - fn into_producer(self) -> Self::Output { - <_>::from(self) - } -} - -/// A Vec is a one-dimensional producer -impl<'a, A: 'a> IntoNdProducer for &'a Vec { - type Item = ::Item; - type Dim = Ix1; - type Output = ArrayView1<'a, A>; - fn into_producer(self) -> Self::Output { - <_>::from(self) - } -} - -/// A mutable Vec is a mutable one-dimensional producer -impl<'a, A: 'a> IntoNdProducer for &'a mut Vec { - type Item = ::Item; - type Dim = Ix1; - type Output = ArrayViewMut1<'a, A>; - fn into_producer(self) -> Self::Output { - <_>::from(self) - } -} - -impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> { - type Item = &'a A; - type Dim = D; - type Ptr = *mut A; - type Stride = isize; - - private_impl! {} - #[doc(hidden)] - fn raw_dim(&self) -> Self::Dim { - self.raw_dim() - } - - #[doc(hidden)] - fn equal_dim(&self, dim: &Self::Dim) -> bool { - self.dim.equal(dim) - } - - #[doc(hidden)] - fn as_ptr(&self) -> *mut A { - self.as_ptr() as _ - } - - #[doc(hidden)] - fn layout(&self) -> Layout { - self.layout_impl() - } - - #[doc(hidden)] - unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item { - &*ptr - } - - #[doc(hidden)] - unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A { - self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) - } - - #[doc(hidden)] - fn stride_of(&self, axis: Axis) -> isize { - self.stride_of(axis) - } - - #[inline(always)] - fn contiguous_stride(&self) -> Self::Stride { - 1 - } - - #[doc(hidden)] - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { - self.split_at(axis, index) - } -} - -impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> { - type Item = &'a mut A; - type Dim = D; - type Ptr = *mut A; - type Stride = isize; - - private_impl! {} - #[doc(hidden)] - fn raw_dim(&self) -> Self::Dim { - self.raw_dim() - } - - #[doc(hidden)] - fn equal_dim(&self, dim: &Self::Dim) -> bool { - self.dim.equal(dim) - } - - #[doc(hidden)] - fn as_ptr(&self) -> *mut A { - self.as_ptr() as _ - } - - #[doc(hidden)] - fn layout(&self) -> Layout { - self.layout_impl() - } - - #[doc(hidden)] - unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item { - &mut *ptr - } - - #[doc(hidden)] - unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A { - self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) - } - - #[doc(hidden)] - fn stride_of(&self, axis: Axis) -> isize { - self.stride_of(axis) - } - - #[inline(always)] - fn contiguous_stride(&self) -> Self::Stride { - 1 - } - - #[doc(hidden)] - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { - self.split_at(axis, index) - } -} - -impl NdProducer for RawArrayView { - type Item = *const A; - type Dim = D; - type Ptr = *const A; - type Stride = isize; - - private_impl! {} - #[doc(hidden)] - fn raw_dim(&self) -> Self::Dim { - self.raw_dim() - } - - #[doc(hidden)] - fn equal_dim(&self, dim: &Self::Dim) -> bool { - self.dim.equal(dim) - } - - #[doc(hidden)] - fn as_ptr(&self) -> *const A { - self.as_ptr() - } - - #[doc(hidden)] - fn layout(&self) -> Layout { - self.layout_impl() - } - - #[doc(hidden)] - unsafe fn as_ref(&self, ptr: *const A) -> *const A { - ptr - } - - #[doc(hidden)] - unsafe fn uget_ptr(&self, i: &Self::Dim) -> *const A { - self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) - } - - #[doc(hidden)] - fn stride_of(&self, axis: Axis) -> isize { - self.stride_of(axis) - } - - #[inline(always)] - fn contiguous_stride(&self) -> Self::Stride { - 1 - } - - #[doc(hidden)] - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { - self.split_at(axis, index) - } -} - -impl NdProducer for RawArrayViewMut { - type Item = *mut A; - type Dim = D; - type Ptr = *mut A; - type Stride = isize; - - private_impl! {} - #[doc(hidden)] - fn raw_dim(&self) -> Self::Dim { - self.raw_dim() - } - - #[doc(hidden)] - fn equal_dim(&self, dim: &Self::Dim) -> bool { - self.dim.equal(dim) - } - - #[doc(hidden)] - fn as_ptr(&self) -> *mut A { - self.as_ptr() as _ - } - - #[doc(hidden)] - fn layout(&self) -> Layout { - self.layout_impl() - } - - #[doc(hidden)] - unsafe fn as_ref(&self, ptr: *mut A) -> *mut A { - ptr - } - - #[doc(hidden)] - unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A { - self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) - } - - #[doc(hidden)] - fn stride_of(&self, axis: Axis) -> isize { - self.stride_of(axis) - } - - #[inline(always)] - fn contiguous_stride(&self) -> Self::Stride { - 1 - } - - #[doc(hidden)] - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { - self.split_at(axis, index) - } -} - /// Lock step function application across several arrays or other producers. /// /// Zip allows matching several producers to each other elementwise and applying @@ -511,7 +123,7 @@ impl NdProducer for RawArrayViewMut { /// a time). /// /// In general, the zip uses a tuple of producers -/// ([`NdProducer`](trait.NdProducer.html) trait) that all have to be of the +/// ([`NdProducer`] trait) that all have to be of the /// same shape. The NdProducer implementation defines what its item type is /// (for example if it's a shared reference, mutable reference or an array /// view etc). @@ -522,15 +134,13 @@ impl NdProducer for RawArrayViewMut { /// The order elements are visited is not specified. The producers don’t have to /// have the same item type. /// -/// The `Zip` has two methods for function application: `apply` and +/// The `Zip` has two methods for function application: `for_each` and /// `fold_while`. The zip object can be split, which allows parallelization. /// A read-only zip object (no mutable producers) can be cloned. /// -/// See also the [`azip!()` macro][az] which offers a convenient shorthand +/// See also the [`azip!()`] which offers a convenient shorthand /// to common ways to use `Zip`. /// -/// [az]: macro.azip.html -/// /// ``` /// use ndarray::Zip; /// use ndarray::Array2; @@ -550,7 +160,7 @@ impl NdProducer for RawArrayViewMut { /// .and(&b) /// .and(&c) /// .and(&d) -/// .apply(|w, &x, &y, &z| { +/// .for_each(|w, &x, &y, &z| { /// *w += x + y * z; /// }); /// @@ -566,17 +176,30 @@ impl NdProducer for RawArrayViewMut { /// let mut totals = Array1::zeros(a.nrows()); /// /// Zip::from(&mut totals) -/// .and(a.genrows()) -/// .apply(|totals, row| *totals = row.sum()); +/// .and(a.rows()) +/// .for_each(|totals, row| *totals = row.sum()); /// /// // Check the result against the built in `.sum_axis()` along axis 1. /// assert_eq!(totals, a.sum_axis(Axis(1))); +/// +/// +/// // Example 3: Recreate Example 2 using map_collect to make a new array +/// +/// let totals2 = Zip::from(a.rows()).map_collect(|row| row.sum()); +/// +/// // Check the result against the previous example. +/// assert_eq!(totals, totals2); /// ``` #[derive(Debug, Clone)] -pub struct Zip { +#[must_use = "zipping producers is lazy and does nothing unless consumed"] +pub struct Zip +{ parts: Parts, dimension: D, layout: Layout, + /// The sum of the layout tendencies of the parts; + /// positive for c- and negative for f-layout preference. + layout_tendency: i32, } impl Zip<(P,), D> @@ -589,15 +212,16 @@ where /// The Zip will take the exact dimension of `p` and all inputs /// must have the same dimensions (or be broadcast to them). pub fn from(p: IP) -> Self - where - IP: IntoNdProducer, + where IP: IntoNdProducer { let array = p.into_producer(); let dim = array.raw_dim(); + let layout = array.layout(); Zip { dimension: dim, - layout: array.layout(), + layout, parts: (array,), + layout_tendency: layout.tendency(), } } } @@ -613,8 +237,7 @@ where /// /// *Note:* Indexed zip has overhead. pub fn indexed(p: IP) -> Self - where - IP: IntoNdProducer, + where IP: IntoNdProducer { let array = p.into_producer(); let dim = array.raw_dim(); @@ -622,91 +245,120 @@ where } } -impl Zip +#[inline] +fn zip_dimension_check(dimension: &D, part: &P) where D: Dimension, + P: NdProducer, { - fn check

//! -//! [`a * b`, `a + b`, etc.](../../struct.ArrayBase.html#arithmetic-operations) +//! [`a * b`, `a + b`, etc.](ArrayBase#arithmetic-operations) //! //! //! @@ -532,26 +532,29 @@ //! ------|-----------|------ //! `a[:] = 3.` | [`a.fill(3.)`][.fill()] | set all array elements to the same scalar value //! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a` -//! `np.concatenate((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), &[a.view(), b.view()])`][stack()] | concatenate arrays `a` and `b` along axis 1 -//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1 +//! `a[:5, 2] = 3.` | [`a.slice_mut(s![..5, 2]).fill(3.)`][.fill()] | set a portion of the array to the same scalar value +//! `a[:5, 2] = b` | [`a.slice_mut(s![..5, 2]).assign(&b)`][.assign()] | copy the data from array `b` into part of array `a` +//! `np.concatenate((a,b), axis=1)` | [`concatenate![Axis(1), a, b]`][concatenate!] or [`concatenate(Axis(1), &[a.view(), b.view()])`][concatenate()] | concatenate arrays `a` and `b` along axis 1 +//! `np.stack((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), vec![a.view(), b.view()])`][stack()] | stack arrays `a` and `b` along axis 1 +//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.slice(s![.., NewAxis])`][.slice()] or [`a.insert_axis(Axis(1))`][.insert_axis()] | create an view of 1-D array `a`, inserting a new axis 1 //! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`) //! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a` -//! `a.flatten()` | [`Array::from_iter(a.iter())`][::from_iter()] | create a 1-D array by flattening `a` +//! `a.flatten()` | [`use std::iter::FromIterator; Array::from_iter(a.iter().cloned())`][::from_iter()] | create a 1-D array by flattening `a` //! //! ## Iteration //! //! `ndarray` has lots of interesting iterators/producers that implement the -//! [`NdProducer`][NdProducer] trait, which is a generalization of `Iterator` +//! [`NdProducer`](crate::NdProducer) trait, which is a generalization of `Iterator` //! to multiple dimensions. This makes it possible to correctly and efficiently //! zip together slices/subviews of arrays in multiple dimensions with -//! [`Zip`][Zip] or [`azip!()`][azip!]. The purpose of this is similar to +//! [`Zip`] or [`azip!()`]. The purpose of this is similar to //! [`np.nditer`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.nditer.html), -//! but [`Zip`][Zip] is implemented and used somewhat differently. +//! but [`Zip`] is implemented and used somewhat differently. //! //! This table lists some of the iterators/producers which have a direct //! equivalent in NumPy. For a more complete introduction to producers and //! iterators, see [*Loops, Producers, and -//! Iterators*](../../struct.ArrayBase.html#loops-producers-and-iterators). +//! Iterators*](ArrayBase#loops-producers-and-iterators). //! Note that there are also variants of these iterators (with a `_mut` suffix) //! that yield `ArrayViewMut` instead of `ArrayView`. //! @@ -561,6 +564,102 @@ //! `np.ndenumerate(a)` | [`a.indexed_iter()`][.indexed_iter()] | flat iterator yielding the index along with each element reference //! `iter(a)` | [`a.outer_iter()`][.outer_iter()] or [`a.axis_iter(Axis(0))`][.axis_iter()] | iterator over the first (outermost) axis, yielding each subview //! +//! ## Type conversions +//! +//! In `ndarray`, conversions between datatypes are done with `mapv()` by +//! passing a closure to convert every element independently. +//! For the conversion itself, we have several options: +//! - `std::convert::From` ensures lossless, safe conversions at compile-time +//! and is generally recommended. +//! - `std::convert::TryFrom` can be used for potentially unsafe conversions. It +//! will return a `Result` which can be handled or `unwrap()`ed to panic if +//! any value at runtime cannot be converted losslessly. +//! - The `as` keyword compiles to lossless/lossy conversions depending on the +//! source and target datatypes. It can be useful when `TryFrom` is a +//! performance issue or does not apply. A notable difference to NumPy is that +//! `as` performs a [*saturating* cast][sat_conv] when casting +//! from floats to integers. Further information can be found in the +//! [reference on type cast expressions][as_typecast]. +//! +//! For details, be sure to check out the type conversion examples. +//! + +//! +//! +//! +//! +//! +//! +//! +//! +//! +//! +//!
+//! +//! NumPy +//! +//! +//! +//! `ndarray` +//! +//! +//! +//! Notes +//! +//!
+//! +//! `a.astype(np.float32)` +//! +//! +//! +//! `a.mapv(f32::from)` +//! +//! +//! +//! convert `u8` array infallibly to `f32` array with `std::convert::From`, generally recommended +//! +//!
+//! +//! `a.astype(np.int32)` +//! +//! +//! +//! `a.mapv(i32::from)` +//! +//! +//! +//! upcast `u8` array to `i32` array with `std::convert::From`, preferable over `as` because it ensures at compile-time that the conversion is lossless +//! +//!
+//! +//! `a.astype(np.uint8)` +//! +//! +//! +//! `a.mapv(|x| u8::try_from(x).unwrap())` +//! +//! +//! +//! try to convert `i8` array to `u8` array, panic if any value cannot be converted lossless at runtime (e.g. negative value) +//! +//!
+//! +//! `a.astype(np.int32)` +//! +//! +//! +//! `a.mapv(|x| x as i32)` +//! +//! +//! +//! convert `f32` array to `i32` array with ["saturating" conversion][sat_conv]; care needed because it can be a lossy conversion or result in non-finite values! See [the reference for information][as_typecast]. +//! +//!
+//! +//! [as_conv]: https://doc.rust-lang.org/rust-by-example/types/cast.html +//! [sat_conv]: https://blog.rust-lang.org/2020/07/16/Rust-1.45.0.html#fixing-unsoundness-in-casts +//! [as_typecast]: https://doc.rust-lang.org/reference/expressions/operator-expr.html#type-cast-expressions +//! //! ## Convenience methods for 2-D arrays //! //! NumPy | `ndarray` | Notes @@ -571,87 +670,79 @@ //! `a[:,4]` | [`a.column(4)`][.column()] or [`a.column_mut(4)`][.column_mut()] | view (or mutable view) of column 4 in a 2-D array //! `a.shape[0] == a.shape[1]` | [`a.is_square()`][.is_square()] | check if the array is square //! -//! [.abs_diff_eq()]: ../../struct.ArrayBase.html#impl-AbsDiffEq> -//! [ArcArray]: ../../type.ArcArray.html -//! [arr2()]: ../../fn.arr2.html -//! [array!]: ../../macro.array.html -//! [Array]: ../../type.Array.html -//! [Array2]: ../../type.Array2.html -//! [ArrayBase]: ../../struct.ArrayBase.html -//! [ArrayView]: ../../type.ArrayView.html -//! [ArrayViewMut]: ../../type.ArrayViewMut.html -//! [.assign()]: ../../struct.ArrayBase.html#method.assign -//! [.axis_iter()]: ../../struct.ArrayBase.html#method.axis_iter -//! [azip!]: ../../macro.azip.html -//! [.ncols()]: ../../struct.ArrayBase.html#method.ncols -//! [.column()]: ../../struct.ArrayBase.html#method.column -//! [.column_mut()]: ../../struct.ArrayBase.html#method.column_mut -//! [CowArray]: ../../type.CowArray.html -//! [::default()]: ../../struct.ArrayBase.html#method.default -//! [.diag()]: ../../struct.ArrayBase.html#method.diag -//! [.dim()]: ../../struct.ArrayBase.html#method.dim -//! [::eye()]: ../../struct.ArrayBase.html#method.eye -//! [.fill()]: ../../struct.ArrayBase.html#method.fill -//! [.fold()]: ../../struct.ArrayBase.html#method.fold -//! [.fold_axis()]: ../../struct.ArrayBase.html#method.fold_axis -//! [::from_elem()]: ../../struct.ArrayBase.html#method.from_elem -//! [::from_iter()]: ../../struct.ArrayBase.html#method.from_iter -//! [::from_diag()]: ../../struct.ArrayBase.html#method.from_diag -//! [::from_shape_fn()]: ../../struct.ArrayBase.html#method.from_shape_fn -//! [::from_shape_vec()]: ../../struct.ArrayBase.html#method.from_shape_vec -//! [::from_shape_vec_unchecked()]: ../../struct.ArrayBase.html#method.from_shape_vec_unchecked -//! [::from_vec()]: ../../struct.ArrayBase.html#method.from_vec -//! [.index()]: ../../struct.ArrayBase.html#impl-Index -//! [.indexed_iter()]: ../../struct.ArrayBase.html#method.indexed_iter -//! [.insert_axis()]: ../../struct.ArrayBase.html#method.insert_axis -//! [.is_empty()]: ../../struct.ArrayBase.html#method.is_empty -//! [.is_square()]: ../../struct.ArrayBase.html#method.is_square -//! [.iter()]: ../../struct.ArrayBase.html#method.iter -//! [Ix]: ../../type.Ix.html -//! [.len()]: ../../struct.ArrayBase.html#method.len -//! [.len_of()]: ../../struct.ArrayBase.html#method.len_of -//! [::linspace()]: ../../struct.ArrayBase.html#method.linspace -//! [::logspace()]: ../../struct.ArrayBase.html#method.logspace -//! [::geomspace()]: ../../struct.ArrayBase.html#method.geomspace -//! [.map()]: ../../struct.ArrayBase.html#method.map -//! [.map_axis()]: ../../struct.ArrayBase.html#method.map_axis -//! [.map_inplace()]: ../../struct.ArrayBase.html#method.map_inplace -//! [.mapv()]: ../../struct.ArrayBase.html#method.mapv -//! [.mapv_inplace()]: ../../struct.ArrayBase.html#method.mapv_inplace -//! [.mapv_into()]: ../../struct.ArrayBase.html#method.mapv_into -//! [matrix-* dot]: ../../struct.ArrayBase.html#method.dot-1 -//! [.mean()]: ../../struct.ArrayBase.html#method.mean -//! [.mean_axis()]: ../../struct.ArrayBase.html#method.mean_axis -//! [.ndim()]: ../../struct.ArrayBase.html#method.ndim -//! [NdProducer]: ../../trait.NdProducer.html -//! [::ones()]: ../../struct.ArrayBase.html#method.ones -//! [.outer_iter()]: ../../struct.ArrayBase.html#method.outer_iter -//! [::range()]: ../../struct.ArrayBase.html#method.range -//! [.raw_dim()]: ../../struct.ArrayBase.html#method.raw_dim -//! [.reversed_axes()]: ../../struct.ArrayBase.html#method.reversed_axes -//! [.row()]: ../../struct.ArrayBase.html#method.row -//! [.row_mut()]: ../../struct.ArrayBase.html#method.row_mut -//! [.nrows()]: ../../struct.ArrayBase.html#method.nrows -//! [s!]: ../../macro.s.html -//! [.sum()]: ../../struct.ArrayBase.html#method.sum -//! [.slice()]: ../../struct.ArrayBase.html#method.slice -//! [.slice_axis()]: ../../struct.ArrayBase.html#method.slice_axis -//! [.slice_collapse()]: ../../struct.ArrayBase.html#method.slice_collapse -//! [.slice_move()]: ../../struct.ArrayBase.html#method.slice_move -//! [.slice_mut()]: ../../struct.ArrayBase.html#method.slice_mut -//! [.shape()]: ../../struct.ArrayBase.html#method.shape -//! [stack!]: ../../macro.stack.html -//! [stack()]: ../../fn.stack.html -//! [.strides()]: ../../struct.ArrayBase.html#method.strides -//! [.index_axis()]: ../../struct.ArrayBase.html#method.index_axis -//! [.sum_axis()]: ../../struct.ArrayBase.html#method.sum_axis -//! [.t()]: ../../struct.ArrayBase.html#method.t -//! [::uninitialized()]: ../../struct.ArrayBase.html#method.uninitialized -//! [vec-* dot]: ../../struct.ArrayBase.html#method.dot -//! [.visit()]: ../../struct.ArrayBase.html#method.visit -//! [::zeros()]: ../../struct.ArrayBase.html#method.zeros -//! [Zip]: ../../struct.Zip.html +//! [.abs_diff_eq()]: ArrayBase#impl-AbsDiffEq> +//! [.assign()]: ArrayBase::assign +//! [.axis_iter()]: ArrayBase::axis_iter +//! [.ncols()]: ArrayBase::ncols +//! [.column()]: ArrayBase::column +//! [.column_mut()]: ArrayBase::column_mut +//! [concatenate()]: crate::concatenate() +//! [concatenate!]: crate::concatenate! +//! [stack!]: crate::stack! +//! [::default()]: ArrayBase::default +//! [.diag()]: ArrayBase::diag +//! [.dim()]: ArrayBase::dim +//! [::eye()]: ArrayBase::eye +//! [.fill()]: ArrayBase::fill +//! [.fold()]: ArrayBase::fold +//! [.fold_axis()]: ArrayBase::fold_axis +//! [::from_elem()]: ArrayBase::from_elem +//! [::from_iter()]: ArrayBase::from_iter +//! [::from_diag()]: ArrayBase::from_diag +//! [::from_shape_fn()]: ArrayBase::from_shape_fn +//! [::from_shape_vec()]: ArrayBase::from_shape_vec +//! [::from_shape_vec_unchecked()]: ArrayBase::from_shape_vec_unchecked +//! [::from_vec()]: ArrayBase::from_vec +//! [.index()]: ArrayBase#impl-Index +//! [.indexed_iter()]: ArrayBase::indexed_iter +//! [.insert_axis()]: ArrayBase::insert_axis +//! [.is_empty()]: ArrayBase::is_empty +//! [.is_square()]: ArrayBase::is_square +//! [.iter()]: ArrayBase::iter +//! [.len()]: ArrayBase::len +//! [.len_of()]: ArrayBase::len_of +//! [::linspace()]: ArrayBase::linspace +//! [::logspace()]: ArrayBase::logspace +//! [::geomspace()]: ArrayBase::geomspace +//! [.map()]: ArrayBase::map +//! [.map_axis()]: ArrayBase::map_axis +//! [.map_inplace()]: ArrayBase::map_inplace +//! [.mapv()]: ArrayBase::mapv +//! [.mapv_inplace()]: ArrayBase::mapv_inplace +//! [.mapv_into()]: ArrayBase::mapv_into +//! [matrix-* dot]: ArrayBase::dot-1 +//! [.mean()]: ArrayBase::mean +//! [.mean_axis()]: ArrayBase::mean_axis +//! [.ndim()]: ArrayBase::ndim +//! [::ones()]: ArrayBase::ones +//! [.outer_iter()]: ArrayBase::outer_iter +//! [::range()]: ArrayBase::range +//! [.raw_dim()]: ArrayBase::raw_dim +//! [.reversed_axes()]: ArrayBase::reversed_axes +//! [.row()]: ArrayBase::row +//! [.row_mut()]: ArrayBase::row_mut +//! [.nrows()]: ArrayBase::nrows +//! [.sum()]: ArrayBase::sum +//! [.slice()]: ArrayBase::slice +//! [.slice_axis()]: ArrayBase::slice_axis +//! [.slice_collapse()]: ArrayBase::slice_collapse +//! [.slice_move()]: ArrayBase::slice_move +//! [.slice_mut()]: ArrayBase::slice_mut +//! [.shape()]: ArrayBase::shape +//! [stack()]: crate::stack() +//! [.strides()]: ArrayBase::strides +//! [.index_axis()]: ArrayBase::index_axis +//! [.sum_axis()]: ArrayBase::sum_axis +//! [.t()]: ArrayBase::t +//! [vec-* dot]: ArrayBase::dot +//! [.for_each()]: ArrayBase::for_each +//! [::zeros()]: ArrayBase::zeros +//! [`Zip`]: crate::Zip pub mod coord_transform; pub mod rk_step; pub mod simple_math; + +// This is to avoid putting `crate::` everywhere +#[allow(unused_imports)] +use crate::imp_prelude::*; diff --git a/src/doc/ndarray_for_numpy_users/rk_step.rs b/src/doc/ndarray_for_numpy_users/rk_step.rs index a68f6b171..c882a3d00 100644 --- a/src/doc/ndarray_for_numpy_users/rk_step.rs +++ b/src/doc/ndarray_for_numpy_users/rk_step.rs @@ -71,8 +71,6 @@ //! A direct translation to `ndarray` looks like this: //! //! ``` -//! extern crate ndarray; -//! //! use ndarray::prelude::*; //! //! fn rk_step( @@ -106,7 +104,7 @@ //! (y_new, f_new, error) //! } //! # -//! # fn main() {} +//! # fn main() { let _ = rk_step::) -> _>; } //! ``` //! //! It's possible to improve the efficiency by doing the following: @@ -122,15 +120,13 @@ //! * Don't return a newly allocated `f_new` array. If the caller wants this //! information, they can get it from the last row of `k`. //! -//! * Use [`c.mul_add(h, t)`][f64.mul_add()] instead of `t + c * h`. This is +//! * Use [`c.mul_add(h, t)`](f64::mul_add) instead of `t + c * h`. This is //! faster and reduces the floating-point error. It might also be beneficial -//! to use [`.scaled_add()`][.scaled_add()] or a combination of -//! [`azip!()`][azip!] and [`.mul_add()`][f64.mul_add()] on the arrays in +//! to use [`.scaled_add()`] or a combination of +//! [`azip!()`] and [`.mul_add()`](f64::mul_add) on the arrays in //! some places, but that's not demonstrated in the example below. //! //! ``` -//! extern crate ndarray; -//! //! use ndarray::prelude::*; //! //! fn rk_step( @@ -169,12 +165,11 @@ //! (y_new, error) //! } //! # -//! # fn main() {} +//! # fn main() { let _ = rk_step::, ArrayViewMut1<'_, f64>)>; } //! ``` //! -//! [f64.mul_add()]: https://doc.rust-lang.org/std/primitive.f64.html#method.mul_add -//! [.scaled_add()]: ../../../struct.ArrayBase.html#method.scaled_add -//! [azip!]: ../../../macro.azip.html +//! [`.scaled_add()`]: crate::ArrayBase::scaled_add +//! [`azip!()`]: crate::azip! //! //! ### SciPy license //! diff --git a/src/doc/ndarray_for_numpy_users/simple_math.rs b/src/doc/ndarray_for_numpy_users/simple_math.rs index b6549d3c2..1bba0c286 100644 --- a/src/doc/ndarray_for_numpy_users/simple_math.rs +++ b/src/doc/ndarray_for_numpy_users/simple_math.rs @@ -51,8 +51,6 @@ //!
//! //! ``` -//! extern crate ndarray; -//! //! use ndarray::prelude::*; //! //! # fn main() { diff --git a/src/error.rs b/src/error.rs index cb896223f..eb7395ad8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,25 +6,30 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. use super::Dimension; +#[cfg(feature = "std")] use std::error::Error; use std::fmt; /// An error related to array shape or layout. #[derive(Clone)] -pub struct ShapeError { +pub struct ShapeError +{ // we want to be able to change this representation later repr: ErrorKind, } -impl ShapeError { +impl ShapeError +{ /// Return the `ErrorKind` of this error. #[inline] - pub fn kind(&self) -> ErrorKind { + pub fn kind(&self) -> ErrorKind + { self.repr } /// Create a new `ShapeError` - pub fn from_kind(error: ErrorKind) -> Self { + pub fn from_kind(error: ErrorKind) -> Self + { from_kind(error) } } @@ -33,8 +38,10 @@ impl ShapeError { /// /// This enumeration is not exhaustive. The representation of the enum /// is not guaranteed. +#[non_exhaustive] #[derive(Copy, Clone, Debug)] -pub enum ErrorKind { +pub enum ErrorKind +{ /// incompatible shape IncompatibleShape = 1, /// incompatible memory layout @@ -47,52 +54,56 @@ pub enum ErrorKind { Unsupported, /// overflow when computing offset, length, etc. Overflow, - #[doc(hidden)] - __Incomplete, } #[inline(always)] -pub fn from_kind(k: ErrorKind) -> ShapeError { +pub fn from_kind(k: ErrorKind) -> ShapeError +{ ShapeError { repr: k } } -impl PartialEq for ErrorKind { +impl PartialEq for ErrorKind +{ #[inline(always)] - fn eq(&self, rhs: &Self) -> bool { + fn eq(&self, rhs: &Self) -> bool + { *self as u8 == *rhs as u8 } } -impl PartialEq for ShapeError { +impl PartialEq for ShapeError +{ #[inline(always)] - fn eq(&self, rhs: &Self) -> bool { + fn eq(&self, rhs: &Self) -> bool + { self.repr == rhs.repr } } -impl Error for ShapeError { - fn description(&self) -> &str { - match self.kind() { +#[cfg(feature = "std")] +impl Error for ShapeError {} + +impl fmt::Display for ShapeError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + let description = match self.kind() { ErrorKind::IncompatibleShape => "incompatible shapes", ErrorKind::IncompatibleLayout => "incompatible memory layout", ErrorKind::RangeLimited => "the shape does not fit in type limits", ErrorKind::OutOfBounds => "out of bounds indexing", ErrorKind::Unsupported => "unsupported operation", ErrorKind::Overflow => "arithmetic overflow", - ErrorKind::__Incomplete => "this error variant is not in use", - } - } -} - -impl fmt::Display for ShapeError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "ShapeError/{:?}: {}", self.kind(), self.description()) + }; + write!(f, "ShapeError/{:?}: {}", self.kind(), description) } } -impl fmt::Debug for ShapeError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "ShapeError/{:?}: {}", self.kind(), self.description()) +impl fmt::Debug for ShapeError +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "{}", self) } } diff --git a/src/extension/nonnull.rs b/src/extension/nonnull.rs index 32fbb07c4..08f80927e 100644 --- a/src/extension/nonnull.rs +++ b/src/extension/nonnull.rs @@ -1,7 +1,10 @@ +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; use std::ptr::NonNull; /// Return a NonNull pointer to the vector's data -pub(crate) fn nonnull_from_vec_data(v: &mut Vec) -> NonNull { +pub(crate) fn nonnull_from_vec_data(v: &mut Vec) -> NonNull +{ // this pointer is guaranteed to be non-null unsafe { NonNull::new_unchecked(v.as_mut_ptr()) } } @@ -12,7 +15,8 @@ pub(crate) fn nonnull_from_vec_data(v: &mut Vec) -> NonNull { /// This is checked with a debug assertion, and will panic if this is not true, /// but treat this as an unconditional conversion. #[inline] -pub(crate) unsafe fn nonnull_debug_checked_from_ptr(ptr: *mut T) -> NonNull { +pub(crate) unsafe fn nonnull_debug_checked_from_ptr(ptr: *mut T) -> NonNull +{ debug_assert!(!ptr.is_null()); NonNull::new_unchecked(ptr) } diff --git a/src/free_functions.rs b/src/free_functions.rs index ff7984ee6..5659d7024 100644 --- a/src/free_functions.rs +++ b/src/free_functions.rs @@ -6,14 +6,18 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +use alloc::vec; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; +#[allow(unused_imports)] +use std::compile_error; use std::mem::{forget, size_of}; -use std::slice; +use std::ptr::NonNull; use crate::imp_prelude::*; use crate::{dimension, ArcArray1, ArcArray2}; -/// Create an [**`Array`**](type.Array.html) with one, two or -/// three dimensions. +/// Create an **[`Array`]** with one, two, three, four, five, or six dimensions. /// /// ``` /// use ndarray::array; @@ -25,17 +29,49 @@ use crate::{dimension, ArcArray1, ArcArray2}; /// let a3 = array![[[1, 2], [3, 4]], /// [[5, 6], [7, 8]]]; /// +/// let a4 = array![[[[1, 2, 3, 4]]]]; +/// +/// let a5 = array![[[[[1, 2, 3, 4, 5]]]]]; +/// +/// let a6 = array![[[[[[1, 2, 3, 4, 5, 6]]]]]]; +/// /// assert_eq!(a1.shape(), &[4]); /// assert_eq!(a2.shape(), &[2, 2]); /// assert_eq!(a3.shape(), &[2, 2, 2]); +/// assert_eq!(a4.shape(), &[1, 1, 1, 4]); +/// assert_eq!(a5.shape(), &[1, 1, 1, 1, 5]); +/// assert_eq!(a6.shape(), &[1, 1, 1, 1, 1, 6]); /// ``` /// /// This macro uses `vec![]`, and has the same ownership semantics; /// elements are moved into the resulting `Array`. /// /// Use `array![...].into_shared()` to create an `ArcArray`. +/// +/// Attempts to crate 7D+ arrays with this macro will lead to +/// a compiler error, since the difference between a 7D array +/// of i32 and a 6D array of `[i32; 3]` is ambiguous. Higher-dim +/// arrays can be created with [`ArrayD`]. +/// +/// ```compile_fail +/// use ndarray::array; +/// let a7 = array![[[[[[[1, 2, 3]]]]]]]; +/// // error: Arrays of 7 dimensions or more (or ndarrays of Rust arrays) cannot be constructed with the array! macro. +/// ``` #[macro_export] macro_rules! array { + ($([$([$([$([$([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*) => {{ + compile_error!("Arrays of 7 dimensions or more (or ndarrays of Rust arrays) cannot be constructed with the array! macro."); + }}; + ($([$([$([$([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*) => {{ + $crate::Array6::from(vec![$([$([$([$([$([$($x,)*],)*],)*],)*],)*],)*]) + }}; + ($([$([$([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*) => {{ + $crate::Array5::from(vec![$([$([$([$([$($x,)*],)*],)*],)*],)*]) + }}; + ($([$([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*]),+ $(,)*) => {{ + $crate::Array4::from(vec![$([$([$([$($x,)*],)*],)*],)*]) + }}; ($([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*) => {{ $crate::Array3::from(vec![$([$([$($x,)*],)*],)*]) }}; @@ -48,62 +84,125 @@ macro_rules! array { } /// Create a zero-dimensional array with the element `x`. -pub fn arr0(x: A) -> Array0 { +pub fn arr0(x: A) -> Array0 +{ unsafe { ArrayBase::from_shape_vec_unchecked((), vec![x]) } } /// Create a one-dimensional array with elements from `xs`. -pub fn arr1(xs: &[A]) -> Array1 { +pub fn arr1(xs: &[A]) -> Array1 +{ ArrayBase::from(xs.to_vec()) } /// Create a one-dimensional array with elements from `xs`. -pub fn rcarr1(xs: &[A]) -> ArcArray1 { +pub fn rcarr1(xs: &[A]) -> ArcArray1 +{ arr1(xs).into_shared() } /// Create a zero-dimensional array view borrowing `x`. -pub fn aview0(x: &A) -> ArrayView0<'_, A> { - unsafe { ArrayView::from_shape_ptr(Ix0(), x) } +pub const fn aview0(x: &A) -> ArrayView0<'_, A> +{ + ArrayBase { + data: ViewRepr::new(), + // Safe because references are always non-null. + ptr: unsafe { NonNull::new_unchecked(x as *const A as *mut A) }, + dim: Ix0(), + strides: Ix0(), + } } /// Create a one-dimensional array view with elements borrowing `xs`. /// +/// **Panics** if the length of the slice overflows `isize`. (This can only +/// occur if `A` is zero-sized, because slices cannot contain more than +/// `isize::MAX` number of bytes.) +/// /// ``` -/// use ndarray::aview1; +/// use ndarray::{aview1, ArrayView1}; /// /// let data = [1.0; 1024]; /// /// // Create a 2D array view from borrowed data -/// let a2d = aview1(&data).into_shape((32, 32)).unwrap(); +/// let a2d = aview1(&data).into_shape_with_order((32, 32)).unwrap(); /// /// assert_eq!(a2d.sum(), 1024.0); +/// +/// // Create a const 1D array view +/// const C: ArrayView1<'static, f64> = aview1(&[1., 2., 3.]); +/// +/// assert_eq!(C.sum(), 6.); /// ``` -pub fn aview1(xs: &[A]) -> ArrayView1<'_, A> { - ArrayView::from(xs) +pub const fn aview1(xs: &[A]) -> ArrayView1<'_, A> +{ + if size_of::() == 0 { + assert!( + xs.len() <= isize::MAX as usize, + "Slice length must fit in `isize`.", + ); + } + ArrayBase { + data: ViewRepr::new(), + // Safe because references are always non-null. + ptr: unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) }, + dim: Ix1(xs.len()), + strides: Ix1(1), + } } /// Create a two-dimensional array view with elements borrowing `xs`. /// -/// **Panics** if the product of non-zero axis lengths overflows `isize`. (This -/// can only occur when `V` is zero-sized.) -pub fn aview2>(xs: &[V]) -> ArrayView2<'_, A> { - let cols = V::len(); +/// **Panics** if the product of non-zero axis lengths overflows `isize` (This +/// can only occur if A is zero-sized or if `N` is zero, because slices cannot +/// contain more than `isize::MAX` number of bytes). +/// +/// ``` +/// use ndarray::{aview2, ArrayView2}; +/// +/// let data = vec![[1., 2., 3.], [4., 5., 6.]]; +/// +/// let view = aview2(&data); +/// assert_eq!(view.sum(), 21.); +/// +/// // Create a const 2D array view +/// const C: ArrayView2<'static, f64> = aview2(&[[1., 2., 3.], [4., 5., 6.]]); +/// assert_eq!(C.sum(), 21.); +/// ``` +pub const fn aview2(xs: &[[A; N]]) -> ArrayView2<'_, A> +{ + let cols = N; let rows = xs.len(); - let dim = Ix2(rows, cols); - if size_of::() == 0 { - dimension::size_of_shape_checked(&dim) - .expect("Product of non-zero axis lengths must not overflow isize."); + if size_of::() == 0 { + if let Some(n_elems) = rows.checked_mul(cols) { + assert!( + rows <= isize::MAX as usize + && cols <= isize::MAX as usize + && n_elems <= isize::MAX as usize, + "Product of non-zero axis lengths must not overflow isize.", + ); + } else { + panic!("Overflow in number of elements."); + } + } else if N == 0 { + assert!( + rows <= isize::MAX as usize, + "Product of non-zero axis lengths must not overflow isize.", + ); } - // `rows` is guaranteed to fit in `isize` because we've checked the ZST - // case and slices never contain > `isize::MAX` bytes. `cols` is guaranteed - // to fit in `isize` because `FixedInitializer` is not implemented for any - // array lengths > `isize::MAX`. `cols * rows` is guaranteed to fit in - // `isize` because we've checked the ZST case and slices never contain > - // `isize::MAX` bytes. - unsafe { - let data = slice::from_raw_parts(xs.as_ptr() as *const A, cols * rows); - ArrayView::from_shape_ptr(dim, data.as_ptr()) + // Safe because references are always non-null. + let ptr = unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) }; + let dim = Ix2(rows, cols); + let strides = if rows == 0 || cols == 0 { + Ix2(0, 0) + } else { + Ix2(cols, 1) + }; + ArrayBase { + data: ViewRepr::new(), + ptr, + dim, + strides, } } @@ -114,27 +213,27 @@ pub fn aview2>(xs: &[V]) -> ArrayView2<'_, A> { /// // Create an array view over some data, then slice it and modify it. /// let mut data = [0; 1024]; /// { -/// let mut a = aview_mut1(&mut data).into_shape((32, 32)).unwrap(); +/// let mut a = aview_mut1(&mut data).into_shape_with_order((32, 32)).unwrap(); /// a.slice_mut(s![.., ..;3]).fill(5); /// } /// assert_eq!(&data[..10], [5, 0, 0, 5, 0, 0, 5, 0, 0, 5]); /// ``` -pub fn aview_mut1(xs: &mut [A]) -> ArrayViewMut1<'_, A> { +pub fn aview_mut1(xs: &mut [A]) -> ArrayViewMut1<'_, A> +{ ArrayViewMut::from(xs) } /// Create a two-dimensional read-write array view with elements borrowing `xs`. /// -/// **Panics** if the product of non-zero axis lengths overflows `isize`. (This -/// can only occur when `V` is zero-sized.) +/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A +/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes). /// /// # Example /// /// ``` /// use ndarray::aview_mut2; /// -/// // The inner (nested) array must be of length 1 to 16, but the outer -/// // can be of any length. +/// // The inner (nested) and outer arrays can be of any length. /// let mut data = [[0.; 2]; 128]; /// { /// // Make a 128 x 2 mut array view then turn it into 2 x 128 @@ -146,56 +245,11 @@ pub fn aview_mut1(xs: &mut [A]) -> ArrayViewMut1<'_, A> { /// // look at the start of the result /// assert_eq!(&data[..3], [[1., -1.], [1., -1.], [1., -1.]]); /// ``` -pub fn aview_mut2>(xs: &mut [V]) -> ArrayViewMut2<'_, A> { - let cols = V::len(); - let rows = xs.len(); - let dim = Ix2(rows, cols); - if size_of::() == 0 { - dimension::size_of_shape_checked(&dim) - .expect("Product of non-zero axis lengths must not overflow isize."); - } - // `rows` is guaranteed to fit in `isize` because we've checked the ZST - // case and slices never contain > `isize::MAX` bytes. `cols` is guaranteed - // to fit in `isize` because `FixedInitializer` is not implemented for any - // array lengths > `isize::MAX`. `cols * rows` is guaranteed to fit in - // `isize` because we've checked the ZST case and slices never contain > - // `isize::MAX` bytes. - unsafe { - let data = slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut A, cols * rows); - ArrayViewMut::from_shape_ptr(dim, data.as_mut_ptr()) - } -} - -/// Fixed-size array used for array initialization -pub unsafe trait FixedInitializer { - type Elem; - fn as_init_slice(&self) -> &[Self::Elem]; - fn len() -> usize; -} - -macro_rules! impl_arr_init { - (__impl $n: expr) => ( - unsafe impl FixedInitializer for [T; $n] { - type Elem = T; - fn as_init_slice(&self) -> &[T] { self } - fn len() -> usize { $n } - } - ); - () => (); - ($n: expr, $($m:expr,)*) => ( - impl_arr_init!(__impl $n); - impl_arr_init!($($m,)*); - ) - +pub fn aview_mut2(xs: &mut [[A; N]]) -> ArrayViewMut2<'_, A> +{ + ArrayViewMut2::from(xs) } -// For implementors: If you ever implement `FixedInitializer` for array lengths -// > `isize::MAX` (e.g. once Rust adds const generics), you must update -// `aview2` and `aview_mut2` to perform the necessary checks. In particular, -// the assumption that `cols` can never exceed `isize::MAX` would be incorrect. -// (Consider e.g. `let xs: &[[i32; ::std::usize::MAX]] = &[]`.) -impl_arr_init!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,); - /// Create a two-dimensional array with elements from `xs`. /// /// ``` @@ -207,77 +261,49 @@ impl_arr_init!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,); /// a.shape() == [2, 3] /// ); /// ``` -pub fn arr2>(xs: &[V]) -> Array2 -where - V: Clone, +pub fn arr2(xs: &[[A; N]]) -> Array2 { Array2::from(xs.to_vec()) } -impl From> for Array2 -where - V: FixedInitializer, -{ - /// Converts the `Vec` of arrays to an owned 2-D array. - /// - /// **Panics** if the product of non-zero axis lengths overflows `isize`. - fn from(mut xs: Vec) -> Self { - let dim = Ix2(xs.len(), V::len()); - let ptr = xs.as_mut_ptr(); - let cap = xs.capacity(); - let expand_len = dimension::size_of_shape_checked(&dim) - .expect("Product of non-zero axis lengths must not overflow isize."); - forget(xs); - unsafe { - let v = if size_of::() == 0 { - Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len) - } else if V::len() == 0 { - Vec::new() - } else { - // Guaranteed not to overflow in this case since A is non-ZST - // and Vec never allocates more than isize bytes. - let expand_cap = cap * V::len(); - Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap) - }; - ArrayBase::from_shape_vec_unchecked(dim, v) +macro_rules! impl_from_nested_vec { + ($arr_type:ty, $ix_type:tt, $($n:ident),+) => { + impl From> for Array + { + fn from(mut xs: Vec<$arr_type>) -> Self + { + let dim = $ix_type(xs.len(), $($n),+); + let ptr = xs.as_mut_ptr(); + let cap = xs.capacity(); + let expand_len = dimension::size_of_shape_checked(&dim) + .expect("Product of non-zero axis lengths must not overflow isize."); + forget(xs); + unsafe { + let v = if size_of::() == 0 { + Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len) + } else if $($n == 0 ||)+ false { + Vec::new() + } else { + let expand_cap = cap $(* $n)+; + Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap) + }; + ArrayBase::from_shape_vec_unchecked(dim, v) + } + } } - } + }; } -impl From> for Array3 -where - V: FixedInitializer, - U: FixedInitializer, -{ - /// Converts the `Vec` of arrays to an owned 3-D array. - /// - /// **Panics** if the product of non-zero axis lengths overflows `isize`. - fn from(mut xs: Vec) -> Self { - let dim = Ix3(xs.len(), V::len(), U::len()); - let ptr = xs.as_mut_ptr(); - let cap = xs.capacity(); - let expand_len = dimension::size_of_shape_checked(&dim) - .expect("Product of non-zero axis lengths must not overflow isize."); - forget(xs); - unsafe { - let v = if size_of::() == 0 { - Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len) - } else if V::len() == 0 || U::len() == 0 { - Vec::new() - } else { - // Guaranteed not to overflow in this case since A is non-ZST - // and Vec never allocates more than isize bytes. - let expand_cap = cap * V::len() * U::len(); - Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap) - }; - ArrayBase::from_shape_vec_unchecked(dim, v) - } - } -} +impl_from_nested_vec!([A; N], Ix2, N); +impl_from_nested_vec!([[A; M]; N], Ix3, N, M); +impl_from_nested_vec!([[[A; L]; M]; N], Ix4, N, M, L); +impl_from_nested_vec!([[[[A; K]; L]; M]; N], Ix5, N, M, L, K); +impl_from_nested_vec!([[[[[A; J]; K]; L]; M]; N], Ix6, N, M, L, K, J); /// Create a two-dimensional array with elements from `xs`. /// -pub fn rcarr2>(xs: &[V]) -> ArcArray2 { +pub fn rcarr2(xs: &[[A; N]]) -> ArcArray2 +{ arr2(xs).into_shared() } @@ -298,23 +324,13 @@ pub fn rcarr2>(xs: &[V]) -> ArcA /// a.shape() == [3, 2, 2] /// ); /// ``` -pub fn arr3, U: FixedInitializer>( - xs: &[V], -) -> Array3 -where - V: Clone, - U: Clone, +pub fn arr3(xs: &[[[A; M]; N]]) -> Array3 { Array3::from(xs.to_vec()) } /// Create a three-dimensional array with elements from `xs`. -pub fn rcarr3, U: FixedInitializer>( - xs: &[V], -) -> ArcArray -where - V: Clone, - U: Clone, +pub fn rcarr3(xs: &[[[A; M]; N]]) -> ArcArray { arr3(xs).into_shared() } diff --git a/src/geomspace.rs b/src/geomspace.rs index 06242f68e..0ac91f529 100644 --- a/src/geomspace.rs +++ b/src/geomspace.rs @@ -5,12 +5,14 @@ // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. +#![cfg(feature = "std")] use num_traits::Float; /// An iterator of a sequence of geometrically spaced floats. /// /// Iterator element type is `F`. -pub struct Geomspace { +pub struct Geomspace +{ sign: F, start: F, step: F, @@ -19,13 +21,13 @@ pub struct Geomspace { } impl Iterator for Geomspace -where - F: Float, +where F: Float { type Item = F; #[inline] - fn next(&mut self) -> Option { + fn next(&mut self) -> Option + { if self.index >= self.len { None } else { @@ -38,18 +40,19 @@ where } #[inline] - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { let n = self.len - self.index; (n, Some(n)) } } impl DoubleEndedIterator for Geomspace -where - F: Float, +where F: Float { #[inline] - fn next_back(&mut self) -> Option { + fn next_back(&mut self) -> Option + { if self.index >= self.len { None } else { @@ -79,8 +82,7 @@ impl ExactSizeIterator for Geomspace where Geomspace: Iterator {} /// **Panics** if converting `n - 1` to type `F` fails. #[inline] pub fn geomspace(a: F, b: F, n: usize) -> Option> -where - F: Float, +where F: Float { if a == F::zero() || b == F::zero() || a.is_sign_negative() != b.is_sign_negative() { return None; @@ -103,12 +105,14 @@ where } #[cfg(test)] -mod tests { +mod tests +{ use super::geomspace; #[test] #[cfg(feature = "approx")] - fn valid() { + fn valid() + { use crate::{arr1, Array1}; use approx::assert_abs_diff_eq; @@ -126,7 +130,8 @@ mod tests { } #[test] - fn iter_forward() { + fn iter_forward() + { let mut iter = geomspace(1.0f64, 1e3, 4).unwrap(); assert!(iter.size_hint() == (4, Some(4))); @@ -141,7 +146,8 @@ mod tests { } #[test] - fn iter_backward() { + fn iter_backward() + { let mut iter = geomspace(1.0f64, 1e3, 4).unwrap(); assert!(iter.size_hint() == (4, Some(4))); @@ -156,17 +162,20 @@ mod tests { } #[test] - fn zero_lower() { + fn zero_lower() + { assert!(geomspace(0.0, 1.0, 4).is_none()); } #[test] - fn zero_upper() { + fn zero_upper() + { assert!(geomspace(1.0, 0.0, 4).is_none()); } #[test] - fn zero_included() { + fn zero_included() + { assert!(geomspace(-1.0, 1.0, 4).is_none()); } } diff --git a/src/impl_1d.rs b/src/impl_1d.rs index fa877eff0..e49fdd731 100644 --- a/src/impl_1d.rs +++ b/src/impl_1d.rs @@ -7,12 +7,16 @@ // except according to those terms. //! Methods for one-dimensional arrays. +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; +use std::mem::MaybeUninit; + use crate::imp_prelude::*; +use crate::low_level_util::AbortIfPanic; /// # Methods For 1-D Arrays impl ArrayBase -where - S: RawData, +where S: RawData { /// Return an vector with the elements of the one-dimensional array. pub fn to_vec(&self) -> Vec @@ -26,4 +30,34 @@ where crate::iterators::to_vec(self.iter().cloned()) } } + + /// Rotate the elements of the array by 1 element towards the front; + /// the former first element becomes the last. + pub(crate) fn rotate1_front(&mut self) + where S: DataMut + { + // use swapping to keep all elements initialized (as required by owned storage) + let mut lane_iter = self.iter_mut(); + let mut dst = if let Some(dst) = lane_iter.next() { dst } else { return }; + + // Logically we do a circular swap here, all elements in a chain + // Using MaybeUninit to avoid unnecessary writes in the safe swap solution + // + // for elt in lane_iter { + // std::mem::swap(dst, elt); + // dst = elt; + // } + // + let guard = AbortIfPanic(&"rotate1_front: temporarily moving out of owned value"); + let mut slot = MaybeUninit::::uninit(); + unsafe { + slot.as_mut_ptr().copy_from_nonoverlapping(dst, 1); + for elt in lane_iter { + (dst as *mut A).copy_from_nonoverlapping(elt, 1); + dst = elt; + } + (dst as *mut A).copy_from_nonoverlapping(slot.as_ptr(), 1); + } + guard.defuse(); + } } diff --git a/src/impl_2d.rs b/src/impl_2d.rs index f44990ed7..c2e9725ac 100644 --- a/src/impl_2d.rs +++ b/src/impl_2d.rs @@ -11,15 +11,20 @@ use crate::imp_prelude::*; /// # Methods For 2-D Arrays impl ArrayBase -where - S: RawData, +where S: RawData { /// Return an array view of row `index`. /// /// **Panics** if `index` is out of bounds. + /// + /// ``` + /// use ndarray::array; + /// let array = array![[1., 2.], [3., 4.]]; + /// assert_eq!(array.row(0), array![1., 2.]); + /// ``` + #[track_caller] pub fn row(&self, index: Ix) -> ArrayView1<'_, A> - where - S: Data, + where S: Data { self.index_axis(Axis(0), index) } @@ -27,30 +32,54 @@ where /// Return a mutable array view of row `index`. /// /// **Panics** if `index` is out of bounds. + /// + /// ``` + /// use ndarray::array; + /// let mut array = array![[1., 2.], [3., 4.]]; + /// array.row_mut(0)[1] = 5.; + /// assert_eq!(array, array![[1., 5.], [3., 4.]]); + /// ``` + #[track_caller] pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut1<'_, A> - where - S: DataMut, + where S: DataMut { self.index_axis_mut(Axis(0), index) } /// Return the number of rows (length of `Axis(0)`) in the two-dimensional array. - pub fn nrows(&self) -> usize { + /// + /// ``` + /// use ndarray::{array, Axis}; + /// + /// let array = array![[1., 2.], + /// [3., 4.], + /// [5., 6.]]; + /// assert_eq!(array.nrows(), 3); + /// + /// // equivalent ways of getting the dimensions + /// // get nrows, ncols by using dim: + /// let (m, n) = array.dim(); + /// assert_eq!(m, array.nrows()); + /// // get length of any particular axis with .len_of() + /// assert_eq!(m, array.len_of(Axis(0))); + /// ``` + pub fn nrows(&self) -> usize + { self.len_of(Axis(0)) } - /// Return the number of rows (length of `Axis(0)`) in the two-dimensional array. - #[deprecated(note = "Renamed to .nrows(), please use the new name")] - pub fn rows(&self) -> usize { - self.nrows() - } - /// Return an array view of column `index`. /// /// **Panics** if `index` is out of bounds. + /// + /// ``` + /// use ndarray::array; + /// let array = array![[1., 2.], [3., 4.]]; + /// assert_eq!(array.column(0), array![1., 3.]); + /// ``` + #[track_caller] pub fn column(&self, index: Ix) -> ArrayView1<'_, A> - where - S: Data, + where S: Data { self.index_axis(Axis(1), index) } @@ -58,26 +87,60 @@ where /// Return a mutable array view of column `index`. /// /// **Panics** if `index` is out of bounds. + /// + /// ``` + /// use ndarray::array; + /// let mut array = array![[1., 2.], [3., 4.]]; + /// array.column_mut(0)[1] = 5.; + /// assert_eq!(array, array![[1., 2.], [5., 4.]]); + /// ``` + #[track_caller] pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut1<'_, A> - where - S: DataMut, + where S: DataMut { self.index_axis_mut(Axis(1), index) } /// Return the number of columns (length of `Axis(1)`) in the two-dimensional array. - pub fn ncols(&self) -> usize { + /// + /// ``` + /// use ndarray::{array, Axis}; + /// + /// let array = array![[1., 2.], + /// [3., 4.], + /// [5., 6.]]; + /// assert_eq!(array.ncols(), 2); + /// + /// // equivalent ways of getting the dimensions + /// // get nrows, ncols by using dim: + /// let (m, n) = array.dim(); + /// assert_eq!(n, array.ncols()); + /// // get length of any particular axis with .len_of() + /// assert_eq!(n, array.len_of(Axis(1))); + /// ``` + pub fn ncols(&self) -> usize + { self.len_of(Axis(1)) } - /// Return the number of columns (length of `Axis(1)`) in the two-dimensional array. - #[deprecated(note = "Renamed to .ncols(), please use the new name")] - pub fn cols(&self) -> usize { - self.ncols() - } - /// Return true if the array is square, false otherwise. - pub fn is_square(&self) -> bool { - self.nrows() == self.ncols() + /// + /// # Examples + /// Square: + /// ``` + /// use ndarray::array; + /// let array = array![[1., 2.], [3., 4.]]; + /// assert!(array.is_square()); + /// ``` + /// Not square: + /// ``` + /// use ndarray::array; + /// let array = array![[1., 2., 5.], [3., 4., 6.]]; + /// assert!(!array.is_square()); + /// ``` + pub fn is_square(&self) -> bool + { + let (m, n) = self.dim(); + m == n } } diff --git a/src/impl_arc_array.rs b/src/impl_arc_array.rs new file mode 100644 index 000000000..619ae2506 --- /dev/null +++ b/src/impl_arc_array.rs @@ -0,0 +1,30 @@ +// Copyright 2019 ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::imp_prelude::*; + +#[cfg(target_has_atomic = "ptr")] +use alloc::sync::Arc; + +#[cfg(not(target_has_atomic = "ptr"))] +use portable_atomic_util::Arc; + +/// Methods specific to `ArcArray`. +/// +/// ***See also all methods for [`ArrayBase`]*** +impl ArcArray +where D: Dimension +{ + /// Returns `true` iff the inner `Arc` is not shared. + /// If you want to ensure the `Arc` is not concurrently cloned, you need to provide a `&mut self` to this function. + pub fn is_unique(&self) -> bool + { + // Only strong pointers are used in this crate. + Arc::strong_count(&self.data.0) == 1 + } +} diff --git a/src/impl_clone.rs b/src/impl_clone.rs index 603708849..d65f6c338 100644 --- a/src/impl_clone.rs +++ b/src/impl_clone.rs @@ -9,8 +9,11 @@ use crate::imp_prelude::*; use crate::RawDataClone; -impl Clone for ArrayBase { - fn clone(&self) -> ArrayBase { +impl Clone for ArrayBase +{ + fn clone(&self) -> ArrayBase + { + // safe because `clone_with_ptr` promises to provide equivalent data and ptr unsafe { let (data, ptr) = self.data.clone_with_ptr(self.ptr); ArrayBase { @@ -25,7 +28,8 @@ impl Clone for ArrayBase { /// `Array` implements `.clone_from()` to reuse an array's existing /// allocation. Semantically equivalent to `*self = other.clone()`, but /// potentially more efficient. - fn clone_from(&mut self, other: &Self) { + fn clone_from(&mut self, other: &Self) + { unsafe { self.ptr = self.data.clone_from_with_ptr(&other.data, other.ptr); self.dim.clone_from(&other.dim); diff --git a/src/impl_constructors.rs b/src/impl_constructors.rs index 1f8b9708e..260937a90 100644 --- a/src/impl_constructors.rs +++ b/src/impl_constructors.rs @@ -11,18 +11,31 @@ //! #![allow(clippy::match_wild_err_arm)] +use alloc::vec; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; +#[cfg(feature = "std")] +use num_traits::Float; +use num_traits::{One, Zero}; +use std::mem; +use std::mem::MaybeUninit; -use num_traits::{Float, One, Zero}; - -use crate::dimension; +use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr; +use crate::dimension::{self, CanIndexCheckMode}; use crate::error::{self, ShapeError}; use crate::extension::nonnull::nonnull_from_vec_data; use crate::imp_prelude::*; use crate::indexes; use crate::indices; -use crate::iterators::{to_vec, to_vec_mapped}; +#[cfg(feature = "std")] +use crate::iterators::to_vec; +use crate::iterators::to_vec_mapped; +use crate::iterators::TrustedIterator; use crate::StrideShape; +#[cfg(feature = "std")] use crate::{geomspace, linspace, logspace}; +#[allow(unused_imports)] +use rawpointer::PointerExt; /// # Constructor Methods for Owned Arrays /// @@ -31,8 +44,7 @@ use crate::{geomspace, linspace, logspace}; /// /// ## Constructor methods for one-dimensional arrays. impl ArrayBase -where - S: DataOwned, +where S: DataOwned { /// Create a one-dimensional array from a vector (no copying needed). /// @@ -41,11 +53,32 @@ where /// ```rust /// use ndarray::Array; /// - /// let array = Array::from(vec![1., 2., 3., 4.]); + /// let array = Array::from_vec(vec![1., 2., 3., 4.]); /// ``` - #[deprecated(note = "use standard `from`", since = "0.13.0")] - pub fn from_vec(v: Vec) -> Self { - Self::from(v) + pub fn from_vec(v: Vec) -> Self + { + if mem::size_of::() == 0 { + assert!( + v.len() <= isize::MAX as usize, + "Length must fit in `isize`.", + ); + } + unsafe { Self::from_shape_vec_unchecked(v.len() as Ix, v) } + } + + /// Create a one-dimensional array from an iterator or iterable. + /// + /// **Panics** if the length is greater than `isize::MAX`. + /// + /// ```rust + /// use ndarray::Array; + /// + /// let array = Array::from_iter(0..10); + /// ``` + #[allow(clippy::should_implement_trait)] + pub fn from_iter>(iterable: I) -> Self + { + Self::from_vec(iterable.into_iter().collect()) } /// Create a one-dimensional array with `n` evenly spaced elements from @@ -65,9 +98,9 @@ where /// let array = Array::linspace(0., 1., 5); /// assert!(array == arr1(&[0.0, 0.25, 0.5, 0.75, 1.0])) /// ``` + #[cfg(feature = "std")] pub fn linspace(start: A, end: A, n: usize) -> Self - where - A: Float, + where A: Float { Self::from(to_vec(linspace::linspace(start, end, n))) } @@ -83,9 +116,9 @@ where /// let array = Array::range(0., 5., 1.); /// assert!(array == arr1(&[0., 1., 2., 3., 4.])) /// ``` + #[cfg(feature = "std")] pub fn range(start: A, end: A, step: A) -> Self - where - A: Float, + where A: Float { Self::from(to_vec(linspace::range(start, end, step))) } @@ -100,10 +133,10 @@ where /// to type `A` fails. /// /// ```rust + /// # #[cfg(feature = "approx")] { /// use approx::assert_abs_diff_eq; /// use ndarray::{Array, arr1}; /// - /// # #[cfg(feature = "approx")] { /// let array = Array::logspace(10.0, 0.0, 3.0, 4); /// assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3])); /// @@ -111,9 +144,9 @@ where /// assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0])); /// # } /// ``` + #[cfg(feature = "std")] pub fn logspace(base: A, start: A, end: A, n: usize) -> Self - where - A: Float, + where A: Float { Self::from(to_vec(logspace::logspace(base, start, end, n))) } @@ -129,11 +162,11 @@ where /// to type `A` fails. /// /// ```rust + /// # fn example() -> Option<()> { + /// # #[cfg(feature = "approx")] { /// use approx::assert_abs_diff_eq; /// use ndarray::{Array, arr1}; /// - /// # fn example() -> Option<()> { - /// # #[cfg(feature = "approx")] { /// let array = Array::geomspace(1e0, 1e3, 4)?; /// assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]), epsilon = 1e-12); /// @@ -145,9 +178,9 @@ where /// # /// # example().unwrap(); /// ``` + #[cfg(feature = "std")] pub fn geomspace(start: A, end: A, n: usize) -> Option - where - A: Float, + where A: Float { Some(Self::from(to_vec(geomspace::geomspace(start, end, n)?))) } @@ -155,8 +188,7 @@ where /// ## Constructor methods for two-dimensional arrays. impl ArrayBase -where - S: DataOwned, +where S: DataOwned { /// Create an identity matrix of size `n` (square 2D array). /// @@ -192,9 +224,32 @@ where { let n = diag.len(); let mut arr = Self::zeros((n, n)); - arr.diag_mut().assign(&diag); + arr.diag_mut().assign(diag); arr } + + /// Create a square 2D matrix of the specified size, with the specified + /// element along the diagonal and zeros elsewhere. + /// + /// **Panics** if `n * n` would overflow `isize`. + /// + /// ```rust + /// use ndarray::{array, Array2}; + /// + /// let array = Array2::from_diag_elem(2, 5.); + /// assert_eq!(array, array![[5., 0.], [0., 5.]]); + /// ``` + pub fn from_diag_elem(n: usize, elem: A) -> Self + where + S: DataMut, + A: Clone + Zero, + { + let mut eye = Self::zeros((n, n)); + for a_ii in eye.diag_mut() { + *a_ii = elem.clone(); + } + eye + } } #[cfg(not(debug_assertions))] @@ -234,7 +289,7 @@ macro_rules! size_of_shape_checked_unwrap { /// column major (“f” order) memory layout instead of the default row major. /// For example `Array::zeros((5, 6).f())` makes a column major 5 × 6 array. /// -/// Use [`IxDyn`](type.IxDyn.html) for the shape to create an array with dynamic +/// Use [`type@IxDyn`] for the shape to create an array with dynamic /// number of axes. /// /// Finally, the few constructors that take a completely general @@ -270,7 +325,7 @@ where A: Clone, Sh: ShapeBuilder, { - let shape = shape.into_shape(); + let shape = shape.into_shape_with_order(); let size = size_of_shape_checked_unwrap!(&shape.dim); let v = vec![elem; size]; unsafe { Self::from_shape_vec_unchecked(shape, v) } @@ -324,7 +379,7 @@ where Sh: ShapeBuilder, F: FnMut() -> A, { - let shape = shape.into_shape(); + let shape = shape.into_shape_with_order(); let len = size_of_shape_checked_unwrap!(&shape.dim); let v = to_vec_mapped(0..len, move |_| f()); unsafe { Self::from_shape_vec_unchecked(shape, v) } @@ -355,9 +410,9 @@ where Sh: ShapeBuilder, F: FnMut(D::Pattern) -> A, { - let shape = shape.into_shape(); + let shape = shape.into_shape_with_order(); let _ = size_of_shape_checked_unwrap!(&shape.dim); - if shape.is_c { + if shape.is_c() { let v = to_vec_mapped(indices(shape.dim.clone()).into_iter(), f); unsafe { Self::from_shape_vec_unchecked(shape, v) } } else { @@ -401,37 +456,37 @@ where /// ); /// ``` pub fn from_shape_vec(shape: Sh, v: Vec) -> Result - where - Sh: Into>, + where Sh: Into> { // eliminate the type parameter Sh as soon as possible Self::from_shape_vec_impl(shape.into(), v) } - fn from_shape_vec_impl(shape: StrideShape, v: Vec) -> Result { + fn from_shape_vec_impl(shape: StrideShape, v: Vec) -> Result + { let dim = shape.dim; - let strides = shape.strides; - if shape.custom { - dimension::can_index_slice(&v, &dim, &strides)?; - } else { - dimension::can_index_slice_not_custom::(&v, &dim)?; - if dim.size() != v.len() { - return Err(error::incompatible_shapes(&Ix1(v.len()), &dim)); - } + let is_custom = shape.strides.is_custom(); + dimension::can_index_slice_with_strides(&v, &dim, &shape.strides, dimension::CanIndexCheckMode::OwnedMutable)?; + if !is_custom && dim.size() != v.len() { + return Err(error::incompatible_shapes(&Ix1(v.len()), &dim)); } + let strides = shape.strides.strides_for_dim(&dim); unsafe { Ok(Self::from_vec_dim_stride_unchecked(dim, strides, v)) } } /// Creates an array from a vector and interpret it according to the /// provided shape and strides. (No cloning of elements needed.) /// + /// # Safety + /// /// The caller must ensure that the following conditions are met: /// /// 1. The ndim of `dim` and `strides` must be the same. /// /// 2. The product of non-zero axis lengths must not exceed `isize::MAX`. /// - /// 3. For axes with length > 1, the stride must be nonnegative. + /// 3. For axes with length > 1, the pointer cannot move outside the + /// slice. /// /// 4. If the array will be empty (any axes are zero-length), the /// difference between the least address and greatest address accessible @@ -444,74 +499,130 @@ where /// 5. The strides must not allow any element to be referenced by two different /// indices. pub unsafe fn from_shape_vec_unchecked(shape: Sh, v: Vec) -> Self - where - Sh: Into>, + where Sh: Into> { let shape = shape.into(); - Self::from_vec_dim_stride_unchecked(shape.dim, shape.strides, v) + let dim = shape.dim; + let strides = shape.strides.strides_for_dim(&dim); + Self::from_vec_dim_stride_unchecked(dim, strides, v) } - unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec) -> Self { + unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec) -> Self + { // debug check for issues that indicates wrong use of this constructor - debug_assert!(dimension::can_index_slice(&v, &dim, &strides).is_ok()); - ArrayBase { - ptr: nonnull_from_vec_data(&mut v), - data: DataOwned::new(v), - strides, - dim, - } + debug_assert!(dimension::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok()); + + let ptr = nonnull_from_vec_data(&mut v).add(offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides)); + ArrayBase::from_data_ptr(DataOwned::new(v), ptr).with_strides_dim(strides, dim) + } + + /// Creates an array from an iterator, mapped by `map` and interpret it according to the + /// provided shape and strides. + /// + /// # Safety + /// + /// See from_shape_vec_unchecked + pub(crate) unsafe fn from_shape_trusted_iter_unchecked(shape: Sh, iter: I, map: F) -> Self + where + Sh: Into>, + I: TrustedIterator + ExactSizeIterator, + F: FnMut(I::Item) -> A, + { + let shape = shape.into(); + let dim = shape.dim; + let strides = shape.strides.strides_for_dim(&dim); + let v = to_vec_mapped(iter, map); + Self::from_vec_dim_stride_unchecked(dim, strides, v) } - /// Create an array with uninitalized elements, shape `shape`. + /// Create an array with uninitialized elements, shape `shape`. + /// + /// The uninitialized elements of type `A` are represented by the type `MaybeUninit`, + /// an easier way to handle uninit values correctly. + /// + /// Only *when* the array is completely initialized with valid elements, can it be + /// converted to an array of `A` elements using [`.assume_init()`]. /// /// **Panics** if the number of elements in `shape` would overflow isize. /// /// ### Safety /// - /// Accessing uninitalized values is undefined behaviour. You must - /// overwrite *all* the elements in the array after it is created; for - /// example using the methods `.fill()` or `.assign()`. + /// The whole of the array must be initialized before it is converted + /// using [`.assume_init()`] or otherwise traversed/read with the element type `A`. /// - /// The contents of the array is indeterminate before initialization and it - /// is an error to perform operations that use the previous values. For - /// example it would not be legal to use `a += 1.;` on such an array. + /// ### Examples /// - /// This constructor is limited to elements where `A: Copy` (no destructors) - /// to avoid users shooting themselves too hard in the foot; it is not - /// a problem to drop an array created with this method even before elements - /// are initialized. (Note that constructors `from_shape_vec` and - /// `from_shape_vec_unchecked` allow the user yet more control). + /// It is possible to assign individual values through `*elt = MaybeUninit::new(value)` + /// and so on. /// - /// ### Examples + /// [`.assume_init()`]: ArrayBase::assume_init /// /// ``` /// use ndarray::{s, Array2}; /// - /// // Example Task: Let's create a column shifted copy of a in b + /// // Example Task: Let's create a column shifted copy of the input /// /// fn shift_by_two(a: &Array2) -> Array2 { - /// let mut b = unsafe { Array2::uninitialized(a.dim()) }; + /// // create an uninitialized array + /// let mut b = Array2::uninit(a.dim()); /// /// // two first columns in b are two last in a /// // rest of columns in b are the initial columns in a - /// b.slice_mut(s![.., ..2]).assign(&a.slice(s![.., -2..])); - /// b.slice_mut(s![.., 2..]).assign(&a.slice(s![.., ..-2])); /// - /// // `b` is safe to use with all operations at this point - /// b + /// a.slice(s![.., -2..]).assign_to(b.slice_mut(s![.., ..2])); + /// a.slice(s![.., 2..]).assign_to(b.slice_mut(s![.., ..-2])); + /// + /// // Now we can promise that `b` is safe to use with all operations + /// unsafe { + /// b.assume_init() + /// } /// } /// - /// # shift_by_two(&Array2::zeros((8, 8))); + /// # let _ = shift_by_two; /// ``` - pub unsafe fn uninitialized(shape: Sh) -> Self + pub fn uninit(shape: Sh) -> ArrayBase + where Sh: ShapeBuilder + { + unsafe { + let shape = shape.into_shape_with_order(); + let size = size_of_shape_checked_unwrap!(&shape.dim); + let mut v = Vec::with_capacity(size); + v.set_len(size); + ArrayBase::from_shape_vec_unchecked(shape, v) + } + } + + /// Create an array with uninitialized elements, shape `shape`. + /// + /// The uninitialized elements of type `A` are represented by the type `MaybeUninit`, + /// an easier way to handle uninit values correctly. + /// + /// The `builder` closure gets unshared access to the array through a view and can use it to + /// modify the array before it is returned. This allows initializing the array for any owned + /// array type (avoiding clone requirements for copy-on-write, because the array is unshared + /// when initially created). + /// + /// Only *when* the array is completely initialized with valid elements, can it be + /// converted to an array of `A` elements using [`.assume_init()`]. + /// + /// **Panics** if the number of elements in `shape` would overflow isize. + /// + /// ### Safety + /// + /// The whole of the array must be initialized before it is converted + /// using [`.assume_init()`] or otherwise traversed/read with the element type `A`. + /// + /// [`.assume_init()`]: ArrayBase::assume_init + pub fn build_uninit(shape: Sh, builder: F) -> ArrayBase where - A: Copy, Sh: ShapeBuilder, + F: FnOnce(ArrayViewMut, D>), { - let shape = shape.into_shape(); - let size = size_of_shape_checked_unwrap!(&shape.dim); - let mut v = Vec::with_capacity(size); - v.set_len(size); - Self::from_shape_vec_unchecked(shape, v) + let mut array = Self::uninit(shape); + // Safe because: the array is unshared here + unsafe { + builder(array.raw_view_mut_unchecked().deref_into_view_mut()); + } + array } } diff --git a/src/impl_cow.rs b/src/impl_cow.rs index 52a06bfa2..4843e305b 100644 --- a/src/impl_cow.rs +++ b/src/impl_cow.rs @@ -11,47 +11,72 @@ use crate::imp_prelude::*; /// Methods specific to `CowArray`. /// /// ***See also all methods for [`ArrayBase`]*** -/// -/// [`ArrayBase`]: struct.ArrayBase.html -impl<'a, A, D> CowArray<'a, A, D> -where - D: Dimension, +impl CowArray<'_, A, D> +where D: Dimension { /// Returns `true` iff the array is the view (borrowed) variant. - pub fn is_view(&self) -> bool { + pub fn is_view(&self) -> bool + { self.data.is_view() } /// Returns `true` iff the array is the owned variant. - pub fn is_owned(&self) -> bool { + pub fn is_owned(&self) -> bool + { self.data.is_owned() } } impl<'a, A, D> From> for CowArray<'a, A, D> -where - D: Dimension, +where D: Dimension { - fn from(view: ArrayView<'a, A, D>) -> CowArray<'a, A, D> { - ArrayBase { - data: CowRepr::View(view.data), - ptr: view.ptr, - dim: view.dim, - strides: view.strides, - } + fn from(view: ArrayView<'a, A, D>) -> CowArray<'a, A, D> + { + // safe because equivalent data + unsafe { ArrayBase::from_data_ptr(CowRepr::View(view.data), view.ptr).with_strides_dim(view.strides, view.dim) } } } impl<'a, A, D> From> for CowArray<'a, A, D> +where D: Dimension +{ + fn from(array: Array) -> CowArray<'a, A, D> + { + // safe because equivalent data + unsafe { + ArrayBase::from_data_ptr(CowRepr::Owned(array.data), array.ptr).with_strides_dim(array.strides, array.dim) + } + } +} + +impl<'a, A, Slice: ?Sized> From<&'a Slice> for CowArray<'a, A, Ix1> +where Slice: AsRef<[A]> +{ + /// Create a one-dimensional clone-on-write view of the data in `slice`. + /// + /// **Panics** if the slice length is greater than [`isize::MAX`]. + /// + /// ``` + /// use ndarray::{array, CowArray}; + /// + /// let array = CowArray::from(&[1., 2., 3., 4.]); + /// assert!(array.is_view()); + /// assert_eq!(array, array![1., 2., 3., 4.]); + /// ``` + fn from(slice: &'a Slice) -> Self + { + Self::from(ArrayView1::from(slice)) + } +} + +impl<'a, A, S, D> From<&'a ArrayBase> for CowArray<'a, A, D> where + S: Data, D: Dimension, { - fn from(array: Array) -> CowArray<'a, A, D> { - ArrayBase { - data: CowRepr::Owned(array.data), - ptr: array.ptr, - dim: array.dim, - strides: array.strides, - } + /// Create a read-only clone-on-write view of the array. + fn from(array: &'a ArrayBase) -> Self + { + Self::from(array.view()) } } diff --git a/src/impl_dyn.rs b/src/impl_dyn.rs index 72851cc81..b86c5dd69 100644 --- a/src/impl_dyn.rs +++ b/src/impl_dyn.rs @@ -11,8 +11,7 @@ use crate::imp_prelude::*; /// # Methods for Dynamic-Dimensional Arrays impl ArrayBase -where - S: Data, +where S: Data { /// Insert new array axis of length 1 at `axis`, modifying the shape and /// strides in-place. @@ -29,7 +28,9 @@ where /// assert_eq!(a, arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn()); /// assert_eq!(a.shape(), &[2, 1, 3]); /// ``` - pub fn insert_axis_inplace(&mut self, axis: Axis) { + #[track_caller] + pub fn insert_axis_inplace(&mut self, axis: Axis) + { assert!(axis.index() <= self.ndim()); self.dim = self.dim.insert_axis(axis); self.strides = self.strides.insert_axis(axis); @@ -50,9 +51,67 @@ where /// assert_eq!(a, arr1(&[2, 5]).into_dyn()); /// assert_eq!(a.shape(), &[2]); /// ``` - pub fn index_axis_inplace(&mut self, axis: Axis, index: usize) { + #[track_caller] + pub fn index_axis_inplace(&mut self, axis: Axis, index: usize) + { self.collapse_axis(axis, index); self.dim = self.dim.remove_axis(axis); self.strides = self.strides.remove_axis(axis); } + + /// Remove axes of length 1 and return the modified array. + /// + /// If the array has more the one dimension, the result array will always + /// have at least one dimension, even if it has a length of 1. + /// + /// ``` + /// use ndarray::{arr1, arr2, arr3}; + /// + /// let a = arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn(); + /// assert_eq!(a.shape(), &[2, 1, 3]); + /// let b = a.squeeze(); + /// assert_eq!(b, arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn()); + /// assert_eq!(b.shape(), &[2, 3]); + /// + /// let c = arr2(&[[1]]).into_dyn(); + /// assert_eq!(c.shape(), &[1, 1]); + /// let d = c.squeeze(); + /// assert_eq!(d, arr1(&[1]).into_dyn()); + /// assert_eq!(d.shape(), &[1]); + /// ``` + #[track_caller] + pub fn squeeze(self) -> Self + { + let mut out = self; + for axis in (0..out.shape().len()).rev() { + if out.shape()[axis] == 1 && out.shape().len() > 1 { + out = out.remove_axis(Axis(axis)); + } + } + out + } +} + +#[cfg(test)] +mod tests +{ + use crate::{arr1, arr2, arr3}; + + #[test] + fn test_squeeze() + { + let a = arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn(); + assert_eq!(a.shape(), &[2, 1, 3]); + + let b = a.squeeze(); + assert_eq!(b, arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn()); + assert_eq!(b.shape(), &[2, 3]); + + let c = arr2(&[[1]]).into_dyn(); + assert_eq!(c.shape(), &[1, 1]); + + let d = c.squeeze(); + assert_eq!(d, arr1(&[1]).into_dyn()); + assert_eq!(d.shape(), &[1]); + } } diff --git a/src/impl_internal_constructors.rs b/src/impl_internal_constructors.rs new file mode 100644 index 000000000..adb4cbd35 --- /dev/null +++ b/src/impl_internal_constructors.rs @@ -0,0 +1,66 @@ +// Copyright 2021 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::ptr::NonNull; + +use crate::imp_prelude::*; + +// internal "builder-like" methods +impl ArrayBase +where S: RawData +{ + /// Create an (initially) empty one-dimensional array from the given data and array head + /// pointer + /// + /// ## Safety + /// + /// The caller must ensure that the data storage and pointer is valid. + /// + /// See ArrayView::from_shape_ptr for general pointer validity documentation. + #[inline] + pub(crate) unsafe fn from_data_ptr(data: S, ptr: NonNull) -> Self + { + let array = ArrayBase { + data, + ptr, + dim: Ix1(0), + strides: Ix1(1), + }; + debug_assert!(array.pointer_is_inbounds()); + array + } +} + +// internal "builder-like" methods +impl ArrayBase +where + S: RawData, + D: Dimension, +{ + /// Set strides and dimension of the array to the new values + /// + /// The argument order with strides before dimensions is used because strides are often + /// computed as derived from the dimension. + /// + /// ## Safety + /// + /// The caller needs to ensure that the new strides and dimensions are correct + /// for the array data. + #[inline] + pub(crate) unsafe fn with_strides_dim(self, strides: E, dim: E) -> ArrayBase + where E: Dimension + { + debug_assert_eq!(strides.ndim(), dim.ndim()); + ArrayBase { + data: self.data, + ptr: self.ptr, + dim, + strides, + } + } +} diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 027f5a8af..4a00ea000 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -6,31 +6,60 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::cmp; -use std::ptr as std_ptr; -use std::slice; - +use alloc::slice; +use alloc::vec; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; +#[allow(unused_imports)] use rawpointer::PointerExt; +use std::mem::{size_of, ManuallyDrop}; use crate::imp_prelude::*; -use crate::arraytraits; +use crate::argument_traits::AssignElem; use crate::dimension; +use crate::dimension::broadcast::co_broadcast; +use crate::dimension::reshape_dim; use crate::dimension::IntoDimension; use crate::dimension::{ - abs_index, axes_of, do_slice, merge_axes, size_of_shape_checked, stride_offset, Axes, + abs_index, + axes_of, + do_slice, + merge_axes, + move_min_stride_axis_to_last, + offset_from_low_addr_ptr_to_logical_ptr, + size_of_shape_checked, + stride_offset, + Axes, }; -use crate::error::{self, ErrorKind, ShapeError}; +use crate::error::{self, from_kind, ErrorKind, ShapeError}; use crate::itertools::zip; -use crate::zip::Zip; +use crate::math_cell::MathCell; +use crate::order::Order; +use crate::shape_builder::ShapeArg; +use crate::zip::{IntoNdProducer, Zip}; +use crate::AxisDescription; +use crate::{arraytraits, DimMax}; use crate::iter::{ - AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut, - IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows, + AxisChunksIter, + AxisChunksIterMut, + AxisIter, + AxisIterMut, + AxisWindows, + ExactChunks, + ExactChunksMut, + IndexedIter, + IndexedIterMut, + Iter, + IterMut, + Lanes, + LanesMut, + Windows, }; -use crate::slice::MultiSlice; -use crate::stacking::stack; -use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex}; +use crate::slice::{MultiSliceArg, SliceArg}; +use crate::stacking::concatenate; +use crate::{NdIndex, Slice, SliceInfoElem}; /// # Methods For All Array Types impl ArrayBase @@ -39,7 +68,8 @@ where D: Dimension, { /// Return the total number of elements in the array. - pub fn len(&self) -> usize { + pub fn len(&self) -> usize + { self.dim.size() } @@ -49,28 +79,33 @@ where /// number of dimensions (axes) of the array. /// /// ***Panics*** if the axis is out of bounds. - pub fn len_of(&self, axis: Axis) -> usize { + #[track_caller] + pub fn len_of(&self, axis: Axis) -> usize + { self.dim[axis.index()] } /// Return whether the array has any elements - pub fn is_empty(&self) -> bool { + pub fn is_empty(&self) -> bool + { self.len() == 0 } /// Return the number of dimensions (axes) in the array - pub fn ndim(&self) -> usize { + pub fn ndim(&self) -> usize + { self.dim.ndim() } /// Return the shape of the array in its “pattern” form, /// an integer in the one-dimensional case, tuple in the n-dimensional cases /// and so on. - pub fn dim(&self) -> D::Pattern { + pub fn dim(&self) -> D::Pattern + { self.dim.clone().into_pattern() } - /// Return the shape of the array as it stored in the array. + /// Return the shape of the array as it's stored in the array. /// /// This is primarily useful for passing to other `ArrayBase` /// functions, such as when creating another array of the same @@ -84,7 +119,8 @@ where /// // Create an array of zeros that's the same shape and dimensionality as `a`. /// let b = Array::::zeros(a.raw_dim()); /// ``` - pub fn raw_dim(&self) -> D { + pub fn raw_dim(&self) -> D + { self.dim.clone() } @@ -112,12 +148,14 @@ where /// let c = Array::zeros(a.raw_dim()); /// assert_eq!(a, c); /// ``` - pub fn shape(&self) -> &[usize] { + pub fn shape(&self) -> &[usize] + { self.dim.slice() } /// Return the strides of the array as a slice. - pub fn strides(&self) -> &[isize] { + pub fn strides(&self) -> &[isize] + { let s = self.strides.slice(); // reinterpret unsigned integer as signed unsafe { slice::from_raw_parts(s.as_ptr() as *const _, s.len()) } @@ -129,15 +167,16 @@ where /// number of dimensions (axes) of the array. /// /// ***Panics*** if the axis is out of bounds. - pub fn stride_of(&self, axis: Axis) -> isize { + #[track_caller] + pub fn stride_of(&self, axis: Axis) -> isize + { // strides are reinterpreted as isize self.strides[axis.index()] as isize } /// Return a read-only view of the array pub fn view(&self) -> ArrayView<'_, A, D> - where - S: Data, + where S: Data { debug_assert!(self.pointer_is_inbounds()); unsafe { ArrayView::new(self.ptr, self.dim.clone(), self.strides.clone()) } @@ -145,21 +184,33 @@ where /// Return a read-write view of the array pub fn view_mut(&mut self) -> ArrayViewMut<'_, A, D> - where - S: DataMut, + where S: DataMut { self.ensure_unique(); unsafe { ArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } } + /// Return a shared view of the array with elements as if they were embedded in cells. + /// + /// The cell view requires a mutable borrow of the array. Once borrowed the + /// cell view itself can be copied and accessed without exclusivity. + /// + /// The view acts "as if" the elements are temporarily in cells, and elements + /// can be changed through shared references using the regular cell methods. + pub fn cell_view(&mut self) -> ArrayView<'_, MathCell, D> + where S: DataMut + { + self.view_mut().into_cell_view() + } + /// Return an uniquely owned copy of the array. /// - /// If the input array is contiguous and its strides are positive, then the - /// output array will have the same memory layout. Otherwise, the layout of - /// the output array is unspecified. If you need a particular layout, you - /// can allocate a new array with the desired memory layout and - /// [`.assign()`](#method.assign) the data. Alternatively, you can collect - /// an iterator, like this for a result in standard layout: + /// If the input array is contiguous, then the output array will have the same + /// memory layout. Otherwise, the layout of the output array is unspecified. + /// If you need a particular layout, you can allocate a new array with the + /// desired memory layout and [`.assign()`](Self::assign) the data. + /// Alternatively, you can collectan iterator, like this for a result in + /// standard layout: /// /// ``` /// # use ndarray::prelude::*; @@ -188,25 +239,20 @@ where S: Data, { if let Some(slc) = self.as_slice_memory_order() { - unsafe { - Array::from_shape_vec_unchecked( - self.dim.clone().strides(self.strides.clone()), - slc.to_vec(), - ) - } + unsafe { Array::from_shape_vec_unchecked(self.dim.clone().strides(self.strides.clone()), slc.to_vec()) } } else { - self.map(|x| x.clone()) + self.map(A::clone) } } - /// Return a shared ownership (copy on write) array. + /// Return a shared ownership (copy on write) array, cloning the array + /// elements if necessary. pub fn to_shared(&self) -> ArcArray where A: Clone, S: Data, { - // FIXME: Avoid copying if it’s already an ArcArray. - self.to_owned().into_shared() + S::to_shared(self) } /// Turn the array into a uniquely owned array, cloning the array elements @@ -219,26 +265,64 @@ where S::into_owned(self) } + /// Converts the array into `Array` if this is possible without + /// cloning the array elements. Otherwise, returns `self` unchanged. + /// + /// ``` + /// use ndarray::{array, rcarr2, ArcArray2, Array2}; + /// + /// // Reference-counted, clone-on-write `ArcArray`. + /// let a: ArcArray2<_> = rcarr2(&[[1., 2.], [3., 4.]]); + /// { + /// // Another reference to the same data. + /// let b: ArcArray2<_> = a.clone(); + /// // Since there are two references to the same data, `.into_owned()` + /// // would require cloning the data, so `.try_into_owned_nocopy()` + /// // returns `Err`. + /// assert!(b.try_into_owned_nocopy().is_err()); + /// } + /// // Here, since the second reference has been dropped, the `ArcArray` + /// // can be converted into an `Array` without cloning the data. + /// let unique: Array2<_> = a.try_into_owned_nocopy().unwrap(); + /// assert_eq!(unique, array![[1., 2.], [3., 4.]]); + /// ``` + pub fn try_into_owned_nocopy(self) -> Result, Self> + where S: Data + { + S::try_into_owned_nocopy(self) + } + /// Turn the array into a shared ownership (copy on write) array, - /// without any copying. + /// cloning the array elements if necessary. + /// + /// If you want to generalize over `Array` and `ArcArray` inputs but avoid + /// an `A: Clone` bound, use `Into::>::into` instead of this + /// method. pub fn into_shared(self) -> ArcArray where + A: Clone, S: DataOwned, { - let data = self.data.into_shared(); - ArrayBase { - data, - ptr: self.ptr, - dim: self.dim, - strides: self.strides, - } + S::into_shared(self) } /// Returns a reference to the first element of the array, or `None` if it /// is empty. + /// + /// # Example + /// + /// ```rust + /// use ndarray::Array3; + /// + /// let mut a = Array3::::zeros([3, 4, 2]); + /// a[[0, 0, 0]] = 42.; + /// assert_eq!(a.first(), Some(&42.)); + /// + /// let b = Array3::::zeros([3, 0, 5]); + /// assert_eq!(b.first(), None); + /// ``` pub fn first(&self) -> Option<&A> - where - S: Data, + where S: Data { if self.is_empty() { None @@ -249,9 +333,21 @@ where /// Returns a mutable reference to the first element of the array, or /// `None` if it is empty. + /// + /// # Example + /// + /// ```rust + /// use ndarray::Array3; + /// + /// let mut a = Array3::::zeros([3, 4, 2]); + /// *a.first_mut().unwrap() = 42.; + /// assert_eq!(a[[0, 0, 0]], 42.); + /// + /// let mut b = Array3::::zeros([3, 0, 5]); + /// assert_eq!(b.first_mut(), None); + /// ``` pub fn first_mut(&mut self) -> Option<&mut A> - where - S: DataMut, + where S: DataMut { if self.is_empty() { None @@ -260,6 +356,65 @@ where } } + /// Returns a reference to the last element of the array, or `None` if it + /// is empty. + /// + /// # Example + /// + /// ```rust + /// use ndarray::Array3; + /// + /// let mut a = Array3::::zeros([3, 4, 2]); + /// a[[2, 3, 1]] = 42.; + /// assert_eq!(a.last(), Some(&42.)); + /// + /// let b = Array3::::zeros([3, 0, 5]); + /// assert_eq!(b.last(), None); + /// ``` + pub fn last(&self) -> Option<&A> + where S: Data + { + if self.is_empty() { + None + } else { + let mut index = self.raw_dim(); + for ax in 0..index.ndim() { + index[ax] -= 1; + } + Some(unsafe { self.uget(index) }) + } + } + + /// Returns a mutable reference to the last element of the array, or `None` + /// if it is empty. + /// + /// # Example + /// + /// ```rust + /// use ndarray::Array3; + /// + /// let mut a = Array3::::zeros([3, 4, 2]); + /// *a.last_mut().unwrap() = 42.; + /// assert_eq!(a[[2, 3, 1]], 42.); + /// + /// let mut b = Array3::::zeros([3, 0, 5]); + /// assert_eq!(b.last_mut(), None); + /// ``` + pub fn last_mut(&mut self) -> Option<&mut A> + where S: DataMut + { + if self.is_empty() { + None + } else { + self.ensure_unique(); + let mut index = self.raw_dim(); + for ax in 0..index.ndim() { + index[ax] -= 1; + } + Some(unsafe { self.uget_mut(index) }) + } + } + /// Return an iterator of references to the elements of the array. /// /// Elements are visited in the *logical order* of the array, which @@ -267,8 +422,7 @@ where /// /// Iterator element type is `&A`. pub fn iter(&self) -> Iter<'_, A, D> - where - S: Data, + where S: Data { debug_assert!(self.pointer_is_inbounds()); self.view().into_iter_() @@ -281,8 +435,7 @@ where /// /// Iterator element type is `&mut A`. pub fn iter_mut(&mut self) -> IterMut<'_, A, D> - where - S: DataMut, + where S: DataMut { self.view_mut().into_iter_() } @@ -294,10 +447,9 @@ where /// /// Iterator element type is `(D::Pattern, &A)`. /// - /// See also [`Zip::indexed`](struct.Zip.html) + /// See also [`Zip::indexed`] pub fn indexed_iter(&self) -> IndexedIter<'_, A, D> - where - S: Data, + where S: Data { IndexedIter::new(self.view().into_elements_base()) } @@ -309,8 +461,7 @@ where /// /// Iterator element type is `(D::Pattern, &mut A)`. pub fn indexed_iter_mut(&mut self) -> IndexedIterMut<'_, A, D> - where - S: DataMut, + where S: DataMut { IndexedIterMut::new(self.view_mut().into_elements_base()) } @@ -318,16 +469,14 @@ where /// Return a sliced view of the array. /// /// See [*Slicing*](#slicing) for full documentation. - /// See also [`SliceInfo`] and [`D::SliceArg`]. - /// - /// [`SliceInfo`]: struct.SliceInfo.html - /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice(&self, info: &SliceInfo) -> ArrayView<'_, A, Do> + #[track_caller] + pub fn slice(&self, info: I) -> ArrayView<'_, A, I::OutDim> where - Do: Dimension, + I: SliceArg, S: Data, { self.view().slice_move(info) @@ -336,16 +485,14 @@ where /// Return a sliced read-write view of the array. /// /// See [*Slicing*](#slicing) for full documentation. - /// See also [`SliceInfo`] and [`D::SliceArg`]. - /// - /// [`SliceInfo`]: struct.SliceInfo.html - /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice_mut(&mut self, info: &SliceInfo) -> ArrayViewMut<'_, A, Do> + #[track_caller] + pub fn slice_mut(&mut self, info: I) -> ArrayViewMut<'_, A, I::OutDim> where - Do: Dimension, + I: SliceArg, S: DataMut, { self.view_mut().slice_move(info) @@ -353,11 +500,9 @@ where /// Return multiple disjoint, sliced, mutable views of the array. /// - /// See [*Slicing*](#slicing) for full documentation. - /// See also [`SliceInfo`] and [`D::SliceArg`]. - /// - /// [`SliceInfo`]: struct.SliceInfo.html - /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// See [*Slicing*](#slicing) for full documentation. See also + /// [`MultiSliceArg`], [`s!`], [`SliceArg`], and + /// [`SliceInfo`](crate::SliceInfo). /// /// **Panics** if any of the following occur: /// @@ -376,9 +521,10 @@ where /// middle.fill(0); /// assert_eq!(a, arr2(&[[1, 0, 1], [1, 0, 1]])); /// ``` + #[track_caller] pub fn multi_slice_mut<'a, M>(&'a mut self, info: M) -> M::Output where - M: MultiSlice<'a, A, D>, + M: MultiSliceArg<'a, A, D>, S: DataMut, { info.multi_slice_move(self.view_mut()) @@ -387,92 +533,112 @@ where /// Slice the array, possibly changing the number of dimensions. /// /// See [*Slicing*](#slicing) for full documentation. - /// See also [`SliceInfo`] and [`D::SliceArg`]. - /// - /// [`SliceInfo`]: struct.SliceInfo.html - /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). /// /// **Panics** if an index is out of bounds or step size is zero.
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) - pub fn slice_move(mut self, info: &SliceInfo) -> ArrayBase - where - Do: Dimension, - { - // Slice and collapse in-place without changing the number of dimensions. - self.slice_collapse(&*info); - - let indices: &[SliceOrIndex] = (**info).as_ref(); - - // Copy the dim and strides that remain after removing the subview axes. + #[track_caller] + pub fn slice_move(mut self, info: I) -> ArrayBase + where I: SliceArg + { + assert_eq!( + info.in_ndim(), + self.ndim(), + "The input dimension of `info` must match the array to be sliced.", + ); let out_ndim = info.out_ndim(); - let mut new_dim = Do::zeros(out_ndim); - let mut new_strides = Do::zeros(out_ndim); - izip!(self.dim.slice(), self.strides.slice(), indices) - .filter_map(|(d, s, slice_or_index)| match slice_or_index { - SliceOrIndex::Slice { .. } => Some((d, s)), - SliceOrIndex::Index(_) => None, - }) - .zip(izip!(new_dim.slice_mut(), new_strides.slice_mut())) - .for_each(|((d, s), (new_d, new_s))| { - *new_d = *d; - *new_s = *s; - }); - - ArrayBase { - ptr: self.ptr, - data: self.data, - dim: new_dim, - strides: new_strides, - } + let mut new_dim = I::OutDim::zeros(out_ndim); + let mut new_strides = I::OutDim::zeros(out_ndim); + + let mut old_axis = 0; + let mut new_axis = 0; + info.as_ref().iter().for_each(|&ax_info| match ax_info { + SliceInfoElem::Slice { start, end, step } => { + // Slice the axis in-place to update the `dim`, `strides`, and `ptr`. + self.slice_axis_inplace(Axis(old_axis), Slice { start, end, step }); + // Copy the sliced dim and stride to corresponding axis. + new_dim[new_axis] = self.dim[old_axis]; + new_strides[new_axis] = self.strides[old_axis]; + old_axis += 1; + new_axis += 1; + } + SliceInfoElem::Index(index) => { + // Collapse the axis in-place to update the `ptr`. + let i_usize = abs_index(self.len_of(Axis(old_axis)), index); + self.collapse_axis(Axis(old_axis), i_usize); + // Skip copying the axis since it should be removed. Note that + // removing this axis is safe because `.collapse_axis()` panics + // if the index is out-of-bounds, so it will panic if the axis + // is zero length. + old_axis += 1; + } + SliceInfoElem::NewAxis => { + // Set the dim and stride of the new axis. + new_dim[new_axis] = 1; + new_strides[new_axis] = 0; + new_axis += 1; + } + }); + debug_assert_eq!(old_axis, self.ndim()); + debug_assert_eq!(new_axis, out_ndim); + + // safe because new dimension, strides allow access to a subset of old data + unsafe { self.with_strides_dim(new_strides, new_dim) } } /// Slice the array in place without changing the number of dimensions. /// - /// Note that [`&SliceInfo`](struct.SliceInfo.html) (produced by the - /// [`s![]`](macro.s!.html) macro) will usually coerce into `&D::SliceArg` - /// automatically, but in some cases (e.g. if `D` is `IxDyn`), you may need - /// to call `.as_ref()`. - /// - /// See [*Slicing*](#slicing) for full documentation. - /// See also [`D::SliceArg`]. - /// - /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// In particular, if an axis is sliced with an index, the axis is + /// collapsed, as in [`.collapse_axis()`], rather than removed, as in + /// [`.slice_move()`] or [`.index_axis_move()`]. /// - /// **Panics** if an index is out of bounds or step size is zero.
- /// (**Panics** if `D` is `IxDyn` and `indices` does not match the number of array axes.) - pub fn slice_collapse(&mut self, indices: &D::SliceArg) { - let indices: &[SliceOrIndex] = indices.as_ref(); - assert_eq!(indices.len(), self.ndim()); - indices - .iter() - .enumerate() - .for_each(|(axis, &slice_or_index)| match slice_or_index { - SliceOrIndex::Slice { start, end, step } => { - self.slice_axis_inplace(Axis(axis), Slice { start, end, step }) - } - SliceOrIndex::Index(index) => { - let i_usize = abs_index(self.len_of(Axis(axis)), index); - self.collapse_axis(Axis(axis), i_usize) - } - }); - } - - /// Slice the array in place without changing the number of dimensions. + /// [`.collapse_axis()`]: Self::collapse_axis + /// [`.slice_move()`]: Self::slice_move + /// [`.index_axis_move()`]: Self::index_axis_move /// - /// **Panics** if an index is out of bounds or step size is zero.
- /// (**Panics** if `D` is `IxDyn` and `indices` does not match the number of array axes.) - #[deprecated(note = "renamed to `slice_collapse`", since = "0.12.1")] - pub fn slice_inplace(&mut self, indices: &D::SliceArg) { - self.slice_collapse(indices) + /// See [*Slicing*](#slicing) for full documentation. + /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). + /// + /// **Panics** in the following cases: + /// + /// - if an index is out of bounds + /// - if a step size is zero + /// - if [`SliceInfoElem::NewAxis`] is in `info`, e.g. if [`NewAxis`] was + /// used in the [`s!`] macro + /// - if `D` is `IxDyn` and `info` does not match the number of array axes + #[track_caller] + pub fn slice_collapse(&mut self, info: I) + where I: SliceArg + { + assert_eq!( + info.in_ndim(), + self.ndim(), + "The input dimension of `info` must match the array to be sliced.", + ); + let mut axis = 0; + info.as_ref().iter().for_each(|&ax_info| match ax_info { + SliceInfoElem::Slice { start, end, step } => { + self.slice_axis_inplace(Axis(axis), Slice { start, end, step }); + axis += 1; + } + SliceInfoElem::Index(index) => { + let i_usize = abs_index(self.len_of(Axis(axis)), index); + self.collapse_axis(Axis(axis), i_usize); + axis += 1; + } + SliceInfoElem::NewAxis => panic!("`slice_collapse` does not support `NewAxis`."), + }); + debug_assert_eq!(axis, self.ndim()); } /// Return a view of the array, sliced along the specified axis. /// /// **Panics** if an index is out of bounds or step size is zero.
/// **Panics** if `axis` is out of bounds. + #[track_caller] + #[must_use = "slice_axis returns an array view with the sliced result"] pub fn slice_axis(&self, axis: Axis, indices: Slice) -> ArrayView<'_, A, D> - where - S: Data, + where S: Data { let mut view = self.view(); view.slice_axis_inplace(axis, indices); @@ -483,9 +649,10 @@ where /// /// **Panics** if an index is out of bounds or step size is zero.
/// **Panics** if `axis` is out of bounds. + #[track_caller] + #[must_use = "slice_axis_mut returns an array view with the sliced result"] pub fn slice_axis_mut(&mut self, axis: Axis, indices: Slice) -> ArrayViewMut<'_, A, D> - where - S: DataMut, + where S: DataMut { let mut view_mut = self.view_mut(); view_mut.slice_axis_inplace(axis, indices); @@ -496,18 +663,87 @@ where /// /// **Panics** if an index is out of bounds or step size is zero.
/// **Panics** if `axis` is out of bounds. - pub fn slice_axis_inplace(&mut self, axis: Axis, indices: Slice) { - let offset = do_slice( - &mut self.dim.slice_mut()[axis.index()], - &mut self.strides.slice_mut()[axis.index()], - indices, - ); + #[track_caller] + pub fn slice_axis_inplace(&mut self, axis: Axis, indices: Slice) + { + let offset = + do_slice(&mut self.dim.slice_mut()[axis.index()], &mut self.strides.slice_mut()[axis.index()], indices); unsafe { self.ptr = self.ptr.offset(offset); } debug_assert!(self.pointer_is_inbounds()); } + /// Slice the array in place along the specified axis, then return the sliced array. + /// + /// **Panics** if an index is out of bounds or step size is zero.
+ /// **Panics** if `axis` is out of bounds. + #[must_use = "slice_axis_move returns an array with the sliced result"] + pub fn slice_axis_move(mut self, axis: Axis, indices: Slice) -> Self + { + self.slice_axis_inplace(axis, indices); + self + } + + /// Return a view of a slice of the array, with a closure specifying the + /// slice for each axis. + /// + /// This is especially useful for code which is generic over the + /// dimensionality of the array. + /// + /// **Panics** if an index is out of bounds or step size is zero. + #[track_caller] + pub fn slice_each_axis(&self, f: F) -> ArrayView<'_, A, D> + where + F: FnMut(AxisDescription) -> Slice, + S: Data, + { + let mut view = self.view(); + view.slice_each_axis_inplace(f); + view + } + + /// Return a mutable view of a slice of the array, with a closure + /// specifying the slice for each axis. + /// + /// This is especially useful for code which is generic over the + /// dimensionality of the array. + /// + /// **Panics** if an index is out of bounds or step size is zero. + #[track_caller] + pub fn slice_each_axis_mut(&mut self, f: F) -> ArrayViewMut<'_, A, D> + where + F: FnMut(AxisDescription) -> Slice, + S: DataMut, + { + let mut view = self.view_mut(); + view.slice_each_axis_inplace(f); + view + } + + /// Slice the array in place, with a closure specifying the slice for each + /// axis. + /// + /// This is especially useful for code which is generic over the + /// dimensionality of the array. + /// + /// **Panics** if an index is out of bounds or step size is zero. + #[track_caller] + pub fn slice_each_axis_inplace(&mut self, mut f: F) + where F: FnMut(AxisDescription) -> Slice + { + for ax in 0..self.ndim() { + self.slice_axis_inplace( + Axis(ax), + f(AxisDescription { + axis: Axis(ax), + len: self.dim[ax], + stride: self.strides[ax] as isize, + }), + ) + } + } + /// Return a reference to the element at `index`, or return `None` /// if the index is out of bounds. /// @@ -528,15 +764,27 @@ where /// ``` pub fn get(&self, index: I) -> Option<&A> where - I: NdIndex, S: Data, + I: NdIndex, { unsafe { self.get_ptr(index).map(|ptr| &*ptr) } } - pub(crate) fn get_ptr(&self, index: I) -> Option<*const A> - where - I: NdIndex, + /// Return a raw pointer to the element at `index`, or return `None` + /// if the index is out of bounds. + /// + /// ``` + /// use ndarray::arr2; + /// + /// let a = arr2(&[[1., 2.], [3., 4.]]); + /// + /// let v = a.raw_view(); + /// let p = a.get_ptr((0, 1)).unwrap(); + /// + /// assert_eq!(unsafe { *p }, 2.); + /// ``` + pub fn get_ptr(&self, index: I) -> Option<*const A> + where I: NdIndex { let ptr = self.ptr; index @@ -551,10 +799,27 @@ where S: DataMut, I: NdIndex, { - unsafe { self.get_ptr_mut(index).map(|ptr| &mut *ptr) } + unsafe { self.get_mut_ptr(index).map(|ptr| &mut *ptr) } } - pub(crate) fn get_ptr_mut(&mut self, index: I) -> Option<*mut A> + /// Return a raw pointer to the element at `index`, or return `None` + /// if the index is out of bounds. + /// + /// ``` + /// use ndarray::arr2; + /// + /// let mut a = arr2(&[[1., 2.], [3., 4.]]); + /// + /// let v = a.raw_view_mut(); + /// let p = a.get_mut_ptr((0, 1)).unwrap(); + /// + /// unsafe { + /// *p = 5.; + /// } + /// + /// assert_eq!(a.get((0, 1)), Some(&5.)); + /// ``` + pub fn get_mut_ptr(&mut self, index: I) -> Option<*mut A> where S: RawDataMut, I: NdIndex, @@ -572,6 +837,10 @@ where /// Return a reference to the element at `index`. /// /// **Note:** only unchecked for non-debug builds of ndarray. + /// + /// # Safety + /// + /// The caller must ensure that the index is in-bounds. #[inline] pub unsafe fn uget(&self, index: I) -> &A where @@ -587,8 +856,16 @@ where /// /// Return a mutable reference to the element at `index`. /// - /// **Note:** Only unchecked for non-debug builds of ndarray.
- /// **Note:** (For `ArcArray`) The array must be uniquely held when mutating it. + /// **Note:** Only unchecked for non-debug builds of ndarray. + /// + /// # Safety + /// + /// The caller must ensure that: + /// + /// 1. the index is in-bounds and + /// + /// 2. the data is uniquely held by the array. (This property is guaranteed + /// for `Array` and `ArrayViewMut`, but not for `ArcArray` or `CowArray`.) #[inline] pub unsafe fn uget_mut(&mut self, index: I) -> &mut A where @@ -606,24 +883,40 @@ where /// Indices may be equal. /// /// ***Panics*** if an index is out of bounds. + #[track_caller] pub fn swap(&mut self, index1: I, index2: I) where S: DataMut, I: NdIndex, { - let ptr1: *mut _ = &mut self[index1]; - let ptr2: *mut _ = &mut self[index2]; - unsafe { - std_ptr::swap(ptr1, ptr2); + let ptr = self.as_mut_ptr(); + let offset1 = index1.index_checked(&self.dim, &self.strides); + let offset2 = index2.index_checked(&self.dim, &self.strides); + if let Some(offset1) = offset1 { + if let Some(offset2) = offset2 { + unsafe { + std::ptr::swap(ptr.offset(offset1), ptr.offset(offset2)); + } + return; + } } + panic!("swap: index out of bounds for indices {:?} {:?}", index1, index2); } /// Swap elements *unchecked* at indices `index1` and `index2`. /// /// Indices may be equal. /// - /// **Note:** only unchecked for non-debug builds of ndarray.
- /// **Note:** (For `ArcArray`) The array must be uniquely held. + /// **Note:** only unchecked for non-debug builds of ndarray. + /// + /// # Safety + /// + /// The caller must ensure that: + /// + /// 1. both `index1` and `index2` are in-bounds and + /// + /// 2. the data is uniquely held by the array. (This property is guaranteed + /// for `Array` and `ArrayViewMut`, but not for `ArcArray` or `CowArray`.) pub unsafe fn uswap(&mut self, index1: I, index2: I) where S: DataMut, @@ -634,17 +927,13 @@ where arraytraits::debug_bounds_check(self, &index2); let off1 = index1.index_unchecked(&self.strides); let off2 = index2.index_unchecked(&self.strides); - std_ptr::swap( - self.ptr.as_ptr().offset(off1), - self.ptr.as_ptr().offset(off2), - ); + std::ptr::swap(self.ptr.as_ptr().offset(off1), self.ptr.as_ptr().offset(off2)); } // `get` for zero-dimensional arrays // panics if dimension is not zero. otherwise an element is always present. fn get_0d(&self) -> &A - where - S: Data, + where S: Data { assert!(self.ndim() == 0); unsafe { &*self.as_ptr() } @@ -671,6 +960,7 @@ where /// a.index_axis(Axis(1), 1) == ArrayView::from(&[2., 4., 6.]) /// ); /// ``` + #[track_caller] pub fn index_axis(&self, axis: Axis, index: usize) -> ArrayView<'_, A, D::Smaller> where S: Data, @@ -703,6 +993,7 @@ where /// [3., 14.]]) /// ); /// ``` + #[track_caller] pub fn index_axis_mut(&mut self, axis: Axis, index: usize) -> ArrayViewMut<'_, A, D::Smaller> where S: DataMut, @@ -713,80 +1004,33 @@ where /// Collapses the array to `index` along the axis and removes the axis. /// - /// See [`.index_axis()`](#method.index_axis) and [*Subviews*](#subviews) for full documentation. + /// See [`.index_axis()`](Self::index_axis) and [*Subviews*](#subviews) for full documentation. /// /// **Panics** if `axis` or `index` is out of bounds. + #[track_caller] pub fn index_axis_move(mut self, axis: Axis, index: usize) -> ArrayBase - where - D: RemoveAxis, + where D: RemoveAxis { self.collapse_axis(axis, index); let dim = self.dim.remove_axis(axis); let strides = self.strides.remove_axis(axis); - ArrayBase { - ptr: self.ptr, - data: self.data, - dim, - strides, - } + // safe because new dimension, strides allow access to a subset of old data + unsafe { self.with_strides_dim(strides, dim) } } /// Selects `index` along the axis, collapsing the axis into length one. /// /// **Panics** if `axis` or `index` is out of bounds. - pub fn collapse_axis(&mut self, axis: Axis, index: usize) { + #[track_caller] + pub fn collapse_axis(&mut self, axis: Axis, index: usize) + { let offset = dimension::do_collapse_axis(&mut self.dim, &self.strides, axis.index(), index); self.ptr = unsafe { self.ptr.offset(offset) }; debug_assert!(self.pointer_is_inbounds()); } - /// Along `axis`, select the subview `index` and return a - /// view with that axis removed. - /// - /// **Panics** if `axis` or `index` is out of bounds. - #[deprecated(note = "renamed to `index_axis`", since = "0.12.1")] - pub fn subview(&self, axis: Axis, index: Ix) -> ArrayView<'_, A, D::Smaller> - where - S: Data, - D: RemoveAxis, - { - self.index_axis(axis, index) - } - - /// Along `axis`, select the subview `index` and return a read-write view - /// with the axis removed. - /// - /// **Panics** if `axis` or `index` is out of bounds. - #[deprecated(note = "renamed to `index_axis_mut`", since = "0.12.1")] - pub fn subview_mut(&mut self, axis: Axis, index: Ix) -> ArrayViewMut<'_, A, D::Smaller> - where - S: DataMut, - D: RemoveAxis, - { - self.index_axis_mut(axis, index) - } - - /// Collapse dimension `axis` into length one, - /// and select the subview of `index` along that axis. - /// - /// **Panics** if `index` is past the length of the axis. - #[deprecated(note = "renamed to `collapse_axis`", since = "0.12.1")] - pub fn subview_inplace(&mut self, axis: Axis, index: Ix) { - self.collapse_axis(axis, index) - } - - /// Along `axis`, select the subview `index` and return `self` - /// with that axis removed. - #[deprecated(note = "renamed to `index_axis_move`", since = "0.12.1")] - pub fn into_subview(self, axis: Axis, index: Ix) -> ArrayBase - where - D: RemoveAxis, - { - self.index_axis_move(axis, index) - } - /// Along `axis`, select arbitrary subviews corresponding to `indices` - /// and and copy them into a new array. + /// and copy them into a new array. /// /// **Panics** if `axis` or an element of `indices` is out of bounds. /// @@ -806,22 +1050,42 @@ where /// [6., 7.]]) ///); /// ``` + #[track_caller] pub fn select(&self, axis: Axis, indices: &[Ix]) -> Array where - A: Copy, + A: Clone, S: Data, D: RemoveAxis, { - let mut subs = vec![self.view(); indices.len()]; - for (&i, sub) in zip(indices, &mut subs[..]) { - sub.collapse_axis(axis, i); - } - if subs.is_empty() { - let mut dim = self.raw_dim(); - dim.set_axis(axis, 0); - unsafe { Array::from_shape_vec_unchecked(dim, vec![]) } + if self.ndim() == 1 { + // using .len_of(axis) means that we check if `axis` is in bounds too. + let axis_len = self.len_of(axis); + // bounds check the indices first + if let Some(max_index) = indices.iter().cloned().max() { + if max_index >= axis_len { + panic!("ndarray: index {} is out of bounds in array of len {}", + max_index, self.len_of(axis)); + } + } // else: indices empty is ok + let view = self.view().into_dimensionality::().unwrap(); + Array::from_iter(indices.iter().map(move |&index| { + // Safety: bounds checked indexes + unsafe { view.uget(index).clone() } + })) + .into_dimensionality::() + .unwrap() } else { - stack(axis, &subs).unwrap() + let mut subs = vec![self.view(); indices.len()]; + for (&i, sub) in zip(indices, &mut subs[..]) { + sub.collapse_axis(axis, i); + } + if subs.is_empty() { + let mut dim = self.raw_dim(); + dim.set_axis(axis, 0); + unsafe { Array::from_shape_vec_unchecked(dim, vec![]) } + } else { + concatenate(axis, &subs).unwrap() + } } } @@ -839,21 +1103,20 @@ where /// Iterator element is `ArrayView1
` (1D array view). /// /// ``` - /// use ndarray::{arr3, Axis, arr1}; + /// use ndarray::arr3; /// /// let a = arr3(&[[[ 0, 1, 2], // -- row 0, 0 /// [ 3, 4, 5]], // -- row 0, 1 /// [[ 6, 7, 8], // -- row 1, 0 /// [ 9, 10, 11]]]); // -- row 1, 1 /// - /// // `genrows` will yield the four generalized rows of the array. - /// for row in a.genrows() { + /// // `rows` will yield the four generalized rows of the array. + /// for row in a.rows() { /// /* loop body */ /// } /// ``` - pub fn genrows(&self) -> Lanes<'_, A, D::Smaller> - where - S: Data, + pub fn rows(&self) -> Lanes<'_, A, D::Smaller> + where S: Data { let mut n = self.ndim(); if n == 0 { @@ -866,9 +1129,8 @@ where /// rows of the array and yields mutable array views. /// /// Iterator element is `ArrayView1` (1D read-write array view). - pub fn genrows_mut(&mut self) -> LanesMut<'_, A, D::Smaller> - where - S: DataMut, + pub fn rows_mut(&mut self) -> LanesMut<'_, A, D::Smaller> + where S: DataMut { let mut n = self.ndim(); if n == 0 { @@ -891,21 +1153,20 @@ where /// Iterator element is `ArrayView1` (1D array view). /// /// ``` - /// use ndarray::{arr3, Axis, arr1}; + /// use ndarray::arr3; /// /// // The generalized columns of a 3D array: /// // are directed along the 0th axis: 0 and 6, 1 and 7 and so on... /// let a = arr3(&[[[ 0, 1, 2], [ 3, 4, 5]], /// [[ 6, 7, 8], [ 9, 10, 11]]]); /// - /// // Here `gencolumns` will yield the six generalized columns of the array. - /// for row in a.gencolumns() { + /// // Here `columns` will yield the six generalized columns of the array. + /// for column in a.columns() { /// /* loop body */ /// } /// ``` - pub fn gencolumns(&self) -> Lanes<'_, A, D::Smaller> - where - S: Data, + pub fn columns(&self) -> Lanes<'_, A, D::Smaller> + where S: Data { Lanes::new(self.view(), Axis(0)) } @@ -914,9 +1175,8 @@ where /// columns of the array and yields mutable array views. /// /// Iterator element is `ArrayView1` (1D read-write array view). - pub fn gencolumns_mut(&mut self) -> LanesMut<'_, A, D::Smaller> - where - S: DataMut, + pub fn columns_mut(&mut self) -> LanesMut<'_, A, D::Smaller> + where S: DataMut { LanesMut::new(self.view_mut(), Axis(0)) } @@ -924,7 +1184,7 @@ where /// Return a producer and iterable that traverses over all 1D lanes /// pointing in the direction of `axis`. /// - /// When the pointing in the direction of the first axis, they are *columns*, + /// When pointing in the direction of the first axis, they are *columns*, /// in the direction of the last axis *rows*; in general they are all /// *lanes* and are one dimensional. /// @@ -950,8 +1210,7 @@ where /// assert_eq!(inner2.into_iter().next().unwrap(), aview1(&[0, 1, 2])); /// ``` pub fn lanes(&self, axis: Axis) -> Lanes<'_, A, D::Smaller> - where - S: Data, + where S: Data { Lanes::new(self.view(), axis) } @@ -961,8 +1220,7 @@ where /// /// Iterator element is `ArrayViewMut1` (1D read-write array view). pub fn lanes_mut(&mut self, axis: Axis) -> LanesMut<'_, A, D::Smaller> - where - S: DataMut, + where S: DataMut { LanesMut::new(self.view_mut(), axis) } @@ -1012,6 +1270,7 @@ where /// **Panics** if `axis` is out of bounds. /// /// + #[track_caller] pub fn axis_iter(&self, axis: Axis) -> AxisIter<'_, A, D::Smaller> where S: Data, @@ -1027,6 +1286,7 @@ where /// (read-write array view). /// /// **Panics** if `axis` is out of bounds. + #[track_caller] pub fn axis_iter_mut(&mut self, axis: Axis) -> AxisIterMut<'_, A, D::Smaller> where S: DataMut, @@ -1048,9 +1308,8 @@ where /// ``` /// use ndarray::Array; /// use ndarray::{arr3, Axis}; - /// use std::iter::FromIterator; /// - /// let a = Array::from_iter(0..28).into_shape((2, 7, 2)).unwrap(); + /// let a = Array::from_iter(0..28).into_shape_with_order((2, 7, 2)).unwrap(); /// let mut iter = a.axis_chunks_iter(Axis(1), 2); /// /// // first iteration yields a 2 × 2 × 2 view @@ -1062,9 +1321,9 @@ where /// assert_eq!(iter.next_back().unwrap(), arr3(&[[[12, 13]], /// [[26, 27]]])); /// ``` + #[track_caller] pub fn axis_chunks_iter(&self, axis: Axis, size: usize) -> AxisChunksIter<'_, A, D> - where - S: Data, + where S: Data { AxisChunksIter::new(self.view(), axis, size) } @@ -1075,9 +1334,9 @@ where /// Iterator element is `ArrayViewMut` /// /// **Panics** if `axis` is out of bounds or if `size` is zero. + #[track_caller] pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut<'_, A, D> - where - S: DataMut, + where S: DataMut { AxisChunksIterMut::new(self.view_mut(), axis, size) } @@ -1093,6 +1352,7 @@ where /// **Panics** if any dimension of `chunk_size` is zero
/// (**Panics** if `D` is `IxDyn` and `chunk_size` does not match the /// number of array axes.) + #[track_caller] pub fn exact_chunks(&self, chunk_size: E) -> ExactChunks<'_, A, D> where E: IntoDimension, @@ -1133,6 +1393,7 @@ where /// [6, 6, 7, 7, 8, 8, 0], /// [6, 6, 7, 7, 8, 8, 0]])); /// ``` + #[track_caller] pub fn exact_chunks_mut(&mut self, chunk_size: E) -> ExactChunksMut<'_, A, D> where E: IntoDimension, @@ -1146,81 +1407,141 @@ where /// The windows are all distinct overlapping views of size `window_size` /// that fit into the array's shape. /// - /// Will yield over no elements if window size is larger - /// than the actual array size of any dimension. - /// + /// This is essentially equivalent to [`.windows_with_stride()`] with unit stride. + #[track_caller] + pub fn windows(&self, window_size: E) -> Windows<'_, A, D> + where + E: IntoDimension, + S: Data, + { + Windows::new(self.view(), window_size) + } + + /// Return a window producer and iterable. + /// + /// The windows are all distinct views of size `window_size` + /// that fit into the array's shape. + /// + /// The stride is ordered by the outermost axis.
+ /// Hence, a (x₀, x₁, ..., xₙ) stride will be applied to + /// (A₀, A₁, ..., Aₙ) where Aₓ stands for `Axis(x)`. + /// + /// This produces all windows that fit within the array for the given stride, + /// assuming the window size is not larger than the array size. + /// /// The produced element is an `ArrayView` with exactly the dimension /// `window_size`. /// - /// **Panics** if any dimension of `window_size` is zero.
- /// (**Panics** if `D` is `IxDyn` and `window_size` does not match the + /// Note that passing a stride of only ones is similar to + /// calling [`ArrayBase::windows()`]. + /// + /// **Panics** if any dimension of `window_size` or `stride` is zero.
+ /// (**Panics** if `D` is `IxDyn` and `window_size` or `stride` does not match the /// number of array axes.) /// - /// This is an illustration of the 2×2 windows in a 3×4 array: + /// This is the same illustration found in [`ArrayBase::windows()`], + /// 2×2 windows in a 3×4 array, but now with a (1, 2) stride: /// /// ```text /// ──▶ Axis(1) /// - /// │ ┏━━━━━┳━━━━━┱─────┬─────┐ ┌─────┲━━━━━┳━━━━━┱─────┐ ┌─────┬─────┲━━━━━┳━━━━━┓ - /// ▼ ┃ a₀₀ ┃ a₀₁ ┃ │ │ │ ┃ a₀₁ ┃ a₀₂ ┃ │ │ │ ┃ a₀₂ ┃ a₀₃ ┃ - /// Axis(0) ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────╊━━━━━╋━━━━━╉─────┤ ├─────┼─────╊━━━━━╋━━━━━┫ - /// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ ┃ a₁₁ ┃ a₁₂ ┃ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃ - /// ┡━━━━━╇━━━━━╃─────┼─────┤ ├─────╄━━━━━╇━━━━━╃─────┤ ├─────┼─────╄━━━━━╇━━━━━┩ - /// │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ - /// └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘ - /// - /// ┌─────┬─────┬─────┬─────┐ ┌─────┬─────┬─────┬─────┐ ┌─────┬─────┬─────┬─────┐ - /// │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ - /// ┢━━━━━╈━━━━━╅─────┼─────┤ ├─────╆━━━━━╈━━━━━╅─────┤ ├─────┼─────╆━━━━━╈━━━━━┪ - /// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ ┃ a₁₁ ┃ a₁₂ ┃ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃ - /// ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────╊━━━━━╋━━━━━╉─────┤ ├─────┼─────╊━━━━━╋━━━━━┫ - /// ┃ a₂₀ ┃ a₂₁ ┃ │ │ │ ┃ a₂₁ ┃ a₂₂ ┃ │ │ │ ┃ a₂₂ ┃ a₂₃ ┃ - /// ┗━━━━━┻━━━━━┹─────┴─────┘ └─────┺━━━━━┻━━━━━┹─────┘ └─────┴─────┺━━━━━┻━━━━━┛ + /// │ ┏━━━━━┳━━━━━┱─────┬─────┐ ┌─────┬─────┲━━━━━┳━━━━━┓ + /// ▼ ┃ a₀₀ ┃ a₀₁ ┃ │ │ │ │ ┃ a₀₂ ┃ a₀₃ ┃ + /// Axis(0) ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────┼─────╊━━━━━╋━━━━━┫ + /// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃ + /// ┡━━━━━╇━━━━━╃─────┼─────┤ ├─────┼─────╄━━━━━╇━━━━━┩ + /// │ │ │ │ │ │ │ │ │ │ + /// └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘ + /// + /// ┌─────┬─────┬─────┬─────┐ ┌─────┬─────┬─────┬─────┐ + /// │ │ │ │ │ │ │ │ │ │ + /// ┢━━━━━╈━━━━━╅─────┼─────┤ ├─────┼─────╆━━━━━╈━━━━━┪ + /// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃ + /// ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────┼─────╊━━━━━╋━━━━━┫ + /// ┃ a₂₀ ┃ a₂₁ ┃ │ │ │ │ ┃ a₂₂ ┃ a₂₃ ┃ + /// ┗━━━━━┻━━━━━┹─────┴─────┘ └─────┴─────┺━━━━━┻━━━━━┛ /// ``` - pub fn windows(&self, window_size: E) -> Windows<'_, A, D> + #[track_caller] + pub fn windows_with_stride(&self, window_size: E, stride: E) -> Windows<'_, A, D> where E: IntoDimension, S: Data, { - Windows::new(self.view(), window_size) + Windows::new_with_stride(self.view(), window_size, stride) + } + + /// Returns a producer which traverses over all windows of a given length along an axis. + /// + /// The windows are all distinct, possibly-overlapping views. The shape of each window + /// is the shape of `self`, with the length of `axis` replaced with `window_size`. + /// + /// **Panics** if `axis` is out-of-bounds or if `window_size` is zero. + /// + /// ``` + /// use ndarray::{Array3, Axis, s}; + /// + /// let arr = Array3::from_shape_fn([4, 5, 2], |(i, j, k)| i * 100 + j * 10 + k); + /// let correct = vec![ + /// arr.slice(s![.., 0..3, ..]), + /// arr.slice(s![.., 1..4, ..]), + /// arr.slice(s![.., 2..5, ..]), + /// ]; + /// for (window, correct) in arr.axis_windows(Axis(1), 3).into_iter().zip(&correct) { + /// assert_eq!(window, correct); + /// assert_eq!(window.shape(), &[4, 3, 2]); + /// } + /// ``` + pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D> + where S: Data + { + let axis_index = axis.index(); + + ndassert!( + axis_index < self.ndim(), + concat!( + "Window axis {} does not match array dimension {} ", + "(with array of shape {:?})" + ), + axis_index, + self.ndim(), + self.shape() + ); + + AxisWindows::new(self.view(), axis, window_size) } // Return (length, stride) for diagonal - fn diag_params(&self) -> (Ix, Ixs) { + fn diag_params(&self) -> (Ix, Ixs) + { /* empty shape has len 1 */ let len = self.dim.slice().iter().cloned().min().unwrap_or(1); let stride = self.strides().iter().sum(); (len, stride) } - /// Return an view of the diagonal elements of the array. + /// Return a view of the diagonal elements of the array. /// /// The diagonal is simply the sequence indexed by *(0, 0, .., 0)*, /// *(1, 1, ..., 1)* etc as long as all axes have elements. pub fn diag(&self) -> ArrayView1<'_, A> - where - S: Data, + where S: Data { self.view().into_diag() } /// Return a read-write view over the diagonal elements of the array. pub fn diag_mut(&mut self) -> ArrayViewMut1<'_, A> - where - S: DataMut, + where S: DataMut { self.view_mut().into_diag() } /// Return the diagonal as a one-dimensional array. - pub fn into_diag(self) -> ArrayBase { + pub fn into_diag(self) -> ArrayBase + { let (len, stride) = self.diag_params(); - ArrayBase { - data: self.data, - ptr: self.ptr, - dim: Ix1(len), - strides: Ix1(stride as Ix), - } + // safe because new len stride allows access to a subset of the current elements + unsafe { self.with_strides_dim(Ix1(stride as Ix), Ix1(len)) } } /// Try to make the array unshared. @@ -1229,8 +1550,7 @@ where /// /// This method is mostly only useful with unsafe code. fn try_ensure_unique(&mut self) - where - S: RawDataMut, + where S: RawDataMut { debug_assert!(self.pointer_is_inbounds()); S::try_ensure_unique(self); @@ -1241,8 +1561,7 @@ where /// /// This method is mostly only useful with unsafe code. fn ensure_unique(&mut self) - where - S: DataMut, + where S: DataMut { debug_assert!(self.pointer_is_inbounds()); S::ensure_unique(self); @@ -1252,29 +1571,16 @@ where /// Return `true` if the array data is laid out in contiguous “C order” in /// memory (where the last index is the most rapidly varying). /// - /// Return `false` otherwise, i.e the array is possibly not + /// Return `false` otherwise, i.e. the array is possibly not /// contiguous in memory, it has custom strides, etc. - pub fn is_standard_layout(&self) -> bool { - fn is_standard_layout(dim: &D, strides: &D) -> bool { - if let Some(1) = D::NDIM { - return strides[0] == 1 || dim[0] <= 1; - } - if dim.slice().iter().any(|&d| d == 0) { - return true; - } - let defaults = dim.default_strides(); - // check all dimensions -- a dimension of length 1 can have unequal strides - for (&dim, &s, &ds) in izip!(dim.slice(), strides.slice(), defaults.slice()) { - if dim != 1 && s != ds { - return false; - } - } - true - } - is_standard_layout(&self.dim, &self.strides) + pub fn is_standard_layout(&self) -> bool + { + dimension::is_layout_c(&self.dim, &self.strides) } - fn is_contiguous(&self) -> bool { + /// Return true if the array is known to be contiguous. + pub(crate) fn is_contiguous(&self) -> bool + { D::is_contiguous(&self.dim, &self.strides) } @@ -1308,15 +1614,15 @@ where if self.is_standard_layout() { CowArray::from(self.view()) } else { - let v: Vec
= self.iter().cloned().collect(); + let v = crate::iterators::to_vec_mapped(self.iter(), A::clone); let dim = self.dim.clone(); - assert_eq!(v.len(), dim.size()); - let owned_array: Array = unsafe { + debug_assert_eq!(v.len(), dim.size()); + + unsafe { // Safe because the shape and element type are from the existing array // and the strides are the default strides. - Array::from_shape_vec_unchecked(dim, v) - }; - CowArray::from(owned_array) + CowArray::from(Array::from_shape_vec_unchecked(dim, v)) + } } } @@ -1330,44 +1636,65 @@ where /// /// where *d* is `self.ndim()`. #[inline(always)] - pub fn as_ptr(&self) -> *const A { + pub fn as_ptr(&self) -> *const A + { self.ptr.as_ptr() as *const A } /// Return a mutable pointer to the first element in the array. + /// + /// This method attempts to unshare the data. If `S: DataMut`, then the + /// data is guaranteed to be uniquely held on return. + /// + /// # Warning + /// + /// When accessing elements through this pointer, make sure to use strides + /// obtained *after* calling this method, since the process of unsharing + /// the data may change the strides. #[inline(always)] pub fn as_mut_ptr(&mut self) -> *mut A - where - S: RawDataMut, + where S: RawDataMut { - self.try_ensure_unique(); // for RcArray + self.try_ensure_unique(); // for ArcArray self.ptr.as_ptr() } /// Return a raw view of the array. #[inline] - pub fn raw_view(&self) -> RawArrayView { + pub fn raw_view(&self) -> RawArrayView + { unsafe { RawArrayView::new(self.ptr, self.dim.clone(), self.strides.clone()) } } /// Return a raw mutable view of the array. + /// + /// This method attempts to unshare the data. If `S: DataMut`, then the + /// data is guaranteed to be uniquely held on return. #[inline] pub fn raw_view_mut(&mut self) -> RawArrayViewMut - where - S: RawDataMut, + where S: RawDataMut { - self.try_ensure_unique(); // for RcArray + self.try_ensure_unique(); // for ArcArray unsafe { RawArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } } + /// Return a raw mutable view of the array. + /// + /// Safety: The caller must ensure that the owned array is unshared when this is called + #[inline] + pub(crate) unsafe fn raw_view_mut_unchecked(&mut self) -> RawArrayViewMut + where S: DataOwned + { + RawArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) + } + /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. /// /// If this function returns `Some(_)`, then the element order in the slice /// corresponds to the logical order of the array’s elements. pub fn as_slice(&self) -> Option<&[A]> - where - S: Data, + where S: Data { if self.is_standard_layout() { unsafe { Some(slice::from_raw_parts(self.ptr.as_ptr(), self.len())) } @@ -1379,8 +1706,7 @@ where /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. pub fn as_slice_mut(&mut self) -> Option<&mut [A]> - where - S: DataMut, + where S: DataMut { if self.is_standard_layout() { self.ensure_unique(); @@ -1395,14 +1721,12 @@ where /// /// If this function returns `Some(_)`, then the elements in the slice /// have whatever order the elements have in memory. - /// - /// Implementation notes: Does not yet support negatively strided arrays. pub fn as_slice_memory_order(&self) -> Option<&[A]> - where - S: Data, + where S: Data { if self.is_contiguous() { - unsafe { Some(slice::from_raw_parts(self.ptr.as_ptr(), self.len())) } + let offset = offset_from_low_addr_ptr_to_logical_ptr(&self.dim, &self.strides); + unsafe { Some(slice::from_raw_parts(self.ptr.sub(offset).as_ptr(), self.len())) } } else { None } @@ -1410,21 +1734,217 @@ where /// Return the array’s data as a slice if it is contiguous, /// return `None` otherwise. + /// + /// In the contiguous case, in order to return a unique reference, this + /// method unshares the data if necessary, but it preserves the existing + /// strides. pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]> - where - S: DataMut, + where S: DataMut + { + self.try_as_slice_memory_order_mut().ok() + } + + /// Return the array’s data as a slice if it is contiguous, otherwise + /// return `self` in the `Err` variant. + pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self> + where S: DataMut { if self.is_contiguous() { self.ensure_unique(); - unsafe { Some(slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len())) } + let offset = offset_from_low_addr_ptr_to_logical_ptr(&self.dim, &self.strides); + unsafe { Ok(slice::from_raw_parts_mut(self.ptr.sub(offset).as_ptr(), self.len())) } } else { - None + Err(self) + } + } + + /// Transform the array into `new_shape`; any shape with the same number of elements is + /// accepted. + /// + /// `order` specifies the *logical* order in which the array is to be read and reshaped. + /// The array is returned as a `CowArray`; a view if possible, otherwise an owned array. + /// + /// For example, when starting from the one-dimensional sequence 1 2 3 4 5 6, it would be + /// understood as a 2 x 3 array in row major ("C") order this way: + /// + /// ```text + /// 1 2 3 + /// 4 5 6 + /// ``` + /// + /// and as 2 x 3 in column major ("F") order this way: + /// + /// ```text + /// 1 3 5 + /// 2 4 6 + /// ``` + /// + /// This example should show that any time we "reflow" the elements in the array to a different + /// number of rows and columns (or more axes if applicable), it is important to pick an index + /// ordering, and that's the reason for the function parameter for `order`. + /// + /// The `new_shape` parameter should be a dimension and an optional order like these examples: + /// + /// ```text + /// (3, 4) // Shape 3 x 4 with default order (RowMajor) + /// ((3, 4), Order::RowMajor)) // use specific order + /// ((3, 4), Order::ColumnMajor)) // use specific order + /// ((3, 4), Order::C)) // use shorthand for order - shorthands C and F + /// ``` + /// + /// **Errors** if the new shape doesn't have the same number of elements as the array's current + /// shape. + /// + /// # Example + /// + /// ``` + /// use ndarray::array; + /// use ndarray::Order; + /// + /// assert!( + /// array![1., 2., 3., 4., 5., 6.].to_shape(((2, 3), Order::RowMajor)).unwrap() + /// == array![[1., 2., 3.], + /// [4., 5., 6.]] + /// ); + /// + /// assert!( + /// array![1., 2., 3., 4., 5., 6.].to_shape(((2, 3), Order::ColumnMajor)).unwrap() + /// == array![[1., 3., 5.], + /// [2., 4., 6.]] + /// ); + /// ``` + pub fn to_shape(&self, new_shape: E) -> Result, ShapeError> + where + E: ShapeArg, + A: Clone, + S: Data, + { + let (shape, order) = new_shape.into_shape_and_order(); + self.to_shape_order(shape, order.unwrap_or(Order::RowMajor)) + } + + fn to_shape_order(&self, shape: E, order: Order) -> Result, ShapeError> + where + E: Dimension, + A: Clone, + S: Data, + { + let len = self.dim.size(); + if size_of_shape_checked(&shape) != Ok(len) { + return Err(error::incompatible_shapes(&self.dim, &shape)); + } + + // Create a view if the length is 0, safe because the array and new shape is empty. + if len == 0 { + unsafe { + return Ok(CowArray::from(ArrayView::from_shape_ptr(shape, self.as_ptr()))); + } + } + + // Try to reshape the array as a view into the existing data + match reshape_dim(&self.dim, &self.strides, &shape, order) { + Ok(to_strides) => unsafe { + return Ok(CowArray::from(ArrayView::new(self.ptr, shape, to_strides))); + }, + Err(err) if err.kind() == ErrorKind::IncompatibleShape => { + return Err(error::incompatible_shapes(&self.dim, &shape)); + } + _otherwise => {} + } + + // otherwise create a new array and copy the elements + unsafe { + let (shape, view) = match order { + Order::RowMajor => (shape.set_f(false), self.view()), + Order::ColumnMajor => (shape.set_f(true), self.t()), + }; + Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked(shape, view.into_iter(), A::clone))) + } + } + + /// Transform the array into `shape`; any shape with the same number of + /// elements is accepted, but the source array must be contiguous. + /// + /// If an index ordering is not specified, the default is `RowMajor`. + /// The operation will only succeed if the array's memory layout is compatible with + /// the index ordering, so that the array elements can be rearranged in place. + /// + /// If required use `.to_shape()` or `.into_shape_clone` instead for more flexible reshaping of + /// arrays, which allows copying elements if required. + /// + /// **Errors** if the shapes don't have the same number of elements.
+ /// **Errors** if order RowMajor is given but input is not c-contiguous. + /// **Errors** if order ColumnMajor is given but input is not f-contiguous. + /// + /// If shape is not given: use memory layout of incoming array. Row major arrays are + /// reshaped using row major index ordering, column major arrays with column major index + /// ordering. + /// + /// The `new_shape` parameter should be a dimension and an optional order like these examples: + /// + /// ```text + /// (3, 4) // Shape 3 x 4 with default order (RowMajor) + /// ((3, 4), Order::RowMajor)) // use specific order + /// ((3, 4), Order::ColumnMajor)) // use specific order + /// ((3, 4), Order::C)) // use shorthand for order - shorthands C and F + /// ``` + /// + /// # Example + /// + /// ``` + /// use ndarray::{aview1, aview2}; + /// use ndarray::Order; + /// + /// assert!( + /// aview1(&[1., 2., 3., 4.]).into_shape_with_order((2, 2)).unwrap() + /// == aview2(&[[1., 2.], + /// [3., 4.]]) + /// ); + /// + /// assert!( + /// aview1(&[1., 2., 3., 4.]).into_shape_with_order(((2, 2), Order::ColumnMajor)).unwrap() + /// == aview2(&[[1., 3.], + /// [2., 4.]]) + /// ); + /// ``` + pub fn into_shape_with_order(self, shape: E) -> Result, ShapeError> + where E: ShapeArg + { + let (shape, order) = shape.into_shape_and_order(); + self.into_shape_with_order_impl(shape, order.unwrap_or(Order::RowMajor)) + } + + fn into_shape_with_order_impl(self, shape: E, order: Order) -> Result, ShapeError> + where E: Dimension + { + let shape = shape.into_dimension(); + if size_of_shape_checked(&shape) != Ok(self.dim.size()) { + return Err(error::incompatible_shapes(&self.dim, &shape)); + } + + // Check if contiguous, then we can change shape + unsafe { + // safe because arrays are contiguous and len is unchanged + match order { + Order::RowMajor if self.is_standard_layout() => + Ok(self.with_strides_dim(shape.default_strides(), shape)), + Order::ColumnMajor if self.raw_view().reversed_axes().is_standard_layout() => + Ok(self.with_strides_dim(shape.fortran_strides(), shape)), + _otherwise => Err(error::from_kind(error::ErrorKind::IncompatibleLayout)), + } } } /// Transform the array into `shape`; any shape with the same number of - /// elements is accepted, but the source array or view must be - /// contiguous, otherwise we cannot rearrange the dimension. + /// elements is accepted, but the source array or view must be in standard + /// or column-major (Fortran) layout. + /// + /// **Note** that `.into_shape()` "moves" elements differently depending on if the input array + /// is C-contig or F-contig, it follows the index order that corresponds to the memory order. + /// Prefer to use `.to_shape()` or `.into_shape_with_order()`. + /// + /// Because of this, the method **is deprecated**. That reshapes depend on memory order is not + /// intuitive. /// /// **Errors** if the shapes don't have the same number of elements.
/// **Errors** if the input array is not c- or f-contiguous. @@ -1438,35 +1958,92 @@ where /// [3., 4.]]) /// ); /// ``` + #[deprecated(note = "Use `.into_shape_with_order()` or `.to_shape()`", since = "0.16.0")] pub fn into_shape(self, shape: E) -> Result, ShapeError> - where - E: IntoDimension, + where E: IntoDimension { let shape = shape.into_dimension(); if size_of_shape_checked(&shape) != Ok(self.dim.size()) { return Err(error::incompatible_shapes(&self.dim, &shape)); } // Check if contiguous, if not => copy all, else just adapt strides - if self.is_standard_layout() { - Ok(ArrayBase { - data: self.data, - ptr: self.ptr, - strides: shape.default_strides(), - dim: shape, - }) - } else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() { - Ok(ArrayBase { - data: self.data, - ptr: self.ptr, - strides: shape.fortran_strides(), - dim: shape, - }) - } else { - Err(error::from_kind(error::ErrorKind::IncompatibleLayout)) + unsafe { + // safe because arrays are contiguous and len is unchanged + if self.is_standard_layout() { + Ok(self.with_strides_dim(shape.default_strides(), shape)) + } else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() { + Ok(self.with_strides_dim(shape.fortran_strides(), shape)) + } else { + Err(error::from_kind(error::ErrorKind::IncompatibleLayout)) + } } } - /// *Note: Reshape is for `ArcArray` only. Use `.into_shape()` for + /// Transform the array into `shape`; any shape with the same number of + /// elements is accepted. Array elements are reordered in place if + /// possible, otherwise they are copied to create a new array. + /// + /// If an index ordering is not specified, the default is `RowMajor`. + /// + /// # `.to_shape` vs `.into_shape_clone` + /// + /// - `to_shape` supports views and outputting views + /// - `to_shape` borrows the original array, `into_shape_clone` consumes the original + /// - `into_shape_clone` preserves array type (Array vs ArcArray), but does not support views. + /// + /// **Errors** if the shapes don't have the same number of elements.
+ pub fn into_shape_clone(self, shape: E) -> Result, ShapeError> + where + S: DataOwned, + A: Clone, + E: ShapeArg, + { + let (shape, order) = shape.into_shape_and_order(); + let order = order.unwrap_or(Order::RowMajor); + self.into_shape_clone_order(shape, order) + } + + fn into_shape_clone_order(self, shape: E, order: Order) -> Result, ShapeError> + where + S: DataOwned, + A: Clone, + E: Dimension, + { + let len = self.dim.size(); + if size_of_shape_checked(&shape) != Ok(len) { + return Err(error::incompatible_shapes(&self.dim, &shape)); + } + + // Safe because the array and new shape is empty. + if len == 0 { + unsafe { + return Ok(self.with_strides_dim(shape.default_strides(), shape)); + } + } + + // Try to reshape the array's current data + match reshape_dim(&self.dim, &self.strides, &shape, order) { + Ok(to_strides) => unsafe { + return Ok(self.with_strides_dim(to_strides, shape)); + }, + Err(err) if err.kind() == ErrorKind::IncompatibleShape => { + return Err(error::incompatible_shapes(&self.dim, &shape)); + } + _otherwise => {} + } + + // otherwise, clone and allocate a new array + unsafe { + let (shape, view) = match order { + Order::RowMajor => (shape.set_f(false), self.view()), + Order::ColumnMajor => (shape.set_f(true), self.t()), + }; + + Ok(ArrayBase::from_shape_trusted_iter_unchecked(shape, view.into_iter(), A::clone)) + } + } + + /// *Note: Reshape is for `ArcArray` only. Use `.into_shape_with_order()` for /// other arrays and array views.* /// /// Transform the array into `shape`; any shape with the same number of @@ -1477,6 +2054,9 @@ where /// /// **Panics** if shapes are incompatible. /// + /// *This method is obsolete, because it is inflexible in how logical order + /// of the array is handled. See [`.to_shape()`].* + /// /// ``` /// use ndarray::{rcarr1, rcarr2}; /// @@ -1486,6 +2066,8 @@ where /// [3., 4.]]) /// ); /// ``` + #[track_caller] + #[deprecated(note = "Use `.into_shape_with_order()` or `.to_shape()`", since = "0.16.0")] pub fn reshape(&self, shape: E) -> ArrayBase where S: DataShared + DataOwned, @@ -1503,18 +2085,77 @@ where // Check if contiguous, if not => copy all, else just adapt strides if self.is_standard_layout() { let cl = self.clone(); - ArrayBase { - data: cl.data, - ptr: cl.ptr, - strides: shape.default_strides(), - dim: shape, - } + // safe because array is contiguous and shape has equal number of elements + unsafe { cl.with_strides_dim(shape.default_strides(), shape) } } else { let v = self.iter().cloned().collect::>(); unsafe { ArrayBase::from_shape_vec_unchecked(shape, v) } } } + /// Flatten the array to a one-dimensional array. + /// + /// The array is returned as a `CowArray`; a view if possible, otherwise an owned array. + /// + /// ``` + /// use ndarray::{arr1, arr3}; + /// + /// let array = arr3(&[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + /// let flattened = array.flatten(); + /// assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); + /// ``` + pub fn flatten(&self) -> CowArray<'_, A, Ix1> + where + A: Clone, + S: Data, + { + self.flatten_with_order(Order::RowMajor) + } + + /// Flatten the array to a one-dimensional array. + /// + /// `order` specifies the *logical* order in which the array is to be read and reshaped. + /// The array is returned as a `CowArray`; a view if possible, otherwise an owned array. + /// + /// ``` + /// use ndarray::{arr1, arr2}; + /// use ndarray::Order; + /// + /// let array = arr2(&[[1, 2], [3, 4], [5, 6], [7, 8]]); + /// let flattened = array.flatten_with_order(Order::RowMajor); + /// assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); + /// let flattened = array.flatten_with_order(Order::ColumnMajor); + /// assert_eq!(flattened, arr1(&[1, 3, 5, 7, 2, 4, 6, 8])); + /// ``` + pub fn flatten_with_order(&self, order: Order) -> CowArray<'_, A, Ix1> + where + A: Clone, + S: Data, + { + self.to_shape((self.len(), order)).unwrap() + } + + /// Flatten the array to a one-dimensional array, consuming the array. + /// + /// If possible, no copy is made, and the new array use the same memory as the original array. + /// Otherwise, a new array is allocated and the elements are copied. + /// + /// ``` + /// use ndarray::{arr1, arr3}; + /// + /// let array = arr3(&[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + /// let flattened = array.into_flat(); + /// assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); + /// ``` + pub fn into_flat(self) -> ArrayBase + where + A: Clone, + S: DataOwned, + { + let len = self.len(); + self.into_shape_clone(Ix1(len)).unwrap() + } + /// Convert any array or array view to a dynamic dimensional array or /// array view (respectively). /// @@ -1524,17 +2165,19 @@ where /// let array: ArrayD = arr2(&[[1, 2], /// [3, 4]]).into_dyn(); /// ``` - pub fn into_dyn(self) -> ArrayBase { - ArrayBase { - data: self.data, - ptr: self.ptr, - dim: self.dim.into_dyn(), - strides: self.strides.into_dyn(), + pub fn into_dyn(self) -> ArrayBase + { + // safe because new dims equivalent + unsafe { + ArrayBase::from_data_ptr(self.data, self.ptr).with_strides_dim(self.strides.into_dyn(), self.dim.into_dyn()) } } - /// Convert an array or array view to another with the same type, but - /// different dimensionality type. Errors if the dimensions don't agree. + /// Convert an array or array view to another with the same type, but different dimensionality + /// type. Errors if the dimensions don't agree (the number of axes must match). + /// + /// Note that conversion to a dynamic dimensional array will never fail (and is equivalent to + /// the `into_dyn` method). /// /// ``` /// use ndarray::{ArrayD, Ix2, IxDyn}; @@ -1547,17 +2190,22 @@ where /// assert!(array.into_dimensionality::().is_ok()); /// ``` pub fn into_dimensionality(self) -> Result, ShapeError> - where - D2: Dimension, - { - if let Some(dim) = D2::from_dimension(&self.dim) { - if let Some(strides) = D2::from_dimension(&self.strides) { - return Ok(ArrayBase { - data: self.data, - ptr: self.ptr, - dim, - strides, - }); + where D2: Dimension + { + unsafe { + if D::NDIM == D2::NDIM { + // safe because D == D2 + let dim = unlimited_transmute::(self.dim); + let strides = unlimited_transmute::(self.strides); + return Ok(ArrayBase::from_data_ptr(self.data, self.ptr).with_strides_dim(strides, dim)); + } else if D::NDIM.is_none() || D2::NDIM.is_none() { + // one is dynamic dim + // safe because dim, strides are equivalent under a different type + if let Some(dim) = D2::from_dimension(&self.dim) { + if let Some(strides) = D2::from_dimension(&self.strides) { + return Ok(self.with_strides_dim(strides, dim)); + } + } } } Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) @@ -1605,7 +2253,8 @@ where /// /// **Note:** Cannot be used for mutable iterators, since repeating /// elements would create aliasing pointers. - fn upcast(to: &D, from: &E, stride: &E) -> Option { + fn upcast(to: &D, from: &E, stride: &E) -> Option + { // Make sure the product of non-zero axis lengths does not exceed // `isize::MAX`. This is the only safety check we need to perform // because all the other constraints of `ArrayBase` are guaranteed @@ -1657,6 +2306,43 @@ where unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) } } + /// For two arrays or views, find their common shape if possible and + /// broadcast them as array views into that shape. + /// + /// Return `ShapeError` if their shapes can not be broadcast together. + #[allow(clippy::type_complexity)] + pub(crate) fn broadcast_with<'a, 'b, B, S2, E>( + &'a self, other: &'b ArrayBase, + ) -> Result<(ArrayView<'a, A, DimMaxOf>, ArrayView<'b, B, DimMaxOf>), ShapeError> + where + S: Data, + S2: Data, + D: Dimension + DimMax, + E: Dimension, + { + let shape = co_broadcast::>::Output>(&self.dim, &other.dim)?; + let view1 = if shape.slice() == self.dim.slice() { + self.view() + .into_dimensionality::<>::Output>() + .unwrap() + } else if let Some(view1) = self.broadcast(shape.clone()) { + view1 + } else { + return Err(from_kind(ErrorKind::IncompatibleShape)); + }; + let view2 = if shape.slice() == other.dim.slice() { + other + .view() + .into_dimensionality::<>::Output>() + .unwrap() + } else if let Some(view2) = other.broadcast(shape) { + view2 + } else { + return Err(from_kind(ErrorKind::IncompatibleShape)); + }; + Ok((view1, view2)) + } + /// Swap axes `ax` and `bx`. /// /// This does not move any data, it just adjusts the array’s dimensions @@ -1673,7 +2359,9 @@ where /// a == arr2(&[[1.], [2.], [3.]]) /// ); /// ``` - pub fn swap_axes(&mut self, ax: usize, bx: usize) { + #[track_caller] + pub fn swap_axes(&mut self, ax: usize, bx: usize) + { self.dim.slice_mut().swap(ax, bx); self.strides.slice_mut().swap(ax, bx); } @@ -1700,9 +2388,9 @@ where /// let b = Array3::::zeros((1, 2, 3)); /// assert_eq!(b.permuted_axes([1, 0, 2]).shape(), &[2, 1, 3]); /// ``` + #[track_caller] pub fn permuted_axes(self, axes: T) -> ArrayBase - where - T: IntoDimension, + where T: IntoDimension { let axes = axes.into_dimension(); // Ensure that each axis is used exactly once. @@ -1724,18 +2412,16 @@ where new_strides[new_axis] = strides[axis]; } } - ArrayBase { - dim: new_dim, - strides: new_strides, - ..self - } + // safe because axis invariants are checked above; they are a permutation of the old + unsafe { self.with_strides_dim(new_strides, new_dim) } } /// Transpose the array by reversing axes. /// /// Transposition reverses the order of the axes (dimensions and strides) /// while retaining the same data. - pub fn reversed_axes(mut self) -> ArrayBase { + pub fn reversed_axes(mut self) -> ArrayBase + { self.dim.slice_mut().reverse(); self.strides.slice_mut().reverse(); self @@ -1747,14 +2433,14 @@ where /// /// See also the more general methods `.reversed_axes()` and `.swap_axes()`. pub fn t(&self) -> ArrayView<'_, A, D> - where - S: Data, + where S: Data { self.view().reversed_axes() } /// Return an iterator over the length and stride of each axis. - pub fn axes(&self) -> Axes<'_, D> { + pub fn axes(&self) -> Axes<'_, D> + { axes_of(&self.dim, &self.strides) } @@ -1767,14 +2453,17 @@ where /// Return the axis with the greatest stride (by absolute value), /// preferring axes with len > 1. - pub fn max_stride_axis(&self) -> Axis { + pub fn max_stride_axis(&self) -> Axis + { self.dim.max_stride_axis(&self.strides) } /// Reverse the stride of `axis`. /// /// ***Panics*** if the axis is out of bounds. - pub fn invert_axis(&mut self, axis: Axis) { + #[track_caller] + pub fn invert_axis(&mut self, axis: Axis) + { unsafe { let s = self.strides.axis(axis) as Ixs; let m = self.dim.axis(axis); @@ -1820,7 +2509,9 @@ where /// ``` /// /// ***Panics*** if an axis is out of bounds. - pub fn merge_axes(&mut self, take: Axis, into: Axis) -> bool { + #[track_caller] + pub fn merge_axes(&mut self, take: Axis, into: Axis) -> bool + { merge_axes(&mut self.dim, &mut self.strides, take, into) } @@ -1845,45 +2536,34 @@ where /// ``` /// /// ***Panics*** if the axis is out of bounds. - pub fn insert_axis(self, axis: Axis) -> ArrayBase { + #[track_caller] + pub fn insert_axis(self, axis: Axis) -> ArrayBase + { assert!(axis.index() <= self.ndim()); - let ArrayBase { - ptr, - data, - dim, - strides, - } = self; - ArrayBase { - ptr, - data, - dim: dim.insert_axis(axis), - strides: strides.insert_axis(axis), + // safe because a new axis of length one does not affect memory layout + unsafe { + let strides = self.strides.insert_axis(axis); + let dim = self.dim.insert_axis(axis); + self.with_strides_dim(strides, dim) } } /// Remove array axis `axis` and return the result. /// + /// This is equivalent to `.index_axis_move(axis, 0)` and makes most sense to use if the + /// axis to remove is of length 1. + /// /// **Panics** if the axis is out of bounds or its length is zero. - #[deprecated(note = "use `.index_axis_move(Axis(_), 0)` instead", since = "0.12.1")] + #[track_caller] pub fn remove_axis(self, axis: Axis) -> ArrayBase - where - D: RemoveAxis, + where D: RemoveAxis { self.index_axis_move(axis, 0) } - fn pointer_is_inbounds(&self) -> bool { - match self.data._data_slice() { - None => { - // special case for non-owned views - true - } - Some(slc) => { - let ptr = slc.as_ptr() as *mut A; - let end = unsafe { ptr.add(slc.len()) }; - self.ptr.as_ptr() >= ptr && self.ptr.as_ptr() <= end - } - } + pub(crate) fn pointer_is_inbounds(&self) -> bool + { + self.data._is_pointer_inbounds(self.as_ptr()) } /// Perform an elementwise assigment to `self` from `rhs`. @@ -1891,13 +2571,31 @@ where /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. /// /// **Panics** if broadcasting isn’t possible. + #[track_caller] pub fn assign(&mut self, rhs: &ArrayBase) where S: DataMut, A: Clone, S2: Data, { - self.zip_mut_with(rhs, |x, y| *x = y.clone()); + self.zip_mut_with(rhs, |x, y| x.clone_from(y)); + } + + /// Perform an elementwise assigment of values cloned from `self` into array or producer `to`. + /// + /// The destination `to` can be another array or a producer of assignable elements. + /// [`AssignElem`] determines how elements are assigned. + /// + /// **Panics** if shapes disagree. + #[track_caller] + pub fn assign_to

(&self, to: P) + where + S: Data, + P: IntoNdProducer, + P::Item: AssignElem, + A: Clone, + { + Zip::from(self).map_assign_into(to, A::clone); } /// Perform an elementwise assigment to `self` from element `x`. @@ -1906,10 +2604,10 @@ where S: DataMut, A: Clone, { - self.unordered_foreach_mut(move |elt| *elt = x.clone()); + self.map_inplace(move |elt| elt.clone_from(&x)); } - fn zip_mut_with_same_shape(&mut self, rhs: &ArrayBase, mut f: F) + pub(crate) fn zip_mut_with_same_shape(&mut self, rhs: &ArrayBase, mut f: F) where S: DataMut, S2: Data, @@ -1917,18 +2615,19 @@ where F: FnMut(&mut A, &B), { debug_assert_eq!(self.shape(), rhs.shape()); - if let Some(self_s) = self.as_slice_mut() { - if let Some(rhs_s) = rhs.as_slice() { - let len = cmp::min(self_s.len(), rhs_s.len()); - let s = &mut self_s[..len]; - let r = &rhs_s[..len]; - for i in 0..len { - f(&mut s[i], &r[i]); + + if self.dim.strides_equivalent(&self.strides, &rhs.strides) { + if let Some(self_s) = self.as_slice_memory_order_mut() { + if let Some(rhs_s) = rhs.as_slice_memory_order() { + for (s, r) in self_s.iter_mut().zip(rhs_s) { + f(s, r); + } + return; } - return; } } - // otherwise, fall back to the outer iter + + // Otherwise, fall back to the outer iter self.zip_mut_with_by_rows(rhs, f); } @@ -1949,7 +2648,7 @@ where let dim = self.raw_dim(); Zip::from(LanesMut::new(self.view_mut(), Axis(n - 1))) .and(Lanes::new(rhs.broadcast_assume(dim), Axis(n - 1))) - .apply(move |s_row, r_row| Zip::from(s_row).and(r_row).apply(|a, b| f(a, b))); + .for_each(move |s_row, r_row| Zip::from(s_row).and(r_row).for_each(&mut f)); } fn zip_mut_with_elem(&mut self, rhs_elem: &B, mut f: F) @@ -1957,7 +2656,7 @@ where S: DataMut, F: FnMut(&mut A, &B), { - self.unordered_foreach_mut(move |elt| f(elt, rhs_elem)); + self.map_inplace(move |elt| f(elt, rhs_elem)); } /// Traverse two arrays in unspecified order, in lock step, @@ -1966,6 +2665,7 @@ where /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. /// /// **Panics** if broadcasting isn’t possible. + #[track_caller] #[inline] pub fn zip_mut_with(&mut self, rhs: &ArrayBase, f: F) where @@ -1999,27 +2699,7 @@ where slc.iter().fold(init, f) } else { let mut v = self.view(); - // put the narrowest axis at the last position - match v.ndim() { - 0 | 1 => {} - 2 => { - if self.len_of(Axis(1)) <= 1 - || self.len_of(Axis(0)) > 1 - && self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs() - { - v.swap_axes(0, 1); - } - } - n => { - let last = n - 1; - let narrow_axis = v - .axes() - .filter(|ax| ax.len() > 1) - .min_by_key(|ax| ax.stride().abs()) - .map_or(last, |ax| ax.axis().index()); - v.swap_axes(last, narrow_axis); - } - } + move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); v.into_elements_base().fold(init, f) } } @@ -2048,17 +2728,16 @@ where A: 'a, S: Data, { - if let Some(slc) = self.as_slice_memory_order() { - let v = crate::iterators::to_vec_mapped(slc.iter(), f); - unsafe { - ArrayBase::from_shape_vec_unchecked( + unsafe { + if let Some(slc) = self.as_slice_memory_order() { + ArrayBase::from_shape_trusted_iter_unchecked( self.dim.clone().strides(self.strides.clone()), - v, + slc.iter(), + f, ) + } else { + ArrayBase::from_shape_trusted_iter_unchecked(self.dim.clone(), self.iter(), f) } - } else { - let v = crate::iterators::to_vec_mapped(self.iter(), f); - unsafe { ArrayBase::from_shape_vec_unchecked(self.dim.clone(), v) } } } @@ -2078,11 +2757,9 @@ where if self.is_contiguous() { let strides = self.strides.clone(); let slc = self.as_slice_memory_order_mut().unwrap(); - let v = crate::iterators::to_vec_mapped(slc.iter_mut(), f); - unsafe { ArrayBase::from_shape_vec_unchecked(dim.strides(strides), v) } + unsafe { ArrayBase::from_shape_trusted_iter_unchecked(dim.strides(strides), slc.iter_mut(), f) } } else { - let v = crate::iterators::to_vec_mapped(self.iter_mut(), f); - unsafe { ArrayBase::from_shape_vec_unchecked(dim, v) } + unsafe { ArrayBase::from_shape_trusted_iter_unchecked(dim, self.iter_mut(), f) } } } @@ -2126,15 +2803,65 @@ where self } + /// Consume the array, call `f` by **v**alue on each element, and return an + /// owned array with the new values. Works for **any** `F: FnMut(A)->B`. + /// + /// If `A` and `B` are the same type then the map is performed by delegating + /// to [`mapv_into`] and then converting into an owned array. This avoids + /// unnecessary memory allocations in [`mapv`]. + /// + /// If `A` and `B` are different types then a new array is allocated and the + /// map is performed as in [`mapv`]. + /// + /// Elements are visited in arbitrary order. + /// + /// [`mapv_into`]: ArrayBase::mapv_into + /// [`mapv`]: ArrayBase::mapv + pub fn mapv_into_any(self, mut f: F) -> Array + where + S: DataMut, + F: FnMut(A) -> B, + A: Clone + 'static, + B: 'static, + { + if core::any::TypeId::of::() == core::any::TypeId::of::() { + // A and B are the same type. + // Wrap f in a closure of type FnMut(A) -> A . + let f = |a| { + let b = f(a); + // Safe because A and B are the same type. + unsafe { unlimited_transmute::(b) } + }; + // Delegate to mapv_into() using the wrapped closure. + // Convert output to a uniquely owned array of type Array. + let output = self.mapv_into(f).into_owned(); + // Change the return type from Array to Array. + // Again, safe because A and B are the same type. + unsafe { unlimited_transmute::, Array>(output) } + } else { + // A and B are not the same type. + // Fallback to mapv(). + self.mapv(f) + } + } + /// Modify the array in place by calling `f` by mutable reference on each element. /// /// Elements are visited in arbitrary order. - pub fn map_inplace(&mut self, f: F) + pub fn map_inplace<'a, F>(&'a mut self, f: F) where S: DataMut, - F: FnMut(&mut A), - { - self.unordered_foreach_mut(f); + A: 'a, + F: FnMut(&'a mut A), + { + match self.try_as_slice_memory_order_mut() { + Ok(slc) => slc.iter_mut().for_each(f), + Err(arr) => { + let mut v = arr.view_mut(); + move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); + v.into_elements_base().for_each(f); + } + } } /// Modify the array in place by calling `f` by **v**alue on each element. @@ -2143,10 +2870,10 @@ where /// Elements are visited in arbitrary order. /// /// ``` + /// # #[cfg(feature = "approx")] { /// use approx::assert_abs_diff_eq; /// use ndarray::arr2; /// - /// # #[cfg(feature = "approx")] { /// let mut a = arr2(&[[ 0., 1.], /// [-1., 2.]]); /// a.mapv_inplace(f32::exp); @@ -2164,14 +2891,13 @@ where F: FnMut(A) -> A, A: Clone, { - self.unordered_foreach_mut(move |x| *x = f(x.clone())); + self.map_inplace(move |x| *x = f(x.clone())); } - /// Visit each element in the array by calling `f` by reference - /// on each element. + /// Call `f` for each element in the array. /// /// Elements are visited in arbitrary order. - pub fn visit<'a, F>(&'a self, mut f: F) + pub fn for_each<'a, F>(&'a self, mut f: F) where F: FnMut(&'a A), A: 'a, @@ -2188,6 +2914,7 @@ where /// Return the result as an `Array`. /// /// **Panics** if `axis` is out of bounds. + #[track_caller] pub fn fold_axis(&self, axis: Axis, init: B, mut fold: F) -> Array where D: RemoveAxis, @@ -2210,6 +2937,7 @@ where /// Return the result as an `Array`. /// /// **Panics** if `axis` is out of bounds. + #[track_caller] pub fn map_axis<'a, B, F>(&'a self, axis: Axis, mut mapping: F) -> Array where D: RemoveAxis, @@ -2217,17 +2945,11 @@ where A: 'a, S: Data, { - let view_len = self.len_of(axis); - let view_stride = self.strides.axis(axis); - if view_len == 0 { + if self.len_of(axis) == 0 { let new_dim = self.dim.remove_axis(axis); Array::from_shape_simple_fn(new_dim, move || mapping(ArrayView::from(&[]))) } else { - // use the 0th subview as a map to each 1d array view extended from - // the 0th element. - self.index_axis(axis, 0).map(|first_elt| unsafe { - mapping(ArrayView::new_(first_elt, Ix1(view_len), Ix1(view_stride))) - }) + Zip::from(self.lanes(axis)).map_collect(mapping) } } @@ -2241,6 +2963,7 @@ where /// Return the result as an `Array`. /// /// **Panics** if `axis` is out of bounds. + #[track_caller] pub fn map_axis_mut<'a, B, F>(&'a mut self, axis: Axis, mut mapping: F) -> Array where D: RemoveAxis, @@ -2248,24 +2971,37 @@ where A: 'a, S: DataMut, { - let view_len = self.len_of(axis); - let view_stride = self.strides.axis(axis); - if view_len == 0 { + if self.len_of(axis) == 0 { let new_dim = self.dim.remove_axis(axis); Array::from_shape_simple_fn(new_dim, move || mapping(ArrayViewMut::from(&mut []))) } else { - // use the 0th subview as a map to each 1d array view extended from - // the 0th element. - self.index_axis_mut(axis, 0).map_mut(|first_elt| unsafe { - mapping(ArrayViewMut::new_( - first_elt, - Ix1(view_len), - Ix1(view_stride), - )) - }) + Zip::from(self.lanes_mut(axis)).map_collect(mapping) } } + /// Remove the `index`th elements along `axis` and shift down elements from higher indexes. + /// + /// Note that this "removes" the elements by swapping them around to the end of the axis and + /// shortening the length of the axis; the elements are not deinitialized or dropped by this, + /// just moved out of view (this only matters for elements with ownership semantics). It's + /// similar to slicing an owned array in place. + /// + /// Decreases the length of `axis` by one. + /// + /// ***Panics*** if `axis` is out of bounds
+ /// ***Panics*** if not `index < self.len_of(axis)`. + pub fn remove_index(&mut self, axis: Axis, index: usize) + where S: DataOwned + DataMut + { + assert!(index < self.len_of(axis), "index {} must be less than length of Axis({})", + index, axis.index()); + let (_, mut tail) = self.view_mut().split_at(axis, index); + // shift elements to the front + Zip::from(tail.lanes_mut(axis)).for_each(|mut lane| lane.rotate1_front()); + // then slice the axis in place to cut out the removed final element + self.slice_axis_inplace(axis, Slice::new(0, Some(-1), 1)); + } + /// Iterates over pairs of consecutive elements along the axis. /// /// The first argument to the closure is an element, and the second @@ -2307,7 +3043,7 @@ where prev.slice_axis_inplace(axis, Slice::from(..-1)); curr.slice_axis_inplace(axis, Slice::from(1..)); // This implementation relies on `Zip` iterating along `axis` in order. - Zip::from(prev).and(curr).apply(|prev, curr| unsafe { + Zip::from(prev).and(curr).for_each(|prev, curr| unsafe { // These pointer dereferences and borrows are safe because: // // 1. They're pointers to elements in the array. @@ -2322,3 +3058,96 @@ where }); } } + +/// Transmute from A to B. +/// +/// Like transmute, but does not have the compile-time size check which blocks +/// using regular transmute in some cases. +/// +/// **Panics** if the size of A and B are different. +#[track_caller] +#[inline] +unsafe fn unlimited_transmute(data: A) -> B +{ + // safe when sizes are equal and caller guarantees that representations are equal + assert_eq!(size_of::
(), size_of::()); + let old_data = ManuallyDrop::new(data); + (&*old_data as *const A as *const B).read() +} + +type DimMaxOf = >::Output; + +#[cfg(test)] +mod tests +{ + use super::*; + use crate::arr3; + use defmac::defmac; + + #[test] + fn test_flatten() + { + let array = arr3(&[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + let flattened = array.flatten(); + assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); + } + + #[test] + fn test_flatten_with_order() + { + let array = arr2(&[[1, 2], [3, 4], [5, 6], [7, 8]]); + let flattened = array.flatten_with_order(Order::RowMajor); + assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); + let flattened = array.flatten_with_order(Order::ColumnMajor); + assert_eq!(flattened, arr1(&[1, 3, 5, 7, 2, 4, 6, 8])); + } + + #[test] + fn test_into_flat() + { + let array = arr3(&[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + let flattened = array.into_flat(); + assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); + } + + #[test] + fn test_first_last() + { + let first = 2; + let last = 3; + + defmac!(assert_first mut array => { + assert_eq!(array.first().copied(), Some(first)); + assert_eq!(array.first_mut().copied(), Some(first)); + }); + defmac!(assert_last mut array => { + assert_eq!(array.last().copied(), Some(last)); + assert_eq!(array.last_mut().copied(), Some(last)); + }); + + let base = Array::from_vec(vec![first, last]); + let a = base.clone(); + assert_first!(a); + + let a = base.clone(); + assert_last!(a); + + let a = CowArray::from(base.view()); + assert_first!(a); + let a = CowArray::from(base.view()); + assert_last!(a); + + let a = CowArray::from(base.clone()); + assert_first!(a); + let a = CowArray::from(base.clone()); + assert_last!(a); + + let a = ArcArray::from(base.clone()); + let _a2 = a.clone(); + assert_last!(a); + + let a = ArcArray::from(base.clone()); + let _a2 = a.clone(); + assert_first!(a); + } +} diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 4804356e8..46ea18a7c 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -6,6 +6,8 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +use crate::dimension::DimMax; +use crate::Zip; use num_complex::Complex; /// Elements that can be used as direct operands in arithmetic with arrays. @@ -53,11 +55,11 @@ macro_rules! impl_binary_op( /// Perform elementwise #[doc=$doc] /// between `self` and `rhs`, -/// and return the result (based on `self`). +/// and return the result. /// /// `self` must be an `Array` or `ArcArray`. /// -/// If their shapes disagree, `rhs` is broadcast to the shape of `self`. +/// If their shapes disagree, `self` is broadcast to their broadcast shape. /// /// **Panics** if broadcasting isn’t possible. impl $trt> for ArrayBase @@ -66,11 +68,12 @@ where B: Clone, S: DataOwned + DataMut, S2: Data, - D: Dimension, + D: Dimension + DimMax, E: Dimension, { - type Output = ArrayBase; - fn $mth(self, rhs: ArrayBase) -> ArrayBase + type Output = ArrayBase>::Output>; + #[track_caller] + fn $mth(self, rhs: ArrayBase) -> Self::Output { self.$mth(&rhs) } @@ -79,9 +82,12 @@ where /// Perform elementwise #[doc=$doc] /// between `self` and reference `rhs`, -/// and return the result (based on `self`). +/// and return the result. +/// +/// `rhs` must be an `Array` or `ArcArray`. /// -/// If their shapes disagree, `rhs` is broadcast to the shape of `self`. +/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for ArrayBase @@ -90,16 +96,69 @@ where B: Clone, S: DataOwned + DataMut, S2: Data, - D: Dimension, + D: Dimension + DimMax, E: Dimension, { - type Output = ArrayBase; - fn $mth(mut self, rhs: &ArrayBase) -> ArrayBase + type Output = ArrayBase>::Output>; + #[track_caller] + fn $mth(self, rhs: &ArrayBase) -> Self::Output { - self.zip_mut_with(rhs, |x, y| { - *x = x.clone() $operator y.clone(); - }); - self + if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { + let mut out = self.into_dimensionality::<>::Output>().unwrap(); + out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth)); + out + } else { + let (lhs_view, rhs_view) = self.broadcast_with(&rhs).unwrap(); + if lhs_view.shape() == self.shape() { + let mut out = self.into_dimensionality::<>::Output>().unwrap(); + out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$mth)); + out + } else { + Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth)) + } + } + } +} + +/// Perform elementwise +#[doc=$doc] +/// between reference `self` and `rhs`, +/// and return the result. +/// +/// `rhs` must be an `Array` or `ArcArray`. +/// +/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// cloning the data if needed. +/// +/// **Panics** if broadcasting isn’t possible. +impl<'a, A, B, S, S2, D, E> $trt> for &'a ArrayBase +where + A: Clone + $trt, + B: Clone, + S: Data, + S2: DataOwned + DataMut, + D: Dimension, + E: Dimension + DimMax, +{ + type Output = ArrayBase>::Output>; + #[track_caller] + fn $mth(self, rhs: ArrayBase) -> Self::Output + where + { + if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { + let mut out = rhs.into_dimensionality::<>::Output>().unwrap(); + out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth)); + out + } else { + let (rhs_view, lhs_view) = rhs.broadcast_with(self).unwrap(); + if rhs_view.shape() == rhs.shape() { + let mut out = rhs.into_dimensionality::<>::Output>().unwrap(); + out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$mth)); + out + } else { + Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth)) + } + } } } @@ -108,7 +167,8 @@ where /// between references `self` and `rhs`, /// and return the result as a new `Array`. /// -/// If their shapes disagree, `rhs` is broadcast to the shape of `self`. +/// If their shapes disagree, `self` and `rhs` is broadcast to their broadcast shape, +/// cloning the data if needed. /// /// **Panics** if broadcasting isn’t possible. impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for &'a ArrayBase @@ -117,13 +177,20 @@ where B: Clone, S: Data, S2: Data, - D: Dimension, + D: Dimension + DimMax, E: Dimension, { - type Output = Array; - fn $mth(self, rhs: &'a ArrayBase) -> Array { - // FIXME: Can we co-broadcast arrays here? And how? - self.to_owned().$mth(rhs) + type Output = Array>::Output>; + #[track_caller] + fn $mth(self, rhs: &'a ArrayBase) -> Self::Output { + let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { + let lhs = self.view().into_dimensionality::<>::Output>().unwrap(); + let rhs = rhs.view().into_dimensionality::<>::Output>().unwrap(); + (lhs, rhs) + } else { + self.broadcast_with(rhs).unwrap() + }; + Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$mth)) } } @@ -141,7 +208,7 @@ impl $trt for ArrayBase { type Output = ArrayBase; fn $mth(mut self, x: B) -> ArrayBase { - self.unordered_foreach_mut(move |elt| { + self.map_inplace(move |elt| { *elt = elt.clone() $operator x.clone(); }); self @@ -159,8 +226,8 @@ impl<'a, A, S, D, B> $trt for &'a ArrayBase B: ScalarOperand, { type Output = Array; - fn $mth(self, x: B) -> Array { - self.to_owned().$mth(x) + fn $mth(self, x: B) -> Self::Output { + self.map(move |elt| elt.clone() $operator x.clone()) } } ); @@ -194,7 +261,7 @@ impl $trt> for $scalar rhs.$mth(self) } or {{ let mut rhs = rhs; - rhs.unordered_foreach_mut(move |elt| { + rhs.map_inplace(move |elt| { *elt = self $operator *elt; }); rhs @@ -210,24 +277,39 @@ impl<'a, S, D> $trt<&'a ArrayBase> for $scalar D: Dimension, { type Output = Array<$scalar, D>; - fn $mth(self, rhs: &ArrayBase) -> Array<$scalar, D> { + fn $mth(self, rhs: &ArrayBase) -> Self::Output { if_commutative!($commutative { rhs.$mth(self) } or { - self.$mth(rhs.to_owned()) + rhs.map(move |elt| self.clone() $operator elt.clone()) }) } } ); } -mod arithmetic_ops { +mod arithmetic_ops +{ use super::*; use crate::imp_prelude::*; - use num_complex::Complex; use std::ops::*; + fn clone_opf(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C + { + move |x, y| f(x.clone(), y.clone()) + } + + fn clone_iopf(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B) + { + move |x, y| *x = f(x.clone(), y.clone()) + } + + fn clone_iopf_rev(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A) + { + move |x, y| *x = f(y.clone(), x.clone()) + } + impl_binary_op!(Add, +, add, +=, "addition"); impl_binary_op!(Sub, -, sub, -=, "subtraction"); impl_binary_op!(Mul, *, mul, *=, "multiplication"); @@ -261,6 +343,8 @@ mod arithmetic_ops { all_scalar_ops!(u32); all_scalar_ops!(i64); all_scalar_ops!(u64); + all_scalar_ops!(isize); + all_scalar_ops!(usize); all_scalar_ops!(i128); all_scalar_ops!(u128); @@ -298,8 +382,9 @@ mod arithmetic_ops { { type Output = Self; /// Perform an elementwise negation of `self` and return the result. - fn neg(mut self) -> Self { - self.unordered_foreach_mut(|elt| { + fn neg(mut self) -> Self + { + self.map_inplace(|elt| { *elt = -elt.clone(); }); self @@ -315,7 +400,8 @@ mod arithmetic_ops { type Output = Array; /// Perform an elementwise negation of reference `self` and return the /// result as a new `Array`. - fn neg(self) -> Array { + fn neg(self) -> Array + { self.map(Neg::neg) } } @@ -328,8 +414,9 @@ mod arithmetic_ops { { type Output = Self; /// Perform an elementwise unary not of `self` and return the result. - fn not(mut self) -> Self { - self.unordered_foreach_mut(|elt| { + fn not(mut self) -> Self + { + self.map_inplace(|elt| { *elt = !elt.clone(); }); self @@ -345,13 +432,15 @@ mod arithmetic_ops { type Output = Array; /// Perform an elementwise unary not of reference `self` and return the /// result as a new `Array`. - fn not(self) -> Array { + fn not(self) -> Array + { self.map(Not::not) } } } -mod assign_ops { +mod assign_ops +{ use super::*; use crate::imp_prelude::*; @@ -371,6 +460,7 @@ mod assign_ops { D: Dimension, E: Dimension, { + #[track_caller] fn $method(&mut self, rhs: &ArrayBase) { self.zip_mut_with(rhs, |x, y| { x.$method(y.clone()); @@ -386,7 +476,7 @@ mod assign_ops { D: Dimension, { fn $method(&mut self, rhs: A) { - self.unordered_foreach_mut(move |elt| { + self.map_inplace(move |elt| { elt.$method(rhs.clone()); }); } diff --git a/src/impl_owned_array.rs b/src/impl_owned_array.rs index 96075593f..bb970f876 100644 --- a/src/impl_owned_array.rs +++ b/src/impl_owned_array.rs @@ -1,11 +1,26 @@ +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; +use core::ptr::NonNull; +use std::mem; +use std::mem::MaybeUninit; + +#[allow(unused_imports)] // Needed for Rust 1.64 +use rawpointer::PointerExt; + use crate::imp_prelude::*; +use crate::dimension; +use crate::error::{ErrorKind, ShapeError}; +use crate::iterators::Baseiter; +use crate::low_level_util::AbortIfPanic; +use crate::OwnedRepr; +use crate::Zip; + /// Methods specific to `Array0`. /// /// ***See also all methods for [`ArrayBase`]*** -/// -/// [`ArrayBase`]: struct.ArrayBase.html -impl Array { +impl Array +{ /// Returns the single element in the array without cloning it. /// /// ``` @@ -19,22 +34,23 @@ impl Array { /// let scalar: Foo = array.into_scalar(); /// assert_eq!(scalar, Foo); /// ``` - pub fn into_scalar(mut self) -> A { - let size = ::std::mem::size_of::(); + pub fn into_scalar(self) -> A + { + let size = mem::size_of::(); if size == 0 { // Any index in the `Vec` is fine since all elements are identical. - self.data.0.remove(0) + self.data.into_vec().remove(0) } else { // Find the index in the `Vec` corresponding to `self.ptr`. // (This is necessary because the element in the array might not be // the first element in the `Vec`, such as if the array was created // by `array![1, 2, 3, 4].slice_move(s![2])`.) let first = self.ptr.as_ptr() as usize; - let base = self.data.0.as_ptr() as usize; + let base = self.data.as_ptr() as usize; let index = (first - base) / size; debug_assert_eq!((first - base) % size, 0); // Remove the element at the index and return it. - self.data.0.remove(index) + self.data.into_vec().remove(index) } } } @@ -42,18 +58,961 @@ impl Array { /// Methods specific to `Array`. /// /// ***See also all methods for [`ArrayBase`]*** -/// -/// [`ArrayBase`]: struct.ArrayBase.html impl Array -where - D: Dimension, +where D: Dimension { + /// Returns the offset (in units of `A`) from the start of the allocation + /// to the first element, or `None` if the array is empty. + fn offset_from_alloc_to_logical_ptr(&self) -> Option + { + if self.is_empty() { + return None; + } + if std::mem::size_of::() == 0 { + Some(dimension::offset_from_low_addr_ptr_to_logical_ptr(&self.dim, &self.strides)) + } else { + let offset = unsafe { self.as_ptr().offset_from(self.data.as_ptr()) }; + debug_assert!(offset >= 0); + Some(offset as usize) + } + } + /// Return a vector of the elements in the array, in the way they are - /// stored internally. + /// stored internally, and the index in the vector corresponding to the + /// logically first element of the array (or None if the array is empty). /// /// If the array is in standard memory layout, the logical element order /// of the array (`.iter()` order) and of the returned vector will be the same. - pub fn into_raw_vec(self) -> Vec { - self.data.0 + /// + /// ``` + /// use ndarray::{array, Array2, Axis}; + /// + /// let mut arr: Array2 = array![[1., 2.], [3., 4.], [5., 6.]]; + /// arr.slice_axis_inplace(Axis(0), (1..).into()); + /// assert_eq!(arr[[0, 0]], 3.); + /// let copy = arr.clone(); + /// + /// let shape = arr.shape().to_owned(); + /// let strides = arr.strides().to_owned(); + /// let (v, offset) = arr.into_raw_vec_and_offset(); + /// + /// assert_eq!(v, &[1., 2., 3., 4., 5., 6.]); + /// assert_eq!(offset, Some(2)); + /// assert_eq!(v[offset.unwrap()], 3.); + /// for row in 0..shape[0] { + /// for col in 0..shape[1] { + /// let index = ( + /// offset.unwrap() as isize + /// + row as isize * strides[0] + /// + col as isize * strides[1] + /// ) as usize; + /// assert_eq!(v[index], copy[[row, col]]); + /// } + /// } + /// ``` + /// + /// In the case of zero-sized elements, the offset to the logically first + /// element is somewhat meaningless. For convenience, an offset will be + /// returned such that all indices computed using the offset, shape, and + /// strides will be in-bounds for the `Vec`. Note that this offset won't + /// necessarily be the same as the offset for an array of nonzero-sized + /// elements sliced in the same way. + /// + /// ``` + /// use ndarray::{array, Array2, Axis}; + /// + /// let mut arr: Array2<()> = array![[(), ()], [(), ()], [(), ()]]; + /// arr.slice_axis_inplace(Axis(0), (1..).into()); + /// + /// let shape = arr.shape().to_owned(); + /// let strides = arr.strides().to_owned(); + /// let (v, offset) = arr.into_raw_vec_and_offset(); + /// + /// assert_eq!(v, &[(), (), (), (), (), ()]); + /// for row in 0..shape[0] { + /// for col in 0..shape[1] { + /// let index = ( + /// offset.unwrap() as isize + /// + row as isize * strides[0] + /// + col as isize * strides[1] + /// ) as usize; + /// assert_eq!(v[index], ()); + /// } + /// } + /// ``` + pub fn into_raw_vec_and_offset(self) -> (Vec, Option) + { + let offset = self.offset_from_alloc_to_logical_ptr(); + (self.data.into_vec(), offset) + } + + /// Return a vector of the elements in the array, in the way they are + /// stored internally. + /// + /// Depending on slicing and strides, the logically first element of the + /// array can be located at an offset. Because of this, prefer to use + /// `.into_raw_vec_and_offset()` instead. + #[deprecated(note = "Use .into_raw_vec_and_offset() instead", since = "0.16.0")] + pub fn into_raw_vec(self) -> Vec + { + self.into_raw_vec_and_offset().0 + } +} + +/// Methods specific to `Array2`. +/// +/// ***See also all methods for [`ArrayBase`]*** +impl Array +{ + /// Append a row to an array + /// + /// The elements from `row` are cloned and added as a new row in the array. + /// + /// ***Errors*** with a shape error if the length of the row does not match the length of the + /// rows in the array. + /// + /// The memory layout of the `self` array matters for ensuring that the append is efficient. + /// Appending automatically changes memory layout of the array so that it is appended to + /// along the "growing axis". However, if the memory layout needs adjusting, the array must + /// reallocate and move memory. + /// + /// The operation leaves the existing data in place and is most efficent if one of these is + /// true: + /// + /// - The axis being appended to is the longest stride axis, i.e the array is in row major + /// ("C") layout. + /// - The array has 0 or 1 rows (It is converted to row major) + /// + /// Ensure appending is efficient by, for example, appending to an empty array and then always + /// pushing/appending along the same axis. For pushing rows, ndarray's default layout (C order) + /// is efficient. + /// + /// When repeatedly appending to a single axis, the amortized average complexity of each + /// append is O(m), where *m* is the length of the row. + /// + /// ```rust + /// use ndarray::{Array, ArrayView, array}; + /// + /// // create an empty array and append + /// let mut a = Array::zeros((0, 4)); + /// a.push_row(ArrayView::from(&[ 1., 2., 3., 4.])).unwrap(); + /// a.push_row(ArrayView::from(&[-1., -2., -3., -4.])).unwrap(); + /// + /// assert_eq!( + /// a, + /// array![[ 1., 2., 3., 4.], + /// [-1., -2., -3., -4.]]); + /// ``` + pub fn push_row(&mut self, row: ArrayView) -> Result<(), ShapeError> + where A: Clone + { + self.append(Axis(0), row.insert_axis(Axis(0))) + } + + /// Append a column to an array + /// + /// The elements from `column` are cloned and added as a new column in the array. + /// + /// ***Errors*** with a shape error if the length of the column does not match the length of + /// the columns in the array. + /// + /// The memory layout of the `self` array matters for ensuring that the append is efficient. + /// Appending automatically changes memory layout of the array so that it is appended to + /// along the "growing axis". However, if the memory layout needs adjusting, the array must + /// reallocate and move memory. + /// + /// The operation leaves the existing data in place and is most efficent if one of these is + /// true: + /// + /// - The axis being appended to is the longest stride axis, i.e the array is in column major + /// ("F") layout. + /// - The array has 0 or 1 columns (It is converted to column major) + /// + /// Ensure appending is efficient by, for example, appending to an empty array and then always + /// pushing/appending along the same axis. For pushing columns, column major layout (F order) + /// is efficient. + /// + /// When repeatedly appending to a single axis, the amortized average complexity of each append + /// is O(m), where *m* is the length of the column. + /// + /// ```rust + /// use ndarray::{Array, ArrayView, array}; + /// + /// // create an empty array and append + /// let mut a = Array::zeros((2, 0)); + /// a.push_column(ArrayView::from(&[1., 2.])).unwrap(); + /// a.push_column(ArrayView::from(&[-1., -2.])).unwrap(); + /// + /// assert_eq!( + /// a, + /// array![[1., -1.], + /// [2., -2.]]); + /// ``` + pub fn push_column(&mut self, column: ArrayView) -> Result<(), ShapeError> + where A: Clone + { + self.append(Axis(1), column.insert_axis(Axis(1))) + } + + /// Reserve capacity to grow array by at least `additional` rows. + /// + /// Existing elements of `array` are untouched and the backing storage is grown by + /// calling the underlying `reserve` method of the `OwnedRepr`. + /// + /// This is useful when pushing or appending repeatedly to an array to avoid multiple + /// allocations. + /// + /// ***Errors*** with a shape error if the resultant capacity is larger than the addressable + /// bounds; that is, the product of non-zero axis lengths once `axis` has been extended by + /// `additional` exceeds `isize::MAX`. + /// + /// ```rust + /// use ndarray::Array2; + /// let mut a = Array2::::zeros((2,4)); + /// a.reserve_rows(1000).unwrap(); + /// assert!(a.into_raw_vec().capacity() >= 4*1002); + /// ``` + pub fn reserve_rows(&mut self, additional: usize) -> Result<(), ShapeError> + { + self.reserve(Axis(0), additional) + } + + /// Reserve capacity to grow array by at least `additional` columns. + /// + /// Existing elements of `array` are untouched and the backing storage is grown by + /// calling the underlying `reserve` method of the `OwnedRepr`. + /// + /// This is useful when pushing or appending repeatedly to an array to avoid multiple + /// allocations. + /// + /// ***Errors*** with a shape error if the resultant capacity is larger than the addressable + /// bounds; that is, the product of non-zero axis lengths once `axis` has been extended by + /// `additional` exceeds `isize::MAX`. + /// + /// ```rust + /// use ndarray::Array2; + /// let mut a = Array2::::zeros((2,4)); + /// a.reserve_columns(1000).unwrap(); + /// assert!(a.into_raw_vec().capacity() >= 2*1002); + /// ``` + pub fn reserve_columns(&mut self, additional: usize) -> Result<(), ShapeError> + { + self.reserve(Axis(1), additional) + } +} + +impl Array +where D: Dimension +{ + /// Move all elements from self into `new_array`, which must be of the same shape but + /// can have a different memory layout. The destination is overwritten completely. + /// + /// The destination should be a mut reference to an array or an `ArrayViewMut` with + /// `A` elements. + /// + /// ***Panics*** if the shapes don't agree. + /// + /// ## Example + /// + /// ``` + /// use ndarray::Array; + /// + /// // Usage example of move_into in safe code + /// let mut a = Array::default((10, 10)); + /// let b = Array::from_shape_fn((10, 10), |(i, j)| (i + j).to_string()); + /// b.move_into(&mut a); + /// ``` + pub fn move_into<'a, AM>(self, new_array: AM) + where + AM: Into>, + A: 'a, + { + // Remove generic parameter P and call the implementation + let new_array = new_array.into(); + if mem::needs_drop::() { + self.move_into_needs_drop(new_array); + } else { + // If `A` doesn't need drop, we can overwrite the destination. + // Safe because: move_into_uninit only writes initialized values + unsafe { self.move_into_uninit(new_array.into_maybe_uninit()) } + } + } + + fn move_into_needs_drop(mut self, new_array: ArrayViewMut) + { + // Simple case where `A` has a destructor: just swap values between self and new_array. + // Afterwards, `self` drops full of initialized values and dropping works as usual. + // This avoids moving out of owned values in `self` while at the same time managing + // the dropping if the values being overwritten in `new_array`. + Zip::from(&mut self) + .and(new_array) + .for_each(|src, dst| mem::swap(src, dst)); + } + + /// Move all elements from self into `new_array`, which must be of the same shape but + /// can have a different memory layout. The destination is overwritten completely. + /// + /// The destination should be a mut reference to an array or an `ArrayViewMut` with + /// `MaybeUninit` elements (which are overwritten without dropping any existing value). + /// + /// Minor implementation note: Owned arrays like `self` may be sliced in place and own elements + /// that are not part of their active view; these are dropped at the end of this function, + /// after all elements in the "active view" are moved into `new_array`. If there is a panic in + /// drop of any such element, other elements may be leaked. + /// + /// ***Panics*** if the shapes don't agree. + /// + /// ## Example + /// + /// ``` + /// use ndarray::Array; + /// + /// let a = Array::from_iter(0..100).into_shape_with_order((10, 10)).unwrap(); + /// let mut b = Array::uninit((10, 10)); + /// a.move_into_uninit(&mut b); + /// unsafe { + /// // we can now promise we have fully initialized `b`. + /// let b = b.assume_init(); + /// } + /// ``` + pub fn move_into_uninit<'a, AM>(self, new_array: AM) + where + AM: Into, D>>, + A: 'a, + { + // Remove generic parameter AM and call the implementation + self.move_into_impl(new_array.into()) + } + + fn move_into_impl(mut self, new_array: ArrayViewMut, D>) + { + unsafe { + // Safety: copy_to_nonoverlapping cannot panic + let guard = AbortIfPanic(&"move_into: moving out of owned value"); + // Move all reachable elements; we move elements out of `self`. + // and thus must not panic for the whole section until we call `self.data.set_len(0)`. + Zip::from(self.raw_view_mut()) + .and(new_array) + .for_each(|src, dst| { + src.copy_to_nonoverlapping(dst.as_mut_ptr(), 1); + }); + guard.defuse(); + // Drop all unreachable elements + self.drop_unreachable_elements(); + } + } + + /// This drops all "unreachable" elements in the data storage of self. + /// + /// That means those elements that are not visible in the slicing of the array. + /// *Reachable elements are assumed to already have been moved from.* + /// + /// # Safety + /// + /// This is a panic critical section since `self` is already moved-from. + fn drop_unreachable_elements(mut self) -> OwnedRepr + { + let self_len = self.len(); + + // "deconstruct" self; the owned repr releases ownership of all elements and we + // and carry on with raw view methods + let data_len = self.data.len(); + + let has_unreachable_elements = self_len != data_len; + if !has_unreachable_elements || mem::size_of::() == 0 || !mem::needs_drop::() { + unsafe { + self.data.set_len(0); + } + self.data + } else { + self.drop_unreachable_elements_slow() + } + } + + #[inline(never)] + #[cold] + fn drop_unreachable_elements_slow(mut self) -> OwnedRepr + { + // "deconstruct" self; the owned repr releases ownership of all elements and we + // carry on with raw view methods + let data_len = self.data.len(); + let data_ptr = self.data.as_nonnull_mut(); + + unsafe { + // Safety: self.data releases ownership of the elements. Any panics below this point + // will result in leaking elements instead of double drops. + let self_ = self.raw_view_mut(); + self.data.set_len(0); + + drop_unreachable_raw(self_, data_ptr, data_len); + } + + self.data + } + + /// Create an empty array with an all-zeros shape + /// + /// ***Panics*** if D is zero-dimensional, because it can't be empty + pub(crate) fn empty() -> Array + { + assert_ne!(D::NDIM, Some(0)); + let ndim = D::NDIM.unwrap_or(1); + Array::from_shape_simple_fn(D::zeros(ndim), || unreachable!()) + } + + /// Create new_array with the right layout for appending to `growing_axis` + #[cold] + fn change_to_contig_append_layout(&mut self, growing_axis: Axis) + { + let ndim = self.ndim(); + let mut dim = self.raw_dim(); + + // The array will be created with 0 (C) or ndim-1 (F) as the biggest stride + // axis. Rearrange the shape so that `growing_axis` is the biggest stride axis + // afterwards. + let mut new_array; + if growing_axis == Axis(ndim - 1) { + new_array = Self::uninit(dim.f()); + } else { + dim.slice_mut()[..=growing_axis.index()].rotate_right(1); + new_array = Self::uninit(dim); + new_array.dim.slice_mut()[..=growing_axis.index()].rotate_left(1); + new_array.strides.slice_mut()[..=growing_axis.index()].rotate_left(1); + } + + // self -> old_self. + // dummy array -> self. + // old_self elements are moved -> new_array. + let old_self = std::mem::replace(self, Self::empty()); + old_self.move_into_uninit(new_array.view_mut()); + + // new_array -> self. + unsafe { + *self = new_array.assume_init(); + } + } + + /// Append an array to the array along an axis. + /// + /// The elements of `array` are cloned and extend the axis `axis` in the present array; + /// `self` will grow in size by 1 along `axis`. + /// + /// Append to the array, where the array being pushed to the array has one dimension less than + /// the `self` array. This method is equivalent to [append](ArrayBase::append) in this way: + /// `self.append(axis, array.insert_axis(axis))`. + /// + /// ***Errors*** with a shape error if the shape of self does not match the array-to-append; + /// all axes *except* the axis along which it being appended matter for this check: + /// the shape of `self` with `axis` removed must be the same as the shape of `array`. + /// + /// The memory layout of the `self` array matters for ensuring that the append is efficient. + /// Appending automatically changes memory layout of the array so that it is appended to + /// along the "growing axis". However, if the memory layout needs adjusting, the array must + /// reallocate and move memory. + /// + /// The operation leaves the existing data in place and is most efficent if `axis` is a + /// "growing axis" for the array, i.e. one of these is true: + /// + /// - The axis is the longest stride axis, for example the 0th axis in a C-layout or the + /// *n-1*th axis in an F-layout array. + /// - The axis has length 0 or 1 (It is converted to the new growing axis) + /// + /// Ensure appending is efficient by for example starting from an empty array and/or always + /// appending to an array along the same axis. + /// + /// The amortized average complexity of the append, when appending along its growing axis, is + /// O(*m*) where *m* is the number of individual elements to append. + /// + /// The memory layout of the argument `array` does not matter to the same extent. + /// + /// ```rust + /// use ndarray::{Array, ArrayView, array, Axis}; + /// + /// // create an empty array and push rows to it + /// let mut a = Array::zeros((0, 4)); + /// let ones = ArrayView::from(&[1.; 4]); + /// let zeros = ArrayView::from(&[0.; 4]); + /// a.push(Axis(0), ones).unwrap(); + /// a.push(Axis(0), zeros).unwrap(); + /// a.push(Axis(0), ones).unwrap(); + /// + /// assert_eq!( + /// a, + /// array![[1., 1., 1., 1.], + /// [0., 0., 0., 0.], + /// [1., 1., 1., 1.]]); + /// ``` + pub fn push(&mut self, axis: Axis, array: ArrayView) -> Result<(), ShapeError> + where + A: Clone, + D: RemoveAxis, + { + // same-dimensionality conversion + self.append(axis, array.insert_axis(axis).into_dimensionality::().unwrap()) + } + + /// Append an array to the array along an axis. + /// + /// The elements of `array` are cloned and extend the axis `axis` in the present array; + /// `self` will grow in size by `array.len_of(axis)` along `axis`. + /// + /// ***Errors*** with a shape error if the shape of self does not match the array-to-append; + /// all axes *except* the axis along which it being appended matter for this check: + /// the shape of `self` with `axis` removed must be the same as the shape of `array` with + /// `axis` removed. + /// + /// The memory layout of the `self` array matters for ensuring that the append is efficient. + /// Appending automatically changes memory layout of the array so that it is appended to + /// along the "growing axis". However, if the memory layout needs adjusting, the array must + /// reallocate and move memory. + /// + /// The operation leaves the existing data in place and is most efficent if `axis` is a + /// "growing axis" for the array, i.e. one of these is true: + /// + /// - The axis is the longest stride axis, for example the 0th axis in a C-layout or the + /// *n-1*th axis in an F-layout array. + /// - The axis has length 0 or 1 (It is converted to the new growing axis) + /// + /// Ensure appending is efficient by for example starting from an empty array and/or always + /// appending to an array along the same axis. + /// + /// The amortized average complexity of the append, when appending along its growing axis, is + /// O(*m*) where *m* is the number of individual elements to append. + /// + /// The memory layout of the argument `array` does not matter to the same extent. + /// + /// ```rust + /// use ndarray::{Array, ArrayView, array, Axis}; + /// + /// // create an empty array and append two rows at a time + /// let mut a = Array::zeros((0, 4)); + /// let ones = ArrayView::from(&[1.; 8]).into_shape_with_order((2, 4)).unwrap(); + /// let zeros = ArrayView::from(&[0.; 8]).into_shape_with_order((2, 4)).unwrap(); + /// a.append(Axis(0), ones).unwrap(); + /// a.append(Axis(0), zeros).unwrap(); + /// a.append(Axis(0), ones).unwrap(); + /// + /// assert_eq!( + /// a, + /// array![[1., 1., 1., 1.], + /// [1., 1., 1., 1.], + /// [0., 0., 0., 0.], + /// [0., 0., 0., 0.], + /// [1., 1., 1., 1.], + /// [1., 1., 1., 1.]]); + /// ``` + pub fn append(&mut self, axis: Axis, mut array: ArrayView) -> Result<(), ShapeError> + where + A: Clone, + D: RemoveAxis, + { + if self.ndim() == 0 { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + + let current_axis_len = self.len_of(axis); + let self_dim = self.raw_dim(); + let array_dim = array.raw_dim(); + let remaining_shape = self_dim.remove_axis(axis); + let array_rem_shape = array_dim.remove_axis(axis); + + if remaining_shape != array_rem_shape { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + + let len_to_append = array.len(); + + let mut res_dim = self_dim; + res_dim[axis.index()] += array_dim[axis.index()]; + let new_len = dimension::size_of_shape_checked(&res_dim)?; + + if len_to_append == 0 { + // There are no elements to append and shapes are compatible: + // either the dimension increment is zero, or there is an existing + // zero in another axis in self. + debug_assert_eq!(self.len(), new_len); + self.dim = res_dim; + return Ok(()); + } + + let self_is_empty = self.is_empty(); + let mut incompatible_layout = false; + + // array must be empty or have `axis` as the outermost (longest stride) axis + if !self_is_empty && current_axis_len > 1 { + // `axis` must be max stride axis or equal to its stride + let axis_stride = self.stride_of(axis); + if axis_stride < 0 { + incompatible_layout = true; + } else { + for ax in self.axes() { + if ax.axis == axis { + continue; + } + if ax.len > 1 && ax.stride.abs() > axis_stride { + incompatible_layout = true; + break; + } + } + } + } + + // array must be be "full" (contiguous and have no exterior holes) + if self.len() != self.data.len() { + incompatible_layout = true; + } + + if incompatible_layout { + self.change_to_contig_append_layout(axis); + // safety-check parameters after remodeling + debug_assert_eq!(self_is_empty, self.is_empty()); + debug_assert_eq!(current_axis_len, self.len_of(axis)); + } + + let strides = if self_is_empty { + // recompute strides - if the array was previously empty, it could have zeros in + // strides. + // The new order is based on c/f-contig but must have `axis` as outermost axis. + if axis == Axis(self.ndim() - 1) { + // prefer f-contig when appending to the last axis + // Axis n - 1 is outermost axis + res_dim.fortran_strides() + } else { + // standard axis order except for the growing axis; + // anticipates that it's likely that `array` has standard order apart from the + // growing axis. + res_dim.slice_mut()[..=axis.index()].rotate_right(1); + let mut strides = res_dim.default_strides(); + res_dim.slice_mut()[..=axis.index()].rotate_left(1); + strides.slice_mut()[..=axis.index()].rotate_left(1); + strides + } + } else if current_axis_len == 1 { + // This is the outermost/longest stride axis; so we find the max across the other axes + let new_stride = self.axes().fold(1, |acc, ax| { + if ax.axis == axis || ax.len <= 1 { + acc + } else { + let this_ax = ax.len as isize * ax.stride.abs(); + if this_ax > acc { + this_ax + } else { + acc + } + } + }); + let mut strides = self.strides.clone(); + strides[axis.index()] = new_stride as usize; + strides + } else { + self.strides.clone() + }; + + // grow backing storage and update head ptr + self.reserve(axis, array_dim[axis.index()])?; + + unsafe { + // clone elements from view to the array now + // + // To be robust for panics and drop the right elements, we want + // to fill the tail in memory order, so that we can drop the right elements on panic. + // + // We have: Zip::from(tail_view).and(array) + // Transform tail_view into standard order by inverting and moving its axes. + // Keep the Zip traversal unchanged by applying the same axis transformations to + // `array`. This ensures the Zip traverses the underlying memory in order. + // + // XXX It would be possible to skip this transformation if the element + // doesn't have drop. However, in the interest of code coverage, all elements + // use this code initially. + + // Invert axes in tail_view by inverting strides + let mut tail_strides = strides.clone(); + if tail_strides.ndim() > 1 { + for i in 0..tail_strides.ndim() { + let s = tail_strides[i] as isize; + if s < 0 { + tail_strides.set_axis(Axis(i), -s as usize); + array.invert_axis(Axis(i)); + } + } + } + + // With > 0 strides, the current end of data is the correct base pointer for tail_view + let tail_ptr = self.data.as_end_nonnull(); + let mut tail_view = RawArrayViewMut::new(tail_ptr, array_dim, tail_strides); + + if tail_view.ndim() > 1 { + sort_axes_in_default_order_tandem(&mut tail_view, &mut array); + debug_assert!(tail_view.is_standard_layout(), + "not std layout dim: {:?}, strides: {:?}", + tail_view.shape(), tail_view.strides()); + } + + // Keep track of currently filled length of `self.data` and update it + // on scope exit (panic or loop finish). This "indirect" way to + // write the length is used to help the compiler, the len store to self.data may + // otherwise be mistaken to alias with other stores in the loop. + struct SetLenOnDrop<'a, A: 'a> + { + len: usize, + data: &'a mut OwnedRepr, + } + + impl Drop for SetLenOnDrop<'_, A> + { + fn drop(&mut self) + { + unsafe { + self.data.set_len(self.len); + } + } + } + + let mut data_length_guard = SetLenOnDrop { + len: self.data.len(), + data: &mut self.data, + }; + + // Safety: tail_view is constructed to have the same shape as array + Zip::from(tail_view) + .and_unchecked(array) + .debug_assert_c_order() + .for_each(|to, from| { + to.write(from.clone()); + data_length_guard.len += 1; + }); + drop(data_length_guard); + + // update array dimension + self.strides = strides; + self.dim = res_dim; + } + // multiple assertions after pointer & dimension update + debug_assert_eq!(self.data.len(), self.len()); + debug_assert_eq!(self.len(), new_len); + debug_assert!(self.pointer_is_inbounds()); + + Ok(()) + } + + /// Reserve capacity to grow array along `axis` by at least `additional` elements. + /// + /// The axis should be in the range `Axis(` 0 .. *n* `)` where *n* is the + /// number of dimensions (axes) of the array. + /// + /// Existing elements of `array` are untouched and the backing storage is grown by + /// calling the underlying `reserve` method of the `OwnedRepr`. + /// + /// This is useful when pushing or appending repeatedly to an array to avoid multiple + /// allocations. + /// + /// ***Panics*** if the axis is out of bounds. + /// + /// ***Errors*** with a shape error if the resultant capacity is larger than the addressable + /// bounds; that is, the product of non-zero axis lengths once `axis` has been extended by + /// `additional` exceeds `isize::MAX`. + /// + /// ```rust + /// use ndarray::{Array3, Axis}; + /// let mut a = Array3::::zeros((0,2,4)); + /// a.reserve(Axis(0), 1000).unwrap(); + /// assert!(a.into_raw_vec().capacity() >= 2*4*1000); + /// ``` + /// + pub fn reserve(&mut self, axis: Axis, additional: usize) -> Result<(), ShapeError> + where D: RemoveAxis + { + debug_assert!(axis.index() < self.ndim()); + let self_dim = self.raw_dim(); + let remaining_shape = self_dim.remove_axis(axis); + + // Make sure added capacity doesn't overflow usize::MAX + let len_to_append = remaining_shape + .size() + .checked_mul(additional) + .ok_or(ShapeError::from_kind(ErrorKind::Overflow))?; + + // Make sure new capacity is still in bounds + let mut res_dim = self_dim; + res_dim[axis.index()] += additional; + let new_len = dimension::size_of_shape_checked(&res_dim)?; + + // Check whether len_to_append would cause an overflow + debug_assert_eq!(self.len().checked_add(len_to_append).unwrap(), new_len); + + unsafe { + // grow backing storage and update head ptr + let data_to_array_offset = if std::mem::size_of::() != 0 { + self.as_ptr().offset_from(self.data.as_ptr()) + } else { + 0 + }; + debug_assert!(data_to_array_offset >= 0); + self.ptr = self + .data + .reserve(len_to_append) + .offset(data_to_array_offset); + } + + debug_assert!(self.pointer_is_inbounds()); + + Ok(()) + } +} + +/// This drops all "unreachable" elements in `self_` given the data pointer and data length. +/// +/// # Safety +/// +/// This is an internal function for use by move_into and IntoIter only, safety invariants may need +/// to be upheld across the calls from those implementations. +pub(crate) unsafe fn drop_unreachable_raw( + mut self_: RawArrayViewMut, data_ptr: NonNull, data_len: usize, +) where D: Dimension +{ + let self_len = self_.len(); + + for i in 0..self_.ndim() { + if self_.stride_of(Axis(i)) < 0 { + self_.invert_axis(Axis(i)); + } + } + sort_axes_in_default_order(&mut self_); + // with uninverted axes this is now the element with lowest address + let array_memory_head_ptr = self_.ptr; + let data_end_ptr = data_ptr.add(data_len); + debug_assert!(data_ptr <= array_memory_head_ptr); + debug_assert!(array_memory_head_ptr <= data_end_ptr); + + // The idea is simply this: the iterator will yield the elements of self_ in + // increasing address order. + // + // The pointers produced by the iterator are those that we *do not* touch. + // The pointers *not mentioned* by the iterator are those we have to drop. + // + // We have to drop elements in the range from `data_ptr` until (not including) + // `data_end_ptr`, except those that are produced by `iter`. + + // As an optimization, the innermost axis is removed if it has stride 1, because + // we then have a long stretch of contiguous elements we can skip as one. + let inner_lane_len; + if self_.ndim() > 1 && self_.strides.last_elem() == 1 { + self_.dim.slice_mut().rotate_right(1); + self_.strides.slice_mut().rotate_right(1); + inner_lane_len = self_.dim[0]; + self_.dim[0] = 1; + self_.strides[0] = 1; + } else { + inner_lane_len = 1; + } + + // iter is a raw pointer iterator traversing the array in memory order now with the + // sorted axes. + let mut iter = Baseiter::new(self_.ptr, self_.dim, self_.strides); + let mut dropped_elements = 0; + + let mut last_ptr = data_ptr; + + while let Some(elem_ptr) = iter.next() { + // The interval from last_ptr up until (not including) elem_ptr + // should now be dropped. This interval may be empty, then we just skip this loop. + while last_ptr != elem_ptr { + debug_assert!(last_ptr < data_end_ptr); + std::ptr::drop_in_place(last_ptr.as_mut()); + last_ptr = last_ptr.add(1); + dropped_elements += 1; + } + // Next interval will continue one past the current lane + last_ptr = elem_ptr.add(inner_lane_len); + } + + while last_ptr < data_end_ptr { + std::ptr::drop_in_place(last_ptr.as_mut()); + last_ptr = last_ptr.add(1); + dropped_elements += 1; + } + + assert_eq!(data_len, dropped_elements + self_len, + "Internal error: inconsistency in move_into"); +} + +/// Sort axes to standard order, i.e Axis(0) has biggest stride and Axis(n - 1) least stride +/// +/// The axes should have stride >= 0 before calling this method. +fn sort_axes_in_default_order(a: &mut ArrayBase) +where + S: RawData, + D: Dimension, +{ + if a.ndim() <= 1 { + return; + } + sort_axes1_impl(&mut a.dim, &mut a.strides); +} + +fn sort_axes1_impl(adim: &mut D, astrides: &mut D) +where D: Dimension +{ + debug_assert!(adim.ndim() > 1); + debug_assert_eq!(adim.ndim(), astrides.ndim()); + // bubble sort axes + let mut changed = true; + while changed { + changed = false; + for i in 0..adim.ndim() - 1 { + let axis_i = i; + let next_axis = i + 1; + + // make sure higher stride axes sort before. + debug_assert!(astrides.slice()[axis_i] as isize >= 0); + if (astrides.slice()[axis_i] as isize) < astrides.slice()[next_axis] as isize { + changed = true; + adim.slice_mut().swap(axis_i, next_axis); + astrides.slice_mut().swap(axis_i, next_axis); + } + } + } +} + +/// Sort axes to standard order, i.e Axis(0) has biggest stride and Axis(n - 1) least stride +/// +/// Axes in a and b are sorted by the strides of `a`, and `a`'s axes should have stride >= 0 before +/// calling this method. +fn sort_axes_in_default_order_tandem(a: &mut ArrayBase, b: &mut ArrayBase) +where + S: RawData, + S2: RawData, + D: Dimension, +{ + if a.ndim() <= 1 { + return; + } + sort_axes2_impl(&mut a.dim, &mut a.strides, &mut b.dim, &mut b.strides); +} + +fn sort_axes2_impl(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D) +where D: Dimension +{ + debug_assert!(adim.ndim() > 1); + debug_assert_eq!(adim.ndim(), bdim.ndim()); + // bubble sort axes + let mut changed = true; + while changed { + changed = false; + for i in 0..adim.ndim() - 1 { + let axis_i = i; + let next_axis = i + 1; + + // make sure higher stride axes sort before. + debug_assert!(astrides.slice()[axis_i] as isize >= 0); + if (astrides.slice()[axis_i] as isize) < astrides.slice()[next_axis] as isize { + changed = true; + adim.slice_mut().swap(axis_i, next_axis); + astrides.slice_mut().swap(axis_i, next_axis); + bdim.slice_mut().swap(axis_i, next_axis); + bstrides.slice_mut().swap(axis_i, next_axis); + } + } } } diff --git a/src/impl_raw_views.rs b/src/impl_raw_views.rs index d10c7909e..5132b1158 100644 --- a/src/impl_raw_views.rs +++ b/src/impl_raw_views.rs @@ -1,37 +1,38 @@ +use num_complex::Complex; use std::mem; use std::ptr::NonNull; use crate::dimension::{self, stride_offset}; use crate::extension::nonnull::nonnull_debug_checked_from_ptr; use crate::imp_prelude::*; -use crate::{is_aligned, StrideShape}; +use crate::is_aligned; +use crate::shape_builder::{StrideShape, Strides}; impl RawArrayView -where - D: Dimension, +where D: Dimension { /// Create a new `RawArrayView`. /// /// Unsafe because caller is responsible for ensuring that the array will /// meet all of the invariants of the `ArrayBase` type. #[inline] - pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self { - RawArrayView { - data: RawViewRepr::new(), - ptr, - dim, - strides, - } + pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self + { + RawArrayView::from_data_ptr(RawViewRepr::new(), ptr).with_strides_dim(strides, dim) } - unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self { + #[inline] + unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self + { Self::new(nonnull_debug_checked_from_ptr(ptr as *mut A), dim, strides) } /// Create an `RawArrayView` from shape information and a raw pointer /// to the elements. /// - /// Unsafe because caller is responsible for ensuring all of the following: + /// # Safety + /// + /// The caller is responsible for ensuring all of the following: /// /// * `ptr` must be non-null, and it must be safe to [`.offset()`] `ptr` by /// zero. @@ -60,29 +61,42 @@ where /// /// * The product of non-zero axis lengths must not exceed `isize::MAX`. /// + /// * Strides must be non-negative. + /// + /// This function can use debug assertions to check some of these requirements, + /// but it's not a complete check. + /// /// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset + #[inline] pub unsafe fn from_shape_ptr(shape: Sh, ptr: *const A) -> Self - where - Sh: Into>, + where Sh: Into> { let shape = shape.into(); let dim = shape.dim; - let strides = shape.strides; if cfg!(debug_assertions) { assert!(!ptr.is_null(), "The pointer must be non-null."); - dimension::max_abs_offset_check_overflow::(&dim, &strides).unwrap(); + if let Strides::Custom(strides) = &shape.strides { + dimension::strides_non_negative(strides).unwrap(); + dimension::max_abs_offset_check_overflow::(&dim, strides).unwrap(); + } else { + dimension::size_of_shape_checked(&dim).unwrap(); + } } + let strides = shape.strides.strides_for_dim(&dim); RawArrayView::new_(ptr, dim, strides) } /// Converts to a read-only view of the array. /// - /// **Warning** from a safety standpoint, this is equivalent to - /// dereferencing a raw pointer for every element in the array. You must - /// ensure that all of the data is valid, ensure that the pointer is - /// aligned, and choose the correct lifetime. + /// # Safety + /// + /// From a safety standpoint, this is equivalent to dereferencing a raw + /// pointer for every element in the array. You must ensure that all of the + /// data is valid, ensure that the pointer is aligned, and choose the + /// correct lifetime. #[inline] - pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> { + pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> + { debug_assert!( is_aligned(self.ptr.as_ptr()), "The pointer must be aligned." @@ -94,7 +108,10 @@ where /// before the split and one array pointer after the split. /// /// **Panics** if `axis` or `index` is out of bounds. - pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { + #[track_caller] + #[inline] + pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) + { assert!(index <= self.len_of(axis)); let left_ptr = self.ptr.as_ptr(); let right_ptr = if index == self.len_of(axis) { @@ -127,7 +144,9 @@ where /// While this method is safe, for the same reason as regular raw pointer /// casts are safe, access through the produced raw view is only possible /// in an unsafe block or function. - pub fn cast(self) -> RawArrayView { + #[track_caller] + pub fn cast(self) -> RawArrayView + { assert_eq!( mem::size_of::(), mem::size_of::(), @@ -138,32 +157,98 @@ where } } +impl RawArrayView, D> +where D: Dimension +{ + /// Splits the view into views of the real and imaginary components of the + /// elements. + pub fn split_complex(self) -> Complex> + { + // Check that the size and alignment of `Complex` are as expected. + // These assertions should always pass, for arbitrary `T`. + assert_eq!( + mem::size_of::>(), + mem::size_of::().checked_mul(2).unwrap() + ); + assert_eq!(mem::align_of::>(), mem::align_of::()); + + let dim = self.dim.clone(); + + // Double the strides. In the zero-sized element case and for axes of + // length <= 1, we leave the strides as-is to avoid possible overflow. + let mut strides = self.strides.clone(); + if mem::size_of::() != 0 { + for ax in 0..strides.ndim() { + if dim[ax] > 1 { + strides[ax] = (strides[ax] as isize * 2) as usize; + } + } + } + + let ptr_re: *mut T = self.ptr.as_ptr().cast(); + let ptr_im: *mut T = if self.is_empty() { + // In the empty case, we can just reuse the existing pointer since + // it won't be dereferenced anyway. It is not safe to offset by + // one, since the allocation may be empty. + ptr_re + } else { + // In the nonempty case, we can safely offset into the first + // (complex) element. + unsafe { ptr_re.add(1) } + }; + + // `Complex` is `repr(C)` with only fields `re: T` and `im: T`. So, the + // real components of the elements start at the same pointer, and the + // imaginary components start at the pointer offset by one, with + // exactly double the strides. The new, doubled strides still meet the + // overflow constraints: + // + // - For the zero-sized element case, the strides are unchanged in + // units of bytes and in units of the element type. + // + // - For the nonzero-sized element case: + // + // - In units of bytes, the strides are unchanged. The only exception + // is axes of length <= 1, but those strides are irrelevant anyway. + // + // - Since `Complex` for nonzero `T` is always at least 2 bytes, + // and the original strides did not overflow in units of bytes, we + // know that the new, doubled strides will not overflow in units of + // `T`. + unsafe { + Complex { + re: RawArrayView::new_(ptr_re, dim.clone(), strides.clone()), + im: RawArrayView::new_(ptr_im, dim, strides), + } + } + } +} + impl RawArrayViewMut -where - D: Dimension, +where D: Dimension { /// Create a new `RawArrayViewMut`. /// /// Unsafe because caller is responsible for ensuring that the array will /// meet all of the invariants of the `ArrayBase` type. #[inline] - pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self { - RawArrayViewMut { - data: RawViewRepr::new(), - ptr, - dim, - strides, - } + pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self + { + RawArrayViewMut::from_data_ptr(RawViewRepr::new(), ptr).with_strides_dim(strides, dim) } - unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self { + #[inline] + unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self + { Self::new(nonnull_debug_checked_from_ptr(ptr), dim, strides) } /// Create an `RawArrayViewMut` from shape information and a raw /// pointer to the elements. /// - /// Unsafe because caller is responsible for ensuring all of the following: + /// # Safety + /// + /// The caller is responsible for ensuring all of the following: /// /// * `ptr` must be non-null, and it must be safe to [`.offset()`] `ptr` by /// zero. @@ -192,35 +277,51 @@ where /// /// * The product of non-zero axis lengths must not exceed `isize::MAX`. /// + /// * Strides must be non-negative. + /// + /// This function can use debug assertions to check some of these requirements, + /// but it's not a complete check. + /// /// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset + #[inline] pub unsafe fn from_shape_ptr(shape: Sh, ptr: *mut A) -> Self - where - Sh: Into>, + where Sh: Into> { let shape = shape.into(); let dim = shape.dim; - let strides = shape.strides; if cfg!(debug_assertions) { assert!(!ptr.is_null(), "The pointer must be non-null."); - dimension::max_abs_offset_check_overflow::(&dim, &strides).unwrap(); + if let Strides::Custom(strides) = &shape.strides { + dimension::strides_non_negative(strides).unwrap(); + dimension::max_abs_offset_check_overflow::(&dim, strides).unwrap(); + assert!(!dimension::dim_stride_overlap(&dim, strides), + "The strides must not allow any element to be referenced by two different indices"); + } else { + dimension::size_of_shape_checked(&dim).unwrap(); + } } + let strides = shape.strides.strides_for_dim(&dim); RawArrayViewMut::new_(ptr, dim, strides) } /// Converts to a non-mutable `RawArrayView`. #[inline] - pub(crate) fn into_raw_view(self) -> RawArrayView { + pub(crate) fn into_raw_view(self) -> RawArrayView + { unsafe { RawArrayView::new(self.ptr, self.dim, self.strides) } } /// Converts to a read-only view of the array. /// - /// **Warning** from a safety standpoint, this is equivalent to - /// dereferencing a raw pointer for every element in the array. You must - /// ensure that all of the data is valid, ensure that the pointer is - /// aligned, and choose the correct lifetime. + /// # Safety + /// + /// From a safety standpoint, this is equivalent to dereferencing a raw + /// pointer for every element in the array. You must ensure that all of the + /// data is valid, ensure that the pointer is aligned, and choose the + /// correct lifetime. #[inline] - pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> { + pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> + { debug_assert!( is_aligned(self.ptr.as_ptr()), "The pointer must be aligned." @@ -230,12 +331,15 @@ where /// Converts to a mutable view of the array. /// - /// **Warning** from a safety standpoint, this is equivalent to - /// dereferencing a raw pointer for every element in the array. You must - /// ensure that all of the data is valid, ensure that the pointer is - /// aligned, and choose the correct lifetime. + /// # Safety + /// + /// From a safety standpoint, this is equivalent to dereferencing a raw + /// pointer for every element in the array. You must ensure that all of the + /// data is valid, ensure that the pointer is aligned, and choose the + /// correct lifetime. #[inline] - pub unsafe fn deref_into_view_mut<'a>(self) -> ArrayViewMut<'a, A, D> { + pub unsafe fn deref_into_view_mut<'a>(self) -> ArrayViewMut<'a, A, D> + { debug_assert!( is_aligned(self.ptr.as_ptr()), "The pointer must be aligned." @@ -247,14 +351,12 @@ where /// before the split and one array pointer after the split. /// /// **Panics** if `axis` or `index` is out of bounds. - pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { + #[track_caller] + #[inline] + pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) + { let (left, right) = self.into_raw_view().split_at(axis, index); - unsafe { - ( - Self::new(left.ptr, left.dim, left.strides), - Self::new(right.ptr, right.dim, right.strides), - ) - } + unsafe { (Self::new(left.ptr, left.dim, left.strides), Self::new(right.ptr, right.dim, right.strides)) } } /// Cast the raw pointer of the raw array view to a different type @@ -267,7 +369,9 @@ where /// While this method is safe, for the same reason as regular raw pointer /// casts are safe, access through the produced raw view is only possible /// in an unsafe block or function. - pub fn cast(self) -> RawArrayViewMut { + #[track_caller] + pub fn cast(self) -> RawArrayViewMut + { assert_eq!( mem::size_of::(), mem::size_of::(), @@ -277,3 +381,20 @@ where unsafe { RawArrayViewMut::new(ptr, self.dim, self.strides) } } } + +impl RawArrayViewMut, D> +where D: Dimension +{ + /// Splits the view into views of the real and imaginary components of the + /// elements. + pub fn split_complex(self) -> Complex> + { + let Complex { re, im } = self.into_raw_view().split_complex(); + unsafe { + Complex { + re: RawArrayViewMut::new(re.ptr, re.dim, re.strides), + im: RawArrayViewMut::new(im.ptr, im.dim, im.strides), + } + } + } +} diff --git a/src/impl_special_element_types.rs b/src/impl_special_element_types.rs new file mode 100644 index 000000000..e430b20bc --- /dev/null +++ b/src/impl_special_element_types.rs @@ -0,0 +1,48 @@ +// Copyright 2020 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::mem::MaybeUninit; + +use crate::imp_prelude::*; +use crate::RawDataSubst; + +/// Methods specific to arrays with `MaybeUninit` elements. +/// +/// ***See also all methods for [`ArrayBase`]*** +impl ArrayBase +where + S: RawDataSubst>, + D: Dimension, +{ + /// **Promise** that the array's elements are all fully initialized, and convert + /// the array from element type `MaybeUninit` to `A`. + /// + /// For example, it can convert an `Array, D>` to `Array`. + /// + /// ## Safety + /// + /// Safe to use if all the array's elements have been initialized. + /// + /// Note that for owned and shared ownership arrays, the promise must include all of the + /// array's storage; it is for example possible to slice these in place, but that must + /// only be done after all elements have been initialized. + pub unsafe fn assume_init(self) -> ArrayBase<>::Output, D> + { + let ArrayBase { + data, + ptr, + dim, + strides, + } = self; + + // "transmute" from storage of MaybeUninit to storage of A + let data = S::data_subst(data); + let ptr = ptr.cast::(); + ArrayBase::from_data_ptr(data, ptr).with_strides_dim(strides, dim) + } +} diff --git a/src/impl_views/constructors.rs b/src/impl_views/constructors.rs index efa854e51..d0089057d 100644 --- a/src/impl_views/constructors.rs +++ b/src/impl_views/constructors.rs @@ -8,7 +8,8 @@ use std::ptr::NonNull; -use crate::dimension; +use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr; +use crate::dimension::{self, CanIndexCheckMode}; use crate::error::ShapeError; use crate::extension::nonnull::nonnull_debug_checked_from_ptr; use crate::imp_prelude::*; @@ -16,8 +17,7 @@ use crate::{is_aligned, StrideShape}; /// Methods for read-only array views. impl<'a, A, D> ArrayView<'a, A, D> -where - D: Dimension, +where D: Dimension { /// Create a read-only array view borrowing its data from a slice. /// @@ -29,6 +29,7 @@ where /// use ndarray::arr3; /// use ndarray::ShapeBuilder; /// + /// // advanced example where we are even specifying exact strides to use (which is optional). /// let s = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; /// let a = ArrayView::from_shape((2, 3, 2).strides((1, 4, 2)), /// &s).unwrap(); @@ -44,28 +45,33 @@ where /// assert!(a.strides() == &[1, 4, 2]); /// ``` pub fn from_shape(shape: Sh, xs: &'a [A]) -> Result - where - Sh: Into>, + where Sh: Into> { // eliminate the type parameter Sh as soon as possible Self::from_shape_impl(shape.into(), xs) } - fn from_shape_impl(shape: StrideShape, xs: &'a [A]) -> Result { + fn from_shape_impl(shape: StrideShape, xs: &'a [A]) -> Result + { let dim = shape.dim; - let strides = shape.strides; - if shape.custom { - dimension::can_index_slice(xs, &dim, &strides)?; - } else { - dimension::can_index_slice_not_custom::(xs, &dim)?; + dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::ReadOnly)?; + let strides = shape.strides.strides_for_dim(&dim); + unsafe { + Ok(Self::new_( + xs.as_ptr() + .add(offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides)), + dim, + strides, + )) } - unsafe { Ok(Self::new_(xs.as_ptr(), dim, strides)) } } /// Create an `ArrayView` from shape information and a raw pointer to /// the elements. /// - /// Unsafe because caller is responsible for ensuring all of the following: + /// # Safety + /// + /// The caller is responsible for ensuring all of the following: /// /// * The elements seen by moving `ptr` according to the shape and strides /// must live at least as long as `'a` and must not be not mutably @@ -98,10 +104,15 @@ where /// /// * The product of non-zero axis lengths must not exceed `isize::MAX`. /// + /// * Strides must be non-negative. + /// + /// This function can use debug assertions to check some of these requirements, + /// but it's not a complete check. + /// /// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset + #[inline] pub unsafe fn from_shape_ptr(shape: Sh, ptr: *const A) -> Self - where - Sh: Into>, + where Sh: Into> { RawArrayView::from_shape_ptr(shape, ptr).deref_into_view() } @@ -109,8 +120,7 @@ where /// Methods for read-write array views. impl<'a, A, D> ArrayViewMut<'a, A, D> -where - D: Dimension, +where D: Dimension { /// Create a read-write array view borrowing its data from a slice. /// @@ -138,28 +148,33 @@ where /// assert!(a.strides() == &[1, 4, 2]); /// ``` pub fn from_shape(shape: Sh, xs: &'a mut [A]) -> Result - where - Sh: Into>, + where Sh: Into> { // eliminate the type parameter Sh as soon as possible Self::from_shape_impl(shape.into(), xs) } - fn from_shape_impl(shape: StrideShape, xs: &'a mut [A]) -> Result { + fn from_shape_impl(shape: StrideShape, xs: &'a mut [A]) -> Result + { let dim = shape.dim; - let strides = shape.strides; - if shape.custom { - dimension::can_index_slice(xs, &dim, &strides)?; - } else { - dimension::can_index_slice_not_custom::(xs, &dim)?; + dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::OwnedMutable)?; + let strides = shape.strides.strides_for_dim(&dim); + unsafe { + Ok(Self::new_( + xs.as_mut_ptr() + .add(offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides)), + dim, + strides, + )) } - unsafe { Ok(Self::new_(xs.as_mut_ptr(), dim, strides)) } } /// Create an `ArrayViewMut` from shape information and a /// raw pointer to the elements. /// - /// Unsafe because caller is responsible for ensuring all of the following: + /// # Safety + /// + /// The caller is responsible for ensuring all of the following: /// /// * The elements seen by moving `ptr` according to the shape and strides /// must live at least as long as `'a` and must not be aliased for the @@ -192,10 +207,15 @@ where /// /// * The product of non-zero axis lengths must not exceed `isize::MAX`. /// + /// * Strides must be non-negative. + /// + /// This function can use debug assertions to check some of these requirements, + /// but it's not a complete check. + /// /// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset + #[inline] pub unsafe fn from_shape_ptr(shape: Sh, ptr: *mut A) -> Self - where - Sh: Into>, + where Sh: Into> { RawArrayViewMut::from_shape_ptr(shape, ptr).deref_into_view_mut() } @@ -203,68 +223,59 @@ where /// Convert the view into an `ArrayViewMut<'b, A, D>` where `'b` is a lifetime /// outlived by `'a'`. pub fn reborrow<'b>(self) -> ArrayViewMut<'b, A, D> - where - 'a: 'b, + where 'a: 'b { unsafe { ArrayViewMut::new(self.ptr, self.dim, self.strides) } } } /// Private array view methods -impl<'a, A, D> ArrayView<'a, A, D> -where - D: Dimension, +impl ArrayView<'_, A, D> +where D: Dimension { /// Create a new `ArrayView` /// /// Unsafe because: `ptr` must be valid for the given dimension and strides. #[inline(always)] - pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self { + pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self + { if cfg!(debug_assertions) { assert!(is_aligned(ptr.as_ptr()), "The pointer must be aligned."); dimension::max_abs_offset_check_overflow::(&dim, &strides).unwrap(); } - ArrayView { - data: ViewRepr::new(), - ptr, - dim, - strides, - } + ArrayView::from_data_ptr(ViewRepr::new(), ptr).with_strides_dim(strides, dim) } /// Unsafe because: `ptr` must be valid for the given dimension and strides. #[inline] - pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self { + pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self + { Self::new(nonnull_debug_checked_from_ptr(ptr as *mut A), dim, strides) } } -impl<'a, A, D> ArrayViewMut<'a, A, D> -where - D: Dimension, +impl ArrayViewMut<'_, A, D> +where D: Dimension { /// Create a new `ArrayView` /// /// Unsafe because: `ptr` must be valid for the given dimension and strides. #[inline(always)] - pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self { + pub(crate) unsafe fn new(ptr: NonNull, dim: D, strides: D) -> Self + { if cfg!(debug_assertions) { assert!(is_aligned(ptr.as_ptr()), "The pointer must be aligned."); dimension::max_abs_offset_check_overflow::(&dim, &strides).unwrap(); } - ArrayViewMut { - data: ViewRepr::new(), - ptr, - dim, - strides, - } + ArrayViewMut::from_data_ptr(ViewRepr::new(), ptr).with_strides_dim(strides, dim) } /// Create a new `ArrayView` /// /// Unsafe because: `ptr` must be valid for the given dimension and strides. #[inline(always)] - pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self { + pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self + { Self::new(nonnull_debug_checked_from_ptr(ptr), dim, strides) } } diff --git a/src/impl_views/conversions.rs b/src/impl_views/conversions.rs index 303541b8b..1dd7d97f2 100644 --- a/src/impl_views/conversions.rs +++ b/src/impl_views/conversions.rs @@ -6,34 +6,39 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::slice; +use alloc::slice; +#[allow(unused_imports)] +use rawpointer::PointerExt; +use std::mem::MaybeUninit; use crate::imp_prelude::*; use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut}; +use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr; use crate::iter::{self, AxisIter, AxisIterMut}; +use crate::math_cell::MathCell; use crate::IndexLonger; /// Methods for read-only array views. impl<'a, A, D> ArrayView<'a, A, D> -where - D: Dimension, +where D: Dimension { /// Convert the view into an `ArrayView<'b, A, D>` where `'b` is a lifetime /// outlived by `'a'`. pub fn reborrow<'b>(self) -> ArrayView<'b, A, D> - where - 'a: 'b, + where 'a: 'b { unsafe { ArrayView::new(self.ptr, self.dim, self.strides) } } /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. - #[deprecated(note = "`into_slice` has been renamed to `to_slice`", since = "0.13.0")] - #[allow(clippy::wrong_self_convention)] - pub fn into_slice(&self) -> Option<&'a [A]> { + /// + /// Note that while the method is similar to [`ArrayBase::as_slice()`], this method transfers + /// the view's lifetime to the slice, so it is a bit more powerful. + pub fn to_slice(&self) -> Option<&'a [A]> + { if self.is_standard_layout() { unsafe { Some(slice::from_raw_parts(self.ptr.as_ptr(), self.len())) } } else { @@ -41,18 +46,26 @@ where } } - /// Return the array’s data as a slice, if it is contiguous and in standard order. + /// Return the array’s data as a slice, if it is contiguous. /// Return `None` otherwise. - pub fn to_slice(&self) -> Option<&'a [A]> { - if self.is_standard_layout() { - unsafe { Some(slice::from_raw_parts(self.ptr.as_ptr(), self.len())) } + /// + /// Note that while the method is similar to + /// [`ArrayBase::as_slice_memory_order()`], this method transfers the view's + /// lifetime to the slice, so it is a bit more powerful. + pub fn to_slice_memory_order(&self) -> Option<&'a [A]> + { + if self.is_contiguous() { + let offset = offset_from_low_addr_ptr_to_logical_ptr(&self.dim, &self.strides); + unsafe { Some(slice::from_raw_parts(self.ptr.sub(offset).as_ptr(), self.len())) } } else { None } } /// Converts to a raw array view. - pub(crate) fn into_raw_view(self) -> RawArrayView { + #[inline] + pub(crate) fn into_raw_view(self) -> RawArrayView + { unsafe { RawArrayView::new(self.ptr, self.dim, self.strides) } } } @@ -60,10 +73,8 @@ where /// Methods specific to `ArrayView0`. /// /// ***See also all methods for [`ArrayView`] and [`ArrayBase`]*** -/// -/// [`ArrayBase`]: struct.ArrayBase.html -/// [`ArrayView`]: struct.ArrayView.html -impl<'a, A> ArrayView<'a, A, Ix0> { +impl<'a, A> ArrayView<'a, A, Ix0> +{ /// Consume the view and return a reference to the single element in the array. /// /// The lifetime of the returned reference matches the lifetime of the data @@ -81,7 +92,8 @@ impl<'a, A> ArrayView<'a, A, Ix0> { /// let scalar: &Foo = view.into_scalar(); /// assert_eq!(scalar, &Foo); /// ``` - pub fn into_scalar(self) -> &'a A { + pub fn into_scalar(self) -> &'a A + { self.index(Ix0()) } } @@ -89,10 +101,8 @@ impl<'a, A> ArrayView<'a, A, Ix0> { /// Methods specific to `ArrayViewMut0`. /// /// ***See also all methods for [`ArrayViewMut`] and [`ArrayBase`]*** -/// -/// [`ArrayBase`]: struct.ArrayBase.html -/// [`ArrayViewMut`]: struct.ArrayViewMut.html -impl<'a, A> ArrayViewMut<'a, A, Ix0> { +impl<'a, A> ArrayViewMut<'a, A, Ix0> +{ /// Consume the mutable view and return a mutable reference to the single element in the array. /// /// The lifetime of the returned reference matches the lifetime of the data @@ -103,44 +113,124 @@ impl<'a, A> ArrayViewMut<'a, A, Ix0> { /// /// let mut array: Array0 = arr0(5.); /// let view = array.view_mut(); - /// let mut scalar = view.into_scalar(); + /// let scalar = view.into_scalar(); /// *scalar = 7.; /// assert_eq!(scalar, &7.); /// assert_eq!(array[()], 7.); /// ``` - pub fn into_scalar(self) -> &'a mut A { + pub fn into_scalar(self) -> &'a mut A + { self.index(Ix0()) } } /// Methods for read-write array views. impl<'a, A, D> ArrayViewMut<'a, A, D> -where - D: Dimension, +where D: Dimension { /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. - pub fn into_slice(self) -> Option<&'a mut [A]> { - self.into_slice_().ok() + /// + /// Note that while this is similar to [`ArrayBase::as_slice_mut()`], this method transfers the + /// view's lifetime to the slice. + pub fn into_slice(self) -> Option<&'a mut [A]> + { + self.try_into_slice().ok() + } + + /// Return the array’s data as a slice, if it is contiguous. + /// Return `None` otherwise. + /// + /// Note that while this is similar to + /// [`ArrayBase::as_slice_memory_order_mut()`], this method transfers the + /// view's lifetime to the slice. + pub fn into_slice_memory_order(self) -> Option<&'a mut [A]> + { + self.try_into_slice_memory_order().ok() + } + + /// Return a shared view of the array with elements as if they were embedded in cells. + /// + /// The cell view itself can be copied and accessed without exclusivity. + /// + /// The view acts "as if" the elements are temporarily in cells, and elements + /// can be changed through shared references using the regular cell methods. + pub fn into_cell_view(self) -> ArrayView<'a, MathCell, D> + { + // safety: valid because + // A and MathCell have the same representation + // &'a mut T is interchangeable with &'a Cell -- see method Cell::from_mut in std + unsafe { + self.into_raw_view_mut() + .cast::>() + .deref_into_view() + } + } + + /// Return the array view as a view of `MaybeUninit` elements + /// + /// This conversion leaves the elements as they were (presumably initialized), but + /// they are represented with the `MaybeUninit` type. Effectively this means that + /// the elements can be overwritten without dropping the old element in its place. + /// (In some situations this is not what you want, while for `Copy` elements it makes + /// no difference at all.) + /// + /// # Safety + /// + /// This method allows writing uninitialized data into the view, which could leave any + /// original array that we borrow from in an inconsistent state. This is not allowed + /// when using the resulting array view. + pub(crate) unsafe fn into_maybe_uninit(self) -> ArrayViewMut<'a, MaybeUninit, D> + { + // Safe because: A and MaybeUninit have the same representation; + // and we can go from initialized to (maybe) not unconditionally in terms of + // representation. However, the user must be careful to not write uninit elements + // through the view. + self.into_raw_view_mut() + .cast::>() + .deref_into_view_mut() + } +} + +/// Private raw array view methods +impl RawArrayView +where D: Dimension +{ + #[inline] + pub(crate) fn into_base_iter(self) -> Baseiter + { + unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } + } +} + +impl RawArrayViewMut +where D: Dimension +{ + #[inline] + pub(crate) fn into_base_iter(self) -> Baseiter + { + unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } } } /// Private array view methods impl<'a, A, D> ArrayView<'a, A, D> -where - D: Dimension, +where D: Dimension { #[inline] - pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) } + pub(crate) fn into_base_iter(self) -> Baseiter + { + unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } } #[inline] - pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> { + pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> + { ElementsBase::new(self) } - pub(crate) fn into_iter_(self) -> Iter<'a, A, D> { + pub(crate) fn into_iter_(self) -> Iter<'a, A, D> + { Iter::new(self) } @@ -148,38 +238,43 @@ where #[doc(hidden)] // not official #[deprecated(note = "This method will be replaced.")] pub fn into_outer_iter(self) -> iter::AxisIter<'a, A, D::Smaller> - where - D: RemoveAxis, + where D: RemoveAxis { AxisIter::new(self, Axis(0)) } } impl<'a, A, D> ArrayViewMut<'a, A, D> -where - D: Dimension, +where D: Dimension { // Convert into a read-only view - pub(crate) fn into_view(self) -> ArrayView<'a, A, D> { + pub(crate) fn into_view(self) -> ArrayView<'a, A, D> + { unsafe { ArrayView::new(self.ptr, self.dim, self.strides) } } /// Converts to a mutable raw array view. - pub(crate) fn into_raw_view_mut(self) -> RawArrayViewMut { + pub(crate) fn into_raw_view_mut(self) -> RawArrayViewMut + { unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) } } #[inline] - pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) } + pub(crate) fn into_base_iter(self) -> Baseiter + { + unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } } #[inline] - pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> { + pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> + { ElementsBaseMut::new(self) } - pub(crate) fn into_slice_(self) -> Result<&'a mut [A], Self> { + /// Return the array’s data as a slice, if it is contiguous and in standard order. + /// Otherwise return self in the Err branch of the result. + pub(crate) fn try_into_slice(self) -> Result<&'a mut [A], Self> + { if self.is_standard_layout() { unsafe { Ok(slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len())) } } else { @@ -187,7 +282,20 @@ where } } - pub(crate) fn into_iter_(self) -> IterMut<'a, A, D> { + /// Return the array’s data as a slice, if it is contiguous. + /// Otherwise return self in the Err branch of the result. + fn try_into_slice_memory_order(self) -> Result<&'a mut [A], Self> + { + if self.is_contiguous() { + let offset = offset_from_low_addr_ptr_to_logical_ptr(&self.dim, &self.strides); + unsafe { Ok(slice::from_raw_parts_mut(self.ptr.sub(offset).as_ptr(), self.len())) } + } else { + Err(self) + } + } + + pub(crate) fn into_iter_(self) -> IterMut<'a, A, D> + { IterMut::new(self) } @@ -195,8 +303,7 @@ where #[doc(hidden)] // not official #[deprecated(note = "This method will be replaced.")] pub fn into_outer_iter(self) -> iter::AxisIterMut<'a, A, D::Smaller> - where - D: RemoveAxis, + where D: RemoveAxis { AxisIterMut::new(self, Axis(0)) } diff --git a/src/impl_views/indexing.rs b/src/impl_views/indexing.rs index b03c9a9c5..827313478 100644 --- a/src/impl_views/indexing.rs +++ b/src/impl_views/indexing.rs @@ -33,7 +33,7 @@ use crate::NdIndex; /// let data = [0.; 256]; /// let long_life_ref = { /// // make a 16 × 16 array view -/// let view = ArrayView::from(&data[..]).into_shape((16, 16)).unwrap(); +/// let view = ArrayView::from(&data[..]).into_shape_with_order((16, 16)).unwrap(); /// /// // index the view and with `IndexLonger`. /// // Note here that we get a reference with a life that is derived from @@ -46,7 +46,8 @@ use crate::NdIndex; /// assert_eq!(long_life_ref, &0.); /// /// ``` -pub trait IndexLonger { +pub trait IndexLonger +{ /// The type of the reference to the element that is produced, including /// its lifetime. type Output; @@ -59,9 +60,10 @@ pub trait IndexLonger { /// See also [the `get` method][1] which works for all arrays and array /// views. /// - /// [1]: struct.ArrayBase.html#method.get + /// [1]: ArrayBase::get /// /// **Panics** if index is out of bounds. + #[track_caller] fn index(self, index: I) -> Self::Output; /// Get a reference of a element through the view. @@ -73,10 +75,11 @@ pub trait IndexLonger { /// See also [the `get` method][1] (and [`get_mut`][2]) which works for all arrays and array /// views. /// - /// [1]: struct.ArrayBase.html#method.get - /// [2]: struct.ArrayBase.html#method.get_mut + /// [1]: ArrayBase::get + /// [2]: ArrayBase::get_mut /// /// **Panics** if index is out of bounds. + #[track_caller] fn get(self, index: I) -> Option; /// Get a reference of a element through the view without boundary check @@ -87,13 +90,17 @@ pub trait IndexLonger { /// See also [the `uget` method][1] which works for all arrays and array /// views. /// - /// [1]: struct.ArrayBase.html#method.uget + /// [1]: ArrayBase::uget /// /// **Note:** only unchecked for non-debug builds of ndarray. + /// + /// # Safety + /// + /// The caller must ensure that the index is in-bounds. unsafe fn uget(self, index: I) -> Self::Output; } -impl<'a, 'b, I, A, D> IndexLonger for &'b ArrayView<'a, A, D> +impl<'a, I, A, D> IndexLonger for &ArrayView<'a, A, D> where I: NdIndex, D: Dimension, @@ -109,15 +116,18 @@ where /// See also [the `get` method][1] which works for all arrays and array /// views. /// - /// [1]: struct.ArrayBase.html#method.get + /// [1]: ArrayBase::get /// /// **Panics** if index is out of bounds. - fn index(self, index: I) -> &'a A { + #[track_caller] + fn index(self, index: I) -> &'a A + { debug_bounds_check!(self, index); unsafe { &*self.get_ptr(index).unwrap_or_else(|| array_out_of_bounds()) } } - fn get(self, index: I) -> Option<&'a A> { + fn get(self, index: I) -> Option<&'a A> + { unsafe { self.get_ptr(index).map(|ptr| &*ptr) } } @@ -129,10 +139,11 @@ where /// See also [the `uget` method][1] which works for all arrays and array /// views. /// - /// [1]: struct.ArrayBase.html#method.uget + /// [1]: ArrayBase::uget /// /// **Note:** only unchecked for non-debug builds of ndarray. - unsafe fn uget(self, index: I) -> &'a A { + unsafe fn uget(self, index: I) -> &'a A + { debug_bounds_check!(self, index); &*self.as_ptr().offset(index.index_unchecked(&self.strides)) } @@ -154,13 +165,15 @@ where /// See also [the `get_mut` method][1] which works for all arrays and array /// views. /// - /// [1]: struct.ArrayBase.html#method.get_mut + /// [1]: ArrayBase::get_mut /// /// **Panics** if index is out of bounds. - fn index(mut self, index: I) -> &'a mut A { + #[track_caller] + fn index(mut self, index: I) -> &'a mut A + { debug_bounds_check!(self, index); unsafe { - match self.get_ptr_mut(index) { + match self.get_mut_ptr(index) { Some(ptr) => &mut *ptr, None => array_out_of_bounds(), } @@ -173,12 +186,13 @@ where /// See also [the `get_mut` method][1] which works for all arrays and array /// views. /// - /// [1]: struct.ArrayBase.html#method.get_mut + /// [1]: ArrayBase::get_mut /// - fn get(mut self, index: I) -> Option<&'a mut A> { + fn get(mut self, index: I) -> Option<&'a mut A> + { debug_bounds_check!(self, index); unsafe { - match self.get_ptr_mut(index) { + match self.get_mut_ptr(index) { Some(ptr) => Some(&mut *ptr), None => None, } @@ -191,10 +205,11 @@ where /// See also [the `uget_mut` method][1] which works for all arrays and array /// views. /// - /// [1]: struct.ArrayBase.html#method.uget_mut + /// [1]: ArrayBase::uget_mut /// /// **Note:** only unchecked for non-debug builds of ndarray. - unsafe fn uget(mut self, index: I) -> &'a mut A { + unsafe fn uget(mut self, index: I) -> &'a mut A + { debug_bounds_check!(self, index); &mut *self .as_mut_ptr() diff --git a/src/impl_views/mod.rs b/src/impl_views/mod.rs index 487cc3cb2..cabea5b37 100644 --- a/src/impl_views/mod.rs +++ b/src/impl_views/mod.rs @@ -3,7 +3,4 @@ mod conversions; mod indexing; mod splitting; -pub use constructors::*; -pub use conversions::*; pub use indexing::*; -pub use splitting::*; diff --git a/src/impl_views/splitting.rs b/src/impl_views/splitting.rs index dcfb04b86..58d0a7556 100644 --- a/src/impl_views/splitting.rs +++ b/src/impl_views/splitting.rs @@ -7,12 +7,12 @@ // except according to those terms. use crate::imp_prelude::*; -use crate::slice::MultiSlice; +use crate::slice::MultiSliceArg; +use num_complex::Complex; /// Methods for read-only array views. -impl<'a, A, D> ArrayView<'a, A, D> -where - D: Dimension, +impl ArrayView<'_, A, D> +where D: Dimension { /// Split the array view along `axis` and return one view strictly before the /// split and one view after the split. @@ -51,7 +51,7 @@ where /// ```rust /// # use ndarray::prelude::*; /// # let a = aview2(&[[0; 4]; 3]); - /// let (v1, v2) = a.split_at(Axis(0), 1); + /// let (v1, v2) = a.split_at(Axis(0), 2); /// ``` /// ```text /// ┌─────┬─────┬─────┬─────┐ 0 ↓ indices @@ -87,7 +87,10 @@ where /// 0 1 2 3 4 indices → /// along Axis(1) /// ``` - pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { + #[track_caller] + #[inline] + pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) + { unsafe { let (left, right) = self.into_raw_view().split_at(axis, index); (left.deref_into_view(), right.deref_into_view()) @@ -95,16 +98,49 @@ where } } +impl<'a, T, D> ArrayView<'a, Complex, D> +where D: Dimension +{ + /// Splits the view into views of the real and imaginary components of the + /// elements. + /// + /// ``` + /// use ndarray::prelude::*; + /// use num_complex::{Complex, Complex64}; + /// + /// let arr = array![ + /// [Complex64::new(1., 2.), Complex64::new(3., 4.)], + /// [Complex64::new(5., 6.), Complex64::new(7., 8.)], + /// [Complex64::new(9., 10.), Complex64::new(11., 12.)], + /// ]; + /// let Complex { re, im } = arr.view().split_complex(); + /// assert_eq!(re, array![[1., 3.], [5., 7.], [9., 11.]]); + /// assert_eq!(im, array![[2., 4.], [6., 8.], [10., 12.]]); + /// ``` + pub fn split_complex(self) -> Complex> + { + unsafe { + let Complex { re, im } = self.into_raw_view().split_complex(); + Complex { + re: re.deref_into_view(), + im: im.deref_into_view(), + } + } + } +} + /// Methods for read-write array views. impl<'a, A, D> ArrayViewMut<'a, A, D> -where - D: Dimension, +where D: Dimension { /// Split the array view along `axis` and return one mutable view strictly /// before the split and one mutable view after the split. /// /// **Panics** if `axis` or `index` is out of bounds. - pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { + #[track_caller] + #[inline] + pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) + { unsafe { let (left, right) = self.into_raw_view_mut().split_at(axis, index); (left.deref_into_view_mut(), right.deref_into_view_mut()) @@ -117,22 +153,59 @@ where /// consumes `self` and produces views with lifetimes matching that of /// `self`. /// - /// See [*Slicing*](#slicing) for full documentation. - /// See also [`SliceInfo`] and [`D::SliceArg`]. + /// See [*Slicing*](#slicing) for full documentation. See also + /// [`MultiSliceArg`], [`s!`], [`SliceArg`](crate::SliceArg), and + /// [`SliceInfo`](crate::SliceInfo). /// - /// [`.multi_slice_mut()`]: struct.ArrayBase.html#method.multi_slice_mut - /// [`SliceInfo`]: struct.SliceInfo.html - /// [`D::SliceArg`]: trait.Dimension.html#associatedtype.SliceArg + /// [`.multi_slice_mut()`]: ArrayBase::multi_slice_mut /// /// **Panics** if any of the following occur: /// /// * if any of the views would intersect (i.e. if any element would appear in multiple slices) /// * if an index is out of bounds or step size is zero /// * if `D` is `IxDyn` and `info` does not match the number of array axes + #[track_caller] pub fn multi_slice_move(self, info: M) -> M::Output - where - M: MultiSlice<'a, A, D>, + where M: MultiSliceArg<'a, A, D> { info.multi_slice_move(self) } } + +impl<'a, T, D> ArrayViewMut<'a, Complex, D> +where D: Dimension +{ + /// Splits the view into views of the real and imaginary components of the + /// elements. + /// + /// ``` + /// use ndarray::prelude::*; + /// use num_complex::{Complex, Complex64}; + /// + /// let mut arr = array![ + /// [Complex64::new(1., 2.), Complex64::new(3., 4.)], + /// [Complex64::new(5., 6.), Complex64::new(7., 8.)], + /// [Complex64::new(9., 10.), Complex64::new(11., 12.)], + /// ]; + /// + /// let Complex { mut re, mut im } = arr.view_mut().split_complex(); + /// assert_eq!(re, array![[1., 3.], [5., 7.], [9., 11.]]); + /// assert_eq!(im, array![[2., 4.], [6., 8.], [10., 12.]]); + /// + /// re[[0, 1]] = 13.; + /// im[[2, 0]] = 14.; + /// + /// assert_eq!(arr[[0, 1]], Complex64::new(13., 4.)); + /// assert_eq!(arr[[2, 0]], Complex64::new(9., 14.)); + /// ``` + pub fn split_complex(self) -> Complex> + { + unsafe { + let Complex { re, im } = self.into_raw_view_mut().split_complex(); + Complex { + re: re.deref_into_view_mut(), + im: im.deref_into_view_mut(), + } + } + } +} diff --git a/src/indexes.rs b/src/indexes.rs index 0218da51e..0fa2b50fb 100644 --- a/src/indexes.rs +++ b/src/indexes.rs @@ -7,7 +7,8 @@ // except according to those terms. use super::Dimension; use crate::dimension::IntoDimension; -use crate::zip::{Offset, Splittable}; +use crate::split_at::SplitAt; +use crate::zip::Offset; use crate::Axis; use crate::Layout; use crate::NdProducer; @@ -17,7 +18,8 @@ use crate::{ArrayBase, Data}; /// /// Iterator element type is `D`. #[derive(Clone)] -pub struct IndicesIter { +pub struct IndicesIter +{ dim: D, index: Option, } @@ -27,8 +29,7 @@ pub struct IndicesIter { /// *Note:* prefer higher order methods, arithmetic operations and /// non-indexed iteration before using indices. pub fn indices(shape: E) -> Indices -where - E: IntoDimension, +where E: IntoDimension { let dim = shape.into_dimension(); Indices { @@ -50,12 +51,12 @@ where } impl Iterator for IndicesIter -where - D: Dimension, +where D: Dimension { type Item = D::Pattern; #[inline] - fn next(&mut self) -> Option { + fn next(&mut self) -> Option + { let index = match self.index { None => return None, Some(ref ix) => ix.clone(), @@ -64,7 +65,8 @@ where Some(index.into_pattern()) } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { let l = match self.index { None => 0, Some(ref ix) => { @@ -74,7 +76,7 @@ where .slice() .iter() .zip(ix.slice().iter()) - .fold(0, |s, (&a, &b)| s + a as usize * b as usize); + .fold(0, |s, (&a, &b)| s + a * b); self.dim.size() - gone } }; @@ -82,8 +84,7 @@ where } fn fold(self, init: B, mut f: F) -> B - where - F: FnMut(B, D::Pattern) -> B, + where F: FnMut(B, D::Pattern) -> B { let IndicesIter { mut index, dim } = self; let ndim = dim.ndim(); @@ -111,18 +112,15 @@ where impl ExactSizeIterator for IndicesIter where D: Dimension {} impl IntoIterator for Indices -where - D: Dimension, +where D: Dimension { type Item = D::Pattern; type IntoIter = IndicesIter; - fn into_iter(self) -> Self::IntoIter { + fn into_iter(self) -> Self::IntoIter + { let sz = self.dim.size(); let index = if sz != 0 { Some(self.start) } else { None }; - IndicesIter { - index, - dim: self.dim, - } + IndicesIter { index, dim: self.dim } } } @@ -131,33 +129,48 @@ where /// `Indices` is an `NdProducer` that produces the indices of an array shape. #[derive(Copy, Clone, Debug)] pub struct Indices -where - D: Dimension, +where D: Dimension { start: D, dim: D, } #[derive(Copy, Clone, Debug)] -pub struct IndexPtr { +pub struct IndexPtr +{ index: D, } impl Offset for IndexPtr -where - D: Dimension + Copy, +where D: Dimension + Copy { // stride: The axis to increment type Stride = usize; - unsafe fn stride_offset(mut self, stride: Self::Stride, index: usize) -> Self { + unsafe fn stride_offset(mut self, stride: Self::Stride, index: usize) -> Self + { self.index[stride] += index; self } private_impl! {} } -impl NdProducer for Indices { +// How the NdProducer for Indices works. +// +// NdProducer allows for raw pointers (Ptr), strides (Stride) and the produced +// item (Item). +// +// Instead of Ptr, there is `IndexPtr` which is an index value, like [0, 0, 0] +// for the three dimensional case. +// +// The stride is simply which axis is currently being incremented. The stride for axis 1, is 1. +// +// .stride_offset(stride, index) simply computes the new index along that axis, for example: +// [0, 0, 0].stride_offset(1, 10) => [0, 10, 0] axis 1 is incremented by 10. +// +// .as_ref() converts the Ptr value to an Item. For example [0, 10, 0] => (0, 10, 0) +impl NdProducer for Indices +{ type Item = D::Pattern; type Dim = D; type Ptr = IndexPtr; @@ -165,23 +178,23 @@ impl NdProducer for Indices { private_impl! {} - #[doc(hidden)] - fn raw_dim(&self) -> Self::Dim { + fn raw_dim(&self) -> Self::Dim + { self.dim } - #[doc(hidden)] - fn equal_dim(&self, dim: &Self::Dim) -> bool { + fn equal_dim(&self, dim: &Self::Dim) -> bool + { self.dim.equal(dim) } - #[doc(hidden)] - fn as_ptr(&self) -> Self::Ptr { + fn as_ptr(&self) -> Self::Ptr + { IndexPtr { index: self.start } } - #[doc(hidden)] - fn layout(&self) -> Layout { + fn layout(&self) -> Layout + { if self.dim.ndim() <= 1 { Layout::one_dimensional() } else { @@ -189,44 +202,36 @@ impl NdProducer for Indices { } } - #[doc(hidden)] - unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item { + unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item + { ptr.index.into_pattern() } - #[doc(hidden)] - unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr { + unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr + { let mut index = *i; index += &self.start; IndexPtr { index } } - #[doc(hidden)] - fn stride_of(&self, axis: Axis) -> Self::Stride { + fn stride_of(&self, axis: Axis) -> Self::Stride + { axis.index() } #[inline(always)] - fn contiguous_stride(&self) -> Self::Stride { + fn contiguous_stride(&self) -> Self::Stride + { 0 } - #[doc(hidden)] - fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) + { let start_a = self.start; let mut start_b = start_a; let (a, b) = self.dim.split_at(axis, index); start_b[axis.index()] += index; - ( - Indices { - start: start_a, - dim: a, - }, - Indices { - start: start_b, - dim: b, - }, - ) + (Indices { start: start_a, dim: a }, Indices { start: start_b, dim: b }) } } @@ -234,15 +239,15 @@ impl NdProducer for Indices { /// /// Iterator element type is `D`. #[derive(Clone)] -pub struct IndicesIterF { +pub struct IndicesIterF +{ dim: D, index: D, has_remaining: bool, } pub fn indices_iter_f(shape: E) -> IndicesIterF -where - E: IntoDimension, +where E: IntoDimension { let dim = shape.into_dimension(); let zero = E::Dim::zeros(dim.ndim()); @@ -254,12 +259,12 @@ where } impl Iterator for IndicesIterF -where - D: Dimension, +where D: Dimension { type Item = D::Pattern; #[inline] - fn next(&mut self) -> Option { + fn next(&mut self) -> Option + { if !self.has_remaining { None } else { @@ -269,22 +274,19 @@ where } } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { if !self.has_remaining { return (0, Some(0)); } - let l = match self.index { - ref ix => { - let gone = self - .dim - .fortran_strides() - .slice() - .iter() - .zip(ix.slice().iter()) - .fold(0, |s, (&a, &b)| s + a as usize * b as usize); - self.dim.size() - gone - } - }; + let gone = self + .dim + .fortran_strides() + .slice() + .iter() + .zip(self.index.slice().iter()) + .fold(0, |s, (&a, &b)| s + a * b); + let l = self.dim.size() - gone; (l, Some(l)) } } @@ -292,12 +294,14 @@ where impl ExactSizeIterator for IndicesIterF where D: Dimension {} #[cfg(test)] -mod tests { +mod tests +{ use super::indices; use super::indices_iter_f; #[test] - fn test_indices_iter_c_size_hint() { + fn test_indices_iter_c_size_hint() + { let dim = (3, 4); let mut it = indices(dim).into_iter(); let mut len = dim.0 * dim.1; @@ -310,7 +314,8 @@ mod tests { } #[test] - fn test_indices_iter_c_fold() { + fn test_indices_iter_c_fold() + { macro_rules! run_test { ($dim:expr) => { for num_consume in 0..3 { @@ -338,7 +343,8 @@ mod tests { } #[test] - fn test_indices_iter_f_size_hint() { + fn test_indices_iter_f_size_hint() + { let dim = (3, 4); let mut it = indices_iter_f(dim); let mut len = dim.0 * dim.1; diff --git a/src/iterators/chunks.rs b/src/iterators/chunks.rs index e41c1bf25..9e2f08e1e 100644 --- a/src/iterators/chunks.rs +++ b/src/iterators/chunks.rs @@ -1,6 +1,7 @@ +use std::marker::PhantomData; + use crate::imp_prelude::*; -use crate::ElementsBase; -use crate::ElementsBaseMut; +use crate::Baseiter; use crate::IntoDimension; use crate::{Layout, NdProducer}; @@ -9,6 +10,7 @@ impl_ndproducer! { [Clone => 'a, A, D: Clone ] ExactChunks { base, + life, chunk, inner_strides, } @@ -23,28 +25,28 @@ impl_ndproducer! { } } -type BaseProducerRef<'a, A, D> = ArrayView<'a, A, D>; -type BaseProducerMut<'a, A, D> = ArrayViewMut<'a, A, D>; - /// Exact chunks producer and iterable. /// -/// See [`.exact_chunks()`](../struct.ArrayBase.html#method.exact_chunks) for more +/// See [`.exact_chunks()`](ArrayBase::exact_chunks) for more /// information. //#[derive(Debug)] -pub struct ExactChunks<'a, A, D> { - base: BaseProducerRef<'a, A, D>, +pub struct ExactChunks<'a, A, D> +{ + base: RawArrayView, + life: PhantomData<&'a A>, chunk: D, inner_strides: D, } -impl<'a, A, D: Dimension> ExactChunks<'a, A, D> { +impl<'a, A, D: Dimension> ExactChunks<'a, A, D> +{ /// Creates a new exact chunks producer. /// /// **Panics** if any chunk dimension is zero - pub(crate) fn new(mut a: ArrayView<'a, A, D>, chunk: E) -> Self - where - E: IntoDimension, + pub(crate) fn new(a: ArrayView<'a, A, D>, chunk: E) -> Self + where E: IntoDimension { + let mut a = a.into_raw_view(); let chunk = chunk.into_dimension(); ndassert!( a.ndim() == chunk.ndim(), @@ -59,11 +61,12 @@ impl<'a, A, D: Dimension> ExactChunks<'a, A, D> { for i in 0..a.ndim() { a.dim[i] /= chunk[i]; } - let inner_strides = a.raw_strides(); + let inner_strides = a.strides.clone(); a.strides *= &chunk; ExactChunks { base: a, + life: PhantomData, chunk, inner_strides, } @@ -77,9 +80,11 @@ where { type Item = ::Item; type IntoIter = ExactChunksIter<'a, A, D>; - fn into_iter(self) -> Self::IntoIter { + fn into_iter(self) -> Self::IntoIter + { ExactChunksIter { - iter: self.base.into_elements_base(), + iter: self.base.into_base_iter(), + life: self.life, chunk: self.chunk, inner_strides: self.inner_strides, } @@ -88,10 +93,12 @@ where /// Exact chunks iterator. /// -/// See [`.exact_chunks()`](../struct.ArrayBase.html#method.exact_chunks) for more +/// See [`.exact_chunks()`](ArrayBase::exact_chunks) for more /// information. -pub struct ExactChunksIter<'a, A, D> { - iter: ElementsBase<'a, A, D>, +pub struct ExactChunksIter<'a, A, D> +{ + iter: Baseiter, + life: PhantomData<&'a A>, chunk: D, inner_strides: D, } @@ -101,6 +108,7 @@ impl_ndproducer! { [Clone => ] ExactChunksMut { base, + life, chunk, inner_strides, } @@ -118,23 +126,26 @@ impl_ndproducer! { /// Exact chunks producer and iterable. /// -/// See [`.exact_chunks_mut()`](../struct.ArrayBase.html#method.exact_chunks_mut) +/// See [`.exact_chunks_mut()`](ArrayBase::exact_chunks_mut) /// for more information. //#[derive(Debug)] -pub struct ExactChunksMut<'a, A, D> { - base: BaseProducerMut<'a, A, D>, +pub struct ExactChunksMut<'a, A, D> +{ + base: RawArrayViewMut, + life: PhantomData<&'a mut A>, chunk: D, inner_strides: D, } -impl<'a, A, D: Dimension> ExactChunksMut<'a, A, D> { +impl<'a, A, D: Dimension> ExactChunksMut<'a, A, D> +{ /// Creates a new exact chunks producer. /// /// **Panics** if any chunk dimension is zero - pub(crate) fn new(mut a: ArrayViewMut<'a, A, D>, chunk: E) -> Self - where - E: IntoDimension, + pub(crate) fn new(a: ArrayViewMut<'a, A, D>, chunk: E) -> Self + where E: IntoDimension { + let mut a = a.into_raw_view_mut(); let chunk = chunk.into_dimension(); ndassert!( a.ndim() == chunk.ndim(), @@ -149,11 +160,12 @@ impl<'a, A, D: Dimension> ExactChunksMut<'a, A, D> { for i in 0..a.ndim() { a.dim[i] /= chunk[i]; } - let inner_strides = a.raw_strides(); + let inner_strides = a.strides.clone(); a.strides *= &chunk; ExactChunksMut { base: a, + life: PhantomData, chunk, inner_strides, } @@ -167,9 +179,11 @@ where { type Item = ::Item; type IntoIter = ExactChunksIterMut<'a, A, D>; - fn into_iter(self) -> Self::IntoIter { + fn into_iter(self) -> Self::IntoIter + { ExactChunksIterMut { - iter: self.base.into_elements_base(), + iter: self.base.into_base_iter(), + life: self.life, chunk: self.chunk, inner_strides: self.inner_strides, } @@ -181,16 +195,17 @@ impl_iterator! { [Clone => 'a, A, D: Clone] ExactChunksIter { iter, + life, chunk, inner_strides, } ExactChunksIter<'a, A, D> { type Item = ArrayView<'a, A, D>; - fn item(&mut self, elt) { + fn item(&mut self, ptr) { unsafe { - ArrayView::new_( - elt, + ArrayView::new( + ptr, self.chunk.clone(), self.inner_strides.clone()) } @@ -209,10 +224,10 @@ impl_iterator! { ExactChunksIterMut<'a, A, D> { type Item = ArrayViewMut<'a, A, D>; - fn item(&mut self, elt) { + fn item(&mut self, ptr) { unsafe { - ArrayViewMut::new_( - elt, + ArrayViewMut::new( + ptr, self.chunk.clone(), self.inner_strides.clone()) } @@ -222,10 +237,18 @@ impl_iterator! { /// Exact chunks iterator. /// -/// See [`.exact_chunks_mut()`](../struct.ArrayBase.html#method.exact_chunks_mut) +/// See [`.exact_chunks_mut()`](ArrayBase::exact_chunks_mut) /// for more information. -pub struct ExactChunksIterMut<'a, A, D> { - iter: ElementsBaseMut<'a, A, D>, +pub struct ExactChunksIterMut<'a, A, D> +{ + iter: Baseiter, + life: PhantomData<&'a mut A>, chunk: D, inner_strides: D, } + +send_sync_read_only!(ExactChunks); +send_sync_read_only!(ExactChunksIter); + +send_sync_read_write!(ExactChunksMut); +send_sync_read_write!(ExactChunksIterMut); diff --git a/src/iterators/into_iter.rs b/src/iterators/into_iter.rs new file mode 100644 index 000000000..9374608cb --- /dev/null +++ b/src/iterators/into_iter.rs @@ -0,0 +1,141 @@ +// Copyright 2020-2021 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::mem; +use std::ptr::NonNull; + +use crate::imp_prelude::*; +use crate::OwnedRepr; + +use super::Baseiter; +use crate::impl_owned_array::drop_unreachable_raw; + +/// By-value iterator for an array +pub struct IntoIter +where D: Dimension +{ + array_data: OwnedRepr, + inner: Baseiter, + data_len: usize, + /// first memory address of an array element + array_head_ptr: NonNull, + // if true, the array owns elements that are not reachable by indexing + // through all the indices of the dimension. + has_unreachable_elements: bool, +} + +impl IntoIter +where D: Dimension +{ + /// Create a new by-value iterator that consumes `array` + pub(crate) fn new(array: Array) -> Self + { + unsafe { + let array_head_ptr = array.ptr; + let mut array_data = array.data; + let data_len = array_data.release_all_elements(); + debug_assert!(data_len >= array.dim.size()); + let has_unreachable_elements = array.dim.size() != data_len; + let inner = Baseiter::new(array_head_ptr, array.dim, array.strides); + + IntoIter { + array_data, + inner, + data_len, + array_head_ptr, + has_unreachable_elements, + } + } + } +} + +impl Iterator for IntoIter +{ + type Item = A; + + #[inline] + fn next(&mut self) -> Option + { + self.inner.next().map(|p| unsafe { p.as_ptr().read() }) + } + + fn size_hint(&self) -> (usize, Option) + { + self.inner.size_hint() + } +} + +impl ExactSizeIterator for IntoIter +{ + fn len(&self) -> usize + { + self.inner.len() + } +} + +impl Drop for IntoIter +where D: Dimension +{ + fn drop(&mut self) + { + if !self.has_unreachable_elements || mem::size_of::() == 0 || !mem::needs_drop::() { + return; + } + + // iterate til the end + while let Some(_) = self.next() {} + + unsafe { + let data_ptr = self.array_data.as_nonnull_mut(); + let view = RawArrayViewMut::new(self.array_head_ptr, self.inner.dim.clone(), self.inner.strides.clone()); + debug_assert!(self.inner.dim.size() < self.data_len, "data_len {} and dim size {}", + self.data_len, self.inner.dim.size()); + drop_unreachable_raw(view, data_ptr, self.data_len); + } + } +} + +impl IntoIterator for Array +where D: Dimension +{ + type Item = A; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter + { + IntoIter::new(self) + } +} + +impl IntoIterator for ArcArray +where + D: Dimension, + A: Clone, +{ + type Item = A; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter + { + IntoIter::new(self.into_owned()) + } +} + +impl IntoIterator for CowArray<'_, A, D> +where + D: Dimension, + A: Clone, +{ + type Item = A; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter + { + IntoIter::new(self.into_owned()) + } +} diff --git a/src/iterators/iter.rs b/src/iterators/iter.rs index 7352e5e18..478987ee0 100644 --- a/src/iterators/iter.rs +++ b/src/iterators/iter.rs @@ -4,12 +4,28 @@ //! implementation structs. //! //! -//! See also [`NdProducer`](../trait.NdProducer.html). +//! See also [`NdProducer`](crate::NdProducer). pub use crate::dimension::Axes; pub use crate::indexes::{Indices, IndicesIter}; pub use crate::iterators::{ - AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksIter, - ExactChunksIterMut, ExactChunksMut, IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, - LanesIter, LanesIterMut, LanesMut, Windows, + AxisChunksIter, + AxisChunksIterMut, + AxisIter, + AxisIterMut, + AxisWindows, + ExactChunks, + ExactChunksIter, + ExactChunksIterMut, + ExactChunksMut, + IndexedIter, + IndexedIterMut, + IntoIter, + Iter, + IterMut, + Lanes, + LanesIter, + LanesIterMut, + LanesMut, + Windows, }; diff --git a/src/iterators/lanes.rs b/src/iterators/lanes.rs index 2163c58a6..11c83d002 100644 --- a/src/iterators/lanes.rs +++ b/src/iterators/lanes.rs @@ -23,33 +23,33 @@ impl_ndproducer! { } } -/// See [`.lanes()`](../struct.ArrayBase.html#method.lanes) +/// See [`.lanes()`](ArrayBase::lanes) /// for more information. -pub struct Lanes<'a, A, D> { +pub struct Lanes<'a, A, D> +{ base: ArrayView<'a, A, D>, inner_len: Ix, inner_stride: Ixs, } -impl<'a, A, D: Dimension> Lanes<'a, A, D> { +impl<'a, A, D: Dimension> Lanes<'a, A, D> +{ pub(crate) fn new(v: ArrayView<'a, A, Di>, axis: Axis) -> Self - where - Di: Dimension, + where Di: Dimension { let ndim = v.ndim(); let len; let stride; - let iter_v; - if ndim == 0 { + let iter_v = if ndim == 0 { len = 1; stride = 1; - iter_v = v.try_remove_axis(Axis(0)) + v.try_remove_axis(Axis(0)) } else { let i = axis.index(); len = v.dim[i]; stride = v.strides[i] as isize; - iter_v = v.try_remove_axis(axis) - } + v.try_remove_axis(axis) + }; Lanes { inner_len: len, inner_stride: stride, @@ -77,12 +77,12 @@ impl_ndproducer! { } impl<'a, A, D> IntoIterator for Lanes<'a, A, D> -where - D: Dimension, +where D: Dimension { type Item = ::Item; type IntoIter = LanesIter<'a, A, D>; - fn into_iter(self) -> Self::IntoIter { + fn into_iter(self) -> Self::IntoIter + { LanesIter { iter: self.base.into_base_iter(), inner_len: self.inner_len, @@ -92,33 +92,33 @@ where } } -/// See [`.lanes_mut()`](../struct.ArrayBase.html#method.lanes_mut) +/// See [`.lanes_mut()`](ArrayBase::lanes_mut) /// for more information. -pub struct LanesMut<'a, A, D> { +pub struct LanesMut<'a, A, D> +{ base: ArrayViewMut<'a, A, D>, inner_len: Ix, inner_stride: Ixs, } -impl<'a, A, D: Dimension> LanesMut<'a, A, D> { +impl<'a, A, D: Dimension> LanesMut<'a, A, D> +{ pub(crate) fn new(v: ArrayViewMut<'a, A, Di>, axis: Axis) -> Self - where - Di: Dimension, + where Di: Dimension { let ndim = v.ndim(); let len; let stride; - let iter_v; - if ndim == 0 { + let iter_v = if ndim == 0 { len = 1; stride = 1; - iter_v = v.try_remove_axis(Axis(0)) + v.try_remove_axis(Axis(0)) } else { let i = axis.index(); len = v.dim[i]; stride = v.strides[i] as isize; - iter_v = v.try_remove_axis(axis) - } + v.try_remove_axis(axis) + }; LanesMut { inner_len: len, inner_stride: stride, @@ -128,12 +128,12 @@ impl<'a, A, D: Dimension> LanesMut<'a, A, D> { } impl<'a, A, D> IntoIterator for LanesMut<'a, A, D> -where - D: Dimension, +where D: Dimension { type Item = ::Item; type IntoIter = LanesIterMut<'a, A, D>; - fn into_iter(self) -> Self::IntoIter { + fn into_iter(self) -> Self::IntoIter + { LanesIterMut { iter: self.base.into_base_iter(), inner_len: self.inner_len, diff --git a/src/iterators/macros.rs b/src/iterators/macros.rs index d3a54453e..78697ec25 100644 --- a/src/iterators/macros.rs +++ b/src/iterators/macros.rs @@ -63,42 +63,34 @@ impl<$($typarm)*> NdProducer for $fulltype { type Ptr = *mut A; type Stride = isize; - #[doc(hidden)] fn raw_dim(&self) -> D { self.$base.raw_dim() } - #[doc(hidden)] fn layout(&self) -> Layout { self.$base.layout() } - #[doc(hidden)] fn as_ptr(&self) -> *mut A { self.$base.as_ptr() as *mut _ } - #[doc(hidden)] fn contiguous_stride(&self) -> isize { self.$base.contiguous_stride() } - #[doc(hidden)] unsafe fn as_ref(&$self_, $ptr: *mut A) -> Self::Item { $refexpr } - #[doc(hidden)] unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A { - self.$base.uget_ptr(i) + self.$base.uget_ptr(i) as *mut _ } - #[doc(hidden)] fn stride_of(&self, axis: Axis) -> isize { self.$base.stride_of(axis) } - #[doc(hidden)] fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { let (a, b) = self.$base.split_at(axis, index); ($typename { @@ -114,6 +106,7 @@ impl<$($typarm)*> NdProducer for $fulltype { )* }) } + private_impl!{} } @@ -130,7 +123,7 @@ expand_if!(@nonempty [$($cloneparm)*] } ); - } + }; } macro_rules! impl_iterator { @@ -177,5 +170,5 @@ macro_rules! impl_iterator { self.$base.size_hint() } } - } + }; } diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 091119361..01fff14f5 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -9,13 +9,20 @@ #[macro_use] mod macros; mod chunks; +mod into_iter; pub mod iter; mod lanes; mod windows; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; use std::iter::FromIterator; use std::marker::PhantomData; use std::ptr; +use std::ptr::NonNull; + +#[allow(unused_imports)] // Needed for Rust 1.64 +use rawpointer::PointerExt; use crate::Ix1; @@ -23,27 +30,32 @@ use super::{ArrayBase, ArrayView, ArrayViewMut, Axis, Data, NdProducer, RemoveAx use super::{Dimension, Ix, Ixs}; pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut}; +pub use self::into_iter::IntoIter; pub use self::lanes::{Lanes, LanesMut}; -pub use self::windows::Windows; +pub use self::windows::{AxisWindows, Windows}; use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut}; /// Base for iterators over all axes. /// -/// Iterator element type is `*mut A`. -pub struct Baseiter { - ptr: *mut A, +/// Iterator element type is `NonNull`. +#[derive(Debug)] +pub struct Baseiter +{ + ptr: NonNull, dim: D, strides: D, index: Option, } -impl Baseiter { +impl Baseiter +{ /// Creating a Baseiter is unsafe because shape and stride parameters need /// to be correct to avoid performing an unsafe pointer offset while /// iterating. #[inline] - pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter { + pub unsafe fn new(ptr: NonNull, len: D, stride: D) -> Baseiter + { Baseiter { ptr, index: len.first_index(), @@ -53,11 +65,13 @@ impl Baseiter { } } -impl Iterator for Baseiter { - type Item = *mut A; +impl Iterator for Baseiter +{ + type Item = NonNull; #[inline] - fn next(&mut self) -> Option<*mut A> { + fn next(&mut self) -> Option + { let index = match self.index { None => return None, Some(ref ix) => ix.clone(), @@ -67,27 +81,30 @@ impl Iterator for Baseiter { unsafe { Some(self.ptr.offset(offset)) } } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { let len = self.len(); (len, Some(len)) } fn fold(mut self, init: Acc, mut g: G) -> Acc - where - G: FnMut(Acc, *mut A) -> Acc, + where G: FnMut(Acc, Self::Item) -> Acc { let ndim = self.dim.ndim(); debug_assert_ne!(ndim, 0); let mut accum = init; - while let Some(mut index) = self.index.clone() { + while let Some(mut index) = self.index { let stride = self.strides.last_elem() as isize; let elem_index = index.last_elem(); let len = self.dim.last_elem(); let offset = D::stride_offset(&index, &self.strides); unsafe { let row_ptr = self.ptr.offset(offset); - for i in 0..(len - elem_index) { + let mut i = 0; + let i_end = len - elem_index; + while i < i_end { accum = g(accum, row_ptr.offset(i as isize * stride)); + i += 1; } } index.set_last_elem(len - 1); @@ -97,8 +114,10 @@ impl Iterator for Baseiter { } } -impl<'a, A, D: Dimension> ExactSizeIterator for Baseiter { - fn len(&self) -> usize { +impl ExactSizeIterator for Baseiter +{ + fn len(&self) -> usize + { match self.index { None => 0, Some(ref ix) => { @@ -108,22 +127,24 @@ impl<'a, A, D: Dimension> ExactSizeIterator for Baseiter { .slice() .iter() .zip(ix.slice().iter()) - .fold(0, |s, (&a, &b)| s + a as usize * b as usize); + .fold(0, |s, (&a, &b)| s + a * b); self.dim.size() - gone } } } } -impl DoubleEndedIterator for Baseiter { +impl DoubleEndedIterator for Baseiter +{ #[inline] - fn next_back(&mut self) -> Option<*mut A> { + fn next_back(&mut self) -> Option + { let index = match self.index { None => return None, Some(ix) => ix, }; self.dim[0] -= 1; - let offset = <_>::stride_offset(&self.dim, &self.strides); + let offset = Ix1::stride_offset(&self.dim, &self.strides); if index == self.dim { self.index = None; } @@ -131,12 +152,13 @@ impl DoubleEndedIterator for Baseiter { unsafe { Some(self.ptr.offset(offset)) } } - fn nth_back(&mut self, n: usize) -> Option<*mut A> { + fn nth_back(&mut self, n: usize) -> Option + { let index = self.index?; let len = self.dim[0] - index[0]; if n < len { self.dim[0] -= n + 1; - let offset = <_>::stride_offset(&self.dim, &self.strides); + let offset = Ix1::stride_offset(&self.dim, &self.strides); if index == self.dim { self.index = None; } @@ -148,8 +170,7 @@ impl DoubleEndedIterator for Baseiter { } fn rfold(mut self, init: Acc, mut g: G) -> Acc - where - G: FnMut(Acc, *mut A) -> Acc, + where G: FnMut(Acc, Self::Item) -> Acc { let mut accum = init; if let Some(index) = self.index { @@ -192,8 +213,10 @@ clone_bounds!( } ); -impl<'a, A, D: Dimension> ElementsBase<'a, A, D> { - pub fn new(v: ArrayView<'a, A, D>) -> Self { +impl<'a, A, D: Dimension> ElementsBase<'a, A, D> +{ + pub fn new(v: ArrayView<'a, A, D>) -> Self + { ElementsBase { inner: v.into_base_iter(), life: PhantomData, @@ -201,44 +224,47 @@ impl<'a, A, D: Dimension> ElementsBase<'a, A, D> { } } -impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D> { +impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D> +{ type Item = &'a A; #[inline] - fn next(&mut self) -> Option<&'a A> { - self.inner.next().map(|p| unsafe { &*p }) + fn next(&mut self) -> Option<&'a A> + { + self.inner.next().map(|p| unsafe { p.as_ref() }) } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { self.inner.size_hint() } fn fold(self, init: Acc, mut g: G) -> Acc - where - G: FnMut(Acc, Self::Item) -> Acc, + where G: FnMut(Acc, Self::Item) -> Acc { - unsafe { self.inner.fold(init, move |acc, ptr| g(acc, &*ptr)) } + unsafe { self.inner.fold(init, move |acc, ptr| g(acc, ptr.as_ref())) } } } -impl<'a, A> DoubleEndedIterator for ElementsBase<'a, A, Ix1> { +impl<'a, A> DoubleEndedIterator for ElementsBase<'a, A, Ix1> +{ #[inline] - fn next_back(&mut self) -> Option<&'a A> { - self.inner.next_back().map(|p| unsafe { &*p }) + fn next_back(&mut self) -> Option<&'a A> + { + self.inner.next_back().map(|p| unsafe { p.as_ref() }) } fn rfold(self, init: Acc, mut g: G) -> Acc - where - G: FnMut(Acc, Self::Item) -> Acc, + where G: FnMut(Acc, Self::Item) -> Acc { - unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, &*ptr)) } + unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, ptr.as_ref())) } } } -impl<'a, A, D> ExactSizeIterator for ElementsBase<'a, A, D> -where - D: Dimension, +impl ExactSizeIterator for ElementsBase<'_, A, D> +where D: Dimension { - fn len(&self) -> usize { + fn len(&self) -> usize + { self.inner.len() } } @@ -271,10 +297,10 @@ clone_bounds!( ); impl<'a, A, D> Iter<'a, A, D> -where - D: Dimension, +where D: Dimension { - pub(crate) fn new(self_: ArrayView<'a, A, D>) -> Self { + pub(crate) fn new(self_: ArrayView<'a, A, D>) -> Self + { Iter { inner: if let Some(slc) = self_.to_slice() { ElementsRepr::Slice(slc.iter()) @@ -286,12 +312,12 @@ where } impl<'a, A, D> IterMut<'a, A, D> -where - D: Dimension, +where D: Dimension { - pub(crate) fn new(self_: ArrayViewMut<'a, A, D>) -> Self { + pub(crate) fn new(self_: ArrayViewMut<'a, A, D>) -> Self + { IterMut { - inner: match self_.into_slice_() { + inner: match self_.try_into_slice() { Ok(x) => ElementsRepr::Slice(x.iter_mut()), Err(self_) => ElementsRepr::Counted(self_.into_elements_base()), }, @@ -299,8 +325,9 @@ where } } -#[derive(Clone)] -pub enum ElementsRepr { +#[derive(Clone, Debug)] +pub enum ElementsRepr +{ Slice(S), Counted(C), } @@ -309,13 +336,17 @@ pub enum ElementsRepr { /// /// Iterator element type is `&'a A`. /// -/// See [`.iter()`](../struct.ArrayBase.html#method.iter) for more information. -pub struct Iter<'a, A, D> { +/// See [`.iter()`](ArrayBase::iter) for more information. +#[derive(Debug)] +pub struct Iter<'a, A, D> +{ inner: ElementsRepr, ElementsBase<'a, A, D>>, } /// Counted read only iterator -pub struct ElementsBase<'a, A, D> { +#[derive(Debug)] +pub struct ElementsBase<'a, A, D> +{ inner: Baseiter, life: PhantomData<&'a A>, } @@ -324,21 +355,27 @@ pub struct ElementsBase<'a, A, D> { /// /// Iterator element type is `&'a mut A`. /// -/// See [`.iter_mut()`](../struct.ArrayBase.html#method.iter_mut) for more information. -pub struct IterMut<'a, A, D> { +/// See [`.iter_mut()`](ArrayBase::iter_mut) for more information. +#[derive(Debug)] +pub struct IterMut<'a, A, D> +{ inner: ElementsRepr, ElementsBaseMut<'a, A, D>>, } /// An iterator over the elements of an array. /// /// Iterator element type is `&'a mut A`. -pub struct ElementsBaseMut<'a, A, D> { +#[derive(Debug)] +pub struct ElementsBaseMut<'a, A, D> +{ inner: Baseiter, life: PhantomData<&'a mut A>, } -impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> { - pub fn new(v: ArrayViewMut<'a, A, D>) -> Self { +impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> +{ + pub fn new(v: ArrayViewMut<'a, A, D>) -> Self + { ElementsBaseMut { inner: v.into_base_iter(), life: PhantomData, @@ -348,136 +385,139 @@ impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> { /// An iterator over the indexes and elements of an array. /// -/// See [`.indexed_iter()`](../struct.ArrayBase.html#method.indexed_iter) for more information. +/// See [`.indexed_iter()`](ArrayBase::indexed_iter) for more information. #[derive(Clone)] pub struct IndexedIter<'a, A, D>(ElementsBase<'a, A, D>); /// An iterator over the indexes and elements of an array (mutable). /// -/// See [`.indexed_iter_mut()`](../struct.ArrayBase.html#method.indexed_iter_mut) for more information. +/// See [`.indexed_iter_mut()`](ArrayBase::indexed_iter_mut) for more information. pub struct IndexedIterMut<'a, A, D>(ElementsBaseMut<'a, A, D>); impl<'a, A, D> IndexedIter<'a, A, D> -where - D: Dimension, +where D: Dimension { - pub(crate) fn new(x: ElementsBase<'a, A, D>) -> Self { + pub(crate) fn new(x: ElementsBase<'a, A, D>) -> Self + { IndexedIter(x) } } impl<'a, A, D> IndexedIterMut<'a, A, D> -where - D: Dimension, +where D: Dimension { - pub(crate) fn new(x: ElementsBaseMut<'a, A, D>) -> Self { + pub(crate) fn new(x: ElementsBaseMut<'a, A, D>) -> Self + { IndexedIterMut(x) } } -impl<'a, A, D: Dimension> Iterator for Iter<'a, A, D> { +impl<'a, A, D: Dimension> Iterator for Iter<'a, A, D> +{ type Item = &'a A; #[inline] - fn next(&mut self) -> Option<&'a A> { + fn next(&mut self) -> Option<&'a A> + { either_mut!(self.inner, iter => iter.next()) } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { either!(self.inner, ref iter => iter.size_hint()) } fn fold(self, init: Acc, g: G) -> Acc - where - G: FnMut(Acc, Self::Item) -> Acc, + where G: FnMut(Acc, Self::Item) -> Acc { either!(self.inner, iter => iter.fold(init, g)) } - fn nth(&mut self, n: usize) -> Option { + fn nth(&mut self, n: usize) -> Option + { either_mut!(self.inner, iter => iter.nth(n)) } fn collect(self) -> B - where - B: FromIterator, + where B: FromIterator { either!(self.inner, iter => iter.collect()) } fn all(&mut self, f: F) -> bool - where - F: FnMut(Self::Item) -> bool, + where F: FnMut(Self::Item) -> bool { either_mut!(self.inner, iter => iter.all(f)) } fn any(&mut self, f: F) -> bool - where - F: FnMut(Self::Item) -> bool, + where F: FnMut(Self::Item) -> bool { either_mut!(self.inner, iter => iter.any(f)) } fn find

(&mut self, predicate: P) -> Option - where - P: FnMut(&Self::Item) -> bool, + where P: FnMut(&Self::Item) -> bool { either_mut!(self.inner, iter => iter.find(predicate)) } fn find_map(&mut self, f: F) -> Option - where - F: FnMut(Self::Item) -> Option, + where F: FnMut(Self::Item) -> Option { either_mut!(self.inner, iter => iter.find_map(f)) } - fn count(self) -> usize { + fn count(self) -> usize + { either!(self.inner, iter => iter.count()) } - fn last(self) -> Option { + fn last(self) -> Option + { either!(self.inner, iter => iter.last()) } fn position

(&mut self, predicate: P) -> Option - where - P: FnMut(Self::Item) -> bool, + where P: FnMut(Self::Item) -> bool { either_mut!(self.inner, iter => iter.position(predicate)) } } -impl<'a, A> DoubleEndedIterator for Iter<'a, A, Ix1> { +impl<'a, A> DoubleEndedIterator for Iter<'a, A, Ix1> +{ #[inline] - fn next_back(&mut self) -> Option<&'a A> { + fn next_back(&mut self) -> Option<&'a A> + { either_mut!(self.inner, iter => iter.next_back()) } - fn nth_back(&mut self, n: usize) -> Option<&'a A> { + fn nth_back(&mut self, n: usize) -> Option<&'a A> + { either_mut!(self.inner, iter => iter.nth_back(n)) } fn rfold(self, init: Acc, g: G) -> Acc - where - G: FnMut(Acc, Self::Item) -> Acc, + where G: FnMut(Acc, Self::Item) -> Acc { either!(self.inner, iter => iter.rfold(init, g)) } } -impl<'a, A, D> ExactSizeIterator for Iter<'a, A, D> -where - D: Dimension, +impl ExactSizeIterator for Iter<'_, A, D> +where D: Dimension { - fn len(&self) -> usize { + fn len(&self) -> usize + { either!(self.inner, ref iter => iter.len()) } } -impl<'a, A, D: Dimension> Iterator for IndexedIter<'a, A, D> { +impl<'a, A, D: Dimension> Iterator for IndexedIter<'a, A, D> +{ type Item = (D::Pattern, &'a A); #[inline] - fn next(&mut self) -> Option { + fn next(&mut self) -> Option + { let index = match self.0.inner.index { None => return None, Some(ref ix) => ix.clone(), @@ -488,166 +528,179 @@ impl<'a, A, D: Dimension> Iterator for IndexedIter<'a, A, D> { } } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { self.0.size_hint() } } -impl<'a, A, D> ExactSizeIterator for IndexedIter<'a, A, D> -where - D: Dimension, +impl ExactSizeIterator for IndexedIter<'_, A, D> +where D: Dimension { - fn len(&self) -> usize { + fn len(&self) -> usize + { self.0.inner.len() } } -impl<'a, A, D: Dimension> Iterator for IterMut<'a, A, D> { +impl<'a, A, D: Dimension> Iterator for IterMut<'a, A, D> +{ type Item = &'a mut A; #[inline] - fn next(&mut self) -> Option<&'a mut A> { + fn next(&mut self) -> Option<&'a mut A> + { either_mut!(self.inner, iter => iter.next()) } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { either!(self.inner, ref iter => iter.size_hint()) } fn fold(self, init: Acc, g: G) -> Acc - where - G: FnMut(Acc, Self::Item) -> Acc, + where G: FnMut(Acc, Self::Item) -> Acc { either!(self.inner, iter => iter.fold(init, g)) } - fn nth(&mut self, n: usize) -> Option { + fn nth(&mut self, n: usize) -> Option + { either_mut!(self.inner, iter => iter.nth(n)) } fn collect(self) -> B - where - B: FromIterator, + where B: FromIterator { either!(self.inner, iter => iter.collect()) } fn all(&mut self, f: F) -> bool - where - F: FnMut(Self::Item) -> bool, + where F: FnMut(Self::Item) -> bool { either_mut!(self.inner, iter => iter.all(f)) } fn any(&mut self, f: F) -> bool - where - F: FnMut(Self::Item) -> bool, + where F: FnMut(Self::Item) -> bool { either_mut!(self.inner, iter => iter.any(f)) } fn find

(&mut self, predicate: P) -> Option - where - P: FnMut(&Self::Item) -> bool, + where P: FnMut(&Self::Item) -> bool { either_mut!(self.inner, iter => iter.find(predicate)) } fn find_map(&mut self, f: F) -> Option - where - F: FnMut(Self::Item) -> Option, + where F: FnMut(Self::Item) -> Option { either_mut!(self.inner, iter => iter.find_map(f)) } - fn count(self) -> usize { + fn count(self) -> usize + { either!(self.inner, iter => iter.count()) } - fn last(self) -> Option { + fn last(self) -> Option + { either!(self.inner, iter => iter.last()) } fn position

(&mut self, predicate: P) -> Option - where - P: FnMut(Self::Item) -> bool, + where P: FnMut(Self::Item) -> bool { either_mut!(self.inner, iter => iter.position(predicate)) } } -impl<'a, A> DoubleEndedIterator for IterMut<'a, A, Ix1> { +impl<'a, A> DoubleEndedIterator for IterMut<'a, A, Ix1> +{ #[inline] - fn next_back(&mut self) -> Option<&'a mut A> { + fn next_back(&mut self) -> Option<&'a mut A> + { either_mut!(self.inner, iter => iter.next_back()) } - fn nth_back(&mut self, n: usize) -> Option<&'a mut A> { + fn nth_back(&mut self, n: usize) -> Option<&'a mut A> + { either_mut!(self.inner, iter => iter.nth_back(n)) } fn rfold(self, init: Acc, g: G) -> Acc - where - G: FnMut(Acc, Self::Item) -> Acc, + where G: FnMut(Acc, Self::Item) -> Acc { either!(self.inner, iter => iter.rfold(init, g)) } } -impl<'a, A, D> ExactSizeIterator for IterMut<'a, A, D> -where - D: Dimension, +impl ExactSizeIterator for IterMut<'_, A, D> +where D: Dimension { - fn len(&self) -> usize { + fn len(&self) -> usize + { either!(self.inner, ref iter => iter.len()) } } -impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D> { +impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D> +{ type Item = &'a mut A; #[inline] - fn next(&mut self) -> Option<&'a mut A> { - self.inner.next().map(|p| unsafe { &mut *p }) + fn next(&mut self) -> Option<&'a mut A> + { + self.inner.next().map(|mut p| unsafe { p.as_mut() }) } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { self.inner.size_hint() } fn fold(self, init: Acc, mut g: G) -> Acc - where - G: FnMut(Acc, Self::Item) -> Acc, + where G: FnMut(Acc, Self::Item) -> Acc { - unsafe { self.inner.fold(init, move |acc, ptr| g(acc, &mut *ptr)) } + unsafe { + self.inner + .fold(init, move |acc, mut ptr| g(acc, ptr.as_mut())) + } } } -impl<'a, A> DoubleEndedIterator for ElementsBaseMut<'a, A, Ix1> { +impl<'a, A> DoubleEndedIterator for ElementsBaseMut<'a, A, Ix1> +{ #[inline] - fn next_back(&mut self) -> Option<&'a mut A> { - self.inner.next_back().map(|p| unsafe { &mut *p }) + fn next_back(&mut self) -> Option<&'a mut A> + { + self.inner.next_back().map(|mut p| unsafe { p.as_mut() }) } fn rfold(self, init: Acc, mut g: G) -> Acc - where - G: FnMut(Acc, Self::Item) -> Acc, + where G: FnMut(Acc, Self::Item) -> Acc { - unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, &mut *ptr)) } + unsafe { + self.inner + .rfold(init, move |acc, mut ptr| g(acc, ptr.as_mut())) + } } } -impl<'a, A, D> ExactSizeIterator for ElementsBaseMut<'a, A, D> -where - D: Dimension, +impl ExactSizeIterator for ElementsBaseMut<'_, A, D> +where D: Dimension { - fn len(&self) -> usize { + fn len(&self) -> usize + { self.inner.len() } } -impl<'a, A, D: Dimension> Iterator for IndexedIterMut<'a, A, D> { +impl<'a, A, D: Dimension> Iterator for IndexedIterMut<'a, A, D> +{ type Item = (D::Pattern, &'a mut A); #[inline] - fn next(&mut self) -> Option { + fn next(&mut self) -> Option + { let index = match self.0.inner.index { None => return None, Some(ref ix) => ix.clone(), @@ -658,16 +711,17 @@ impl<'a, A, D: Dimension> Iterator for IndexedIterMut<'a, A, D> { } } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { self.0.size_hint() } } -impl<'a, A, D> ExactSizeIterator for IndexedIterMut<'a, A, D> -where - D: Dimension, +impl ExactSizeIterator for IndexedIterMut<'_, A, D> +where D: Dimension { - fn len(&self) -> usize { + fn len(&self) -> usize + { self.0.inner.len() } } @@ -675,8 +729,9 @@ where /// An iterator that traverses over all axes but one, and yields a view for /// each lane along that axis. /// -/// See [`.lanes()`](../struct.ArrayBase.html#method.lanes) for more information. -pub struct LanesIter<'a, A, D> { +/// See [`.lanes()`](ArrayBase::lanes) for more information. +pub struct LanesIter<'a, A, D> +{ inner_len: Ix, inner_stride: Ixs, iter: Baseiter, @@ -696,39 +751,51 @@ clone_bounds!( ); impl<'a, A, D> Iterator for LanesIter<'a, A, D> -where - D: Dimension, +where D: Dimension { type Item = ArrayView<'a, A, Ix1>; - fn next(&mut self) -> Option { - self.iter.next().map(|ptr| unsafe { - ArrayView::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) - }) + fn next(&mut self) -> Option + { + self.iter + .next() + .map(|ptr| unsafe { ArrayView::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { self.iter.size_hint() } } -impl<'a, A, D> ExactSizeIterator for LanesIter<'a, A, D> -where - D: Dimension, +impl ExactSizeIterator for LanesIter<'_, A, D> +where D: Dimension { - fn len(&self) -> usize { + fn len(&self) -> usize + { self.iter.len() } } +impl DoubleEndedIterator for LanesIter<'_, A, Ix1> +{ + fn next_back(&mut self) -> Option + { + self.iter + .next_back() + .map(|ptr| unsafe { ArrayView::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) + } +} + // NOTE: LanesIterMut is a mutable iterator and must not expose aliasing // pointers. Due to this we use an empty slice for the raw data (it's unused // anyway). /// An iterator that traverses over all dimensions but the innermost, /// and yields each inner row (mutable). /// -/// See [`.lanes_mut()`](../struct.ArrayBase.html#method.lanes_mut) +/// See [`.lanes_mut()`](ArrayBase::lanes_mut) /// for more information. -pub struct LanesIterMut<'a, A, D> { +pub struct LanesIterMut<'a, A, D> +{ inner_len: Ix, inner_stride: Ixs, iter: Baseiter, @@ -736,32 +803,44 @@ pub struct LanesIterMut<'a, A, D> { } impl<'a, A, D> Iterator for LanesIterMut<'a, A, D> -where - D: Dimension, +where D: Dimension { type Item = ArrayViewMut<'a, A, Ix1>; - fn next(&mut self) -> Option { - self.iter.next().map(|ptr| unsafe { - ArrayViewMut::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) - }) + fn next(&mut self) -> Option + { + self.iter + .next() + .map(|ptr| unsafe { ArrayViewMut::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { self.iter.size_hint() } } -impl<'a, A, D> ExactSizeIterator for LanesIterMut<'a, A, D> -where - D: Dimension, +impl ExactSizeIterator for LanesIterMut<'_, A, D> +where D: Dimension { - fn len(&self) -> usize { + fn len(&self) -> usize + { self.iter.len() } } +impl DoubleEndedIterator for LanesIterMut<'_, A, Ix1> +{ + fn next_back(&mut self) -> Option + { + self.iter + .next_back() + .map(|ptr| unsafe { ArrayViewMut::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) + } +} + #[derive(Debug)] -pub struct AxisIterCore { +pub struct AxisIterCore +{ /// Index along the axis of the value of `.next()`, relative to the start /// of the axis. index: Ix, @@ -792,7 +871,8 @@ clone_bounds!( } ); -impl AxisIterCore { +impl AxisIterCore +{ /// Constructs a new iterator over the specified axis. fn new(v: ArrayBase, axis: Axis) -> Self where @@ -810,7 +890,8 @@ impl AxisIterCore { } #[inline] - unsafe fn offset(&self, index: usize) -> *mut A { + unsafe fn offset(&self, index: usize) -> *mut A + { debug_assert!( index < self.end, "index={}, end={}, stride={}", @@ -828,7 +909,9 @@ impl AxisIterCore { /// /// **Panics** if `index` is strictly greater than the iterator's remaining /// length. - fn split_at(self, index: usize) -> (Self, Self) { + #[track_caller] + fn split_at(self, index: usize) -> (Self, Self) + { assert!(index <= self.len()); let mid = self.index + index; let left = AxisIterCore { @@ -852,25 +935,27 @@ impl AxisIterCore { /// Does the same thing as `.next()` but also returns the index of the item /// relative to the start of the axis. - fn next_with_index(&mut self) -> Option<(usize, *mut A)> { + fn next_with_index(&mut self) -> Option<(usize, *mut A)> + { let index = self.index; self.next().map(|ptr| (index, ptr)) } /// Does the same thing as `.next_back()` but also returns the index of the /// item relative to the start of the axis. - fn next_back_with_index(&mut self) -> Option<(usize, *mut A)> { + fn next_back_with_index(&mut self) -> Option<(usize, *mut A)> + { self.next_back().map(|ptr| (self.end, ptr)) } } impl Iterator for AxisIterCore -where - D: Dimension, +where D: Dimension { type Item = *mut A; - fn next(&mut self) -> Option { + fn next(&mut self) -> Option + { if self.index >= self.end { None } else { @@ -880,17 +965,18 @@ where } } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { let len = self.len(); (len, Some(len)) } } impl DoubleEndedIterator for AxisIterCore -where - D: Dimension, +where D: Dimension { - fn next_back(&mut self) -> Option { + fn next_back(&mut self) -> Option + { if self.index >= self.end { None } else { @@ -902,10 +988,10 @@ where } impl ExactSizeIterator for AxisIterCore -where - D: Dimension, +where D: Dimension { - fn len(&self) -> usize { + fn len(&self) -> usize + { self.end - self.index } } @@ -921,11 +1007,12 @@ where /// /// Iterator element type is `ArrayView<'a, A, D>`. /// -/// See [`.outer_iter()`](../struct.ArrayBase.html#method.outer_iter) -/// or [`.axis_iter()`](../struct.ArrayBase.html#method.axis_iter) +/// See [`.outer_iter()`](ArrayBase::outer_iter) +/// or [`.axis_iter()`](ArrayBase::axis_iter) /// for more information. #[derive(Debug)] -pub struct AxisIter<'a, A, D> { +pub struct AxisIter<'a, A, D> +{ iter: AxisIterCore, life: PhantomData<&'a A>, } @@ -940,11 +1027,11 @@ clone_bounds!( } ); -impl<'a, A, D: Dimension> AxisIter<'a, A, D> { +impl<'a, A, D: Dimension> AxisIter<'a, A, D> +{ /// Creates a new iterator over the specified axis. pub(crate) fn new(v: ArrayView<'a, A, Di>, axis: Axis) -> Self - where - Di: RemoveAxis, + where Di: RemoveAxis { AxisIter { iter: AxisIterCore::new(v, axis), @@ -959,7 +1046,9 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> { /// /// **Panics** if `index` is strictly greater than the iterator's remaining /// length. - pub fn split_at(self, index: usize) -> (Self, Self) { + #[track_caller] + pub fn split_at(self, index: usize) -> (Self, Self) + { let (left, right) = self.iter.split_at(index); ( AxisIter { @@ -975,34 +1064,35 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> { } impl<'a, A, D> Iterator for AxisIter<'a, A, D> -where - D: Dimension, +where D: Dimension { type Item = ArrayView<'a, A, D>; - fn next(&mut self) -> Option { + fn next(&mut self) -> Option + { self.iter.next().map(|ptr| unsafe { self.as_ref(ptr) }) } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { self.iter.size_hint() } } -impl<'a, A, D> DoubleEndedIterator for AxisIter<'a, A, D> -where - D: Dimension, +impl DoubleEndedIterator for AxisIter<'_, A, D> +where D: Dimension { - fn next_back(&mut self) -> Option { + fn next_back(&mut self) -> Option + { self.iter.next_back().map(|ptr| unsafe { self.as_ref(ptr) }) } } -impl<'a, A, D> ExactSizeIterator for AxisIter<'a, A, D> -where - D: Dimension, +impl ExactSizeIterator for AxisIter<'_, A, D> +where D: Dimension { - fn len(&self) -> usize { + fn len(&self) -> usize + { self.iter.len() } } @@ -1018,19 +1108,20 @@ where /// /// Iterator element type is `ArrayViewMut<'a, A, D>`. /// -/// See [`.outer_iter_mut()`](../struct.ArrayBase.html#method.outer_iter_mut) -/// or [`.axis_iter_mut()`](../struct.ArrayBase.html#method.axis_iter_mut) +/// See [`.outer_iter_mut()`](ArrayBase::outer_iter_mut) +/// or [`.axis_iter_mut()`](ArrayBase::axis_iter_mut) /// for more information. -pub struct AxisIterMut<'a, A, D> { +pub struct AxisIterMut<'a, A, D> +{ iter: AxisIterCore, life: PhantomData<&'a mut A>, } -impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> { +impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> +{ /// Creates a new iterator over the specified axis. pub(crate) fn new(v: ArrayViewMut<'a, A, Di>, axis: Axis) -> Self - where - Di: RemoveAxis, + where Di: RemoveAxis { AxisIterMut { iter: AxisIterCore::new(v, axis), @@ -1045,7 +1136,9 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> { /// /// **Panics** if `index` is strictly greater than the iterator's remaining /// length. - pub fn split_at(self, index: usize) -> (Self, Self) { + #[track_caller] + pub fn split_at(self, index: usize) -> (Self, Self) + { let (left, right) = self.iter.split_at(index); ( AxisIterMut { @@ -1061,54 +1154,58 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> { } impl<'a, A, D> Iterator for AxisIterMut<'a, A, D> -where - D: Dimension, +where D: Dimension { type Item = ArrayViewMut<'a, A, D>; - fn next(&mut self) -> Option { + fn next(&mut self) -> Option + { self.iter.next().map(|ptr| unsafe { self.as_ref(ptr) }) } - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + { self.iter.size_hint() } } -impl<'a, A, D> DoubleEndedIterator for AxisIterMut<'a, A, D> -where - D: Dimension, +impl DoubleEndedIterator for AxisIterMut<'_, A, D> +where D: Dimension { - fn next_back(&mut self) -> Option { + fn next_back(&mut self) -> Option + { self.iter.next_back().map(|ptr| unsafe { self.as_ref(ptr) }) } } -impl<'a, A, D> ExactSizeIterator for AxisIterMut<'a, A, D> -where - D: Dimension, +impl ExactSizeIterator for AxisIterMut<'_, A, D> +where D: Dimension { - fn len(&self) -> usize { + fn len(&self) -> usize + { self.iter.len() } } -impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> { +impl NdProducer for AxisIter<'_, A, D> +{ type Item = ::Item; type Dim = Ix1; type Ptr = *mut A; type Stride = isize; - #[doc(hidden)] - fn layout(&self) -> crate::Layout { + fn layout(&self) -> crate::Layout + { crate::Layout::one_dimensional() } - #[doc(hidden)] - fn raw_dim(&self) -> Self::Dim { + + fn raw_dim(&self) -> Self::Dim + { Ix1(self.len()) } - #[doc(hidden)] - fn as_ptr(&self) -> Self::Ptr { + + fn as_ptr(&self) -> Self::Ptr + { if self.len() > 0 { // `self.iter.index` is guaranteed to be in-bounds if any of the // iterator remains (i.e. if `self.len() > 0`). @@ -1121,51 +1218,53 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> { } } - fn contiguous_stride(&self) -> isize { + fn contiguous_stride(&self) -> isize + { self.iter.stride } - #[doc(hidden)] - unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item { - ArrayView::new_( - ptr, - self.iter.inner_dim.clone(), - self.iter.inner_strides.clone(), - ) + unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item + { + ArrayView::new_(ptr, self.iter.inner_dim.clone(), self.iter.inner_strides.clone()) } - #[doc(hidden)] - unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr { + + unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr + { self.iter.offset(self.iter.index + i[0]) } - #[doc(hidden)] - fn stride_of(&self, _axis: Axis) -> isize { + fn stride_of(&self, _axis: Axis) -> isize + { self.contiguous_stride() } - #[doc(hidden)] - fn split_at(self, _axis: Axis, index: usize) -> (Self, Self) { + fn split_at(self, _axis: Axis, index: usize) -> (Self, Self) + { self.split_at(index) } + private_impl! {} } -impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> { +impl NdProducer for AxisIterMut<'_, A, D> +{ type Item = ::Item; type Dim = Ix1; type Ptr = *mut A; type Stride = isize; - #[doc(hidden)] - fn layout(&self) -> crate::Layout { + fn layout(&self) -> crate::Layout + { crate::Layout::one_dimensional() } - #[doc(hidden)] - fn raw_dim(&self) -> Self::Dim { + + fn raw_dim(&self) -> Self::Dim + { Ix1(self.len()) } - #[doc(hidden)] - fn as_ptr(&self) -> Self::Ptr { + + fn as_ptr(&self) -> Self::Ptr + { if self.len() > 0 { // `self.iter.index` is guaranteed to be in-bounds if any of the // iterator remains (i.e. if `self.len() > 0`). @@ -1178,32 +1277,31 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> { } } - fn contiguous_stride(&self) -> isize { + fn contiguous_stride(&self) -> isize + { self.iter.stride } - #[doc(hidden)] - unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item { - ArrayViewMut::new_( - ptr, - self.iter.inner_dim.clone(), - self.iter.inner_strides.clone(), - ) + unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item + { + ArrayViewMut::new_(ptr, self.iter.inner_dim.clone(), self.iter.inner_strides.clone()) } - #[doc(hidden)] - unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr { + + unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr + { self.iter.offset(self.iter.index + i[0]) } - #[doc(hidden)] - fn stride_of(&self, _axis: Axis) -> isize { + fn stride_of(&self, _axis: Axis) -> isize + { self.contiguous_stride() } - #[doc(hidden)] - fn split_at(self, _axis: Axis, index: usize) -> (Self, Self) { + fn split_at(self, _axis: Axis, index: usize) -> (Self, Self) + { self.split_at(index) } + private_impl! {} } @@ -1216,8 +1314,9 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> { /// /// Iterator element type is `ArrayView<'a, A, D>`. /// -/// See [`.axis_chunks_iter()`](../struct.ArrayBase.html#method.axis_chunks_iter) for more information. -pub struct AxisChunksIter<'a, A, D> { +/// See [`.axis_chunks_iter()`](ArrayBase::axis_chunks_iter) for more information. +pub struct AxisChunksIter<'a, A, D> +{ iter: AxisIterCore, /// Index of the partial chunk (the chunk smaller than the specified chunk /// size due to the axis length not being evenly divisible). If the axis @@ -1249,11 +1348,10 @@ clone_bounds!( /// the number of chunks, and the shape of the last chunk. /// /// **Panics** if `size == 0`. -fn chunk_iter_parts( - v: ArrayView<'_, A, D>, - axis: Axis, - size: usize, -) -> (AxisIterCore, usize, D) { +#[track_caller] +fn chunk_iter_parts(v: ArrayView<'_, A, D>, axis: Axis, size: usize) + -> (AxisIterCore, usize, D) +{ assert_ne!(size, 0, "Chunk size must be nonzero."); let axis_len = v.len_of(axis); let n_whole_chunks = axis_len / size; @@ -1290,8 +1388,10 @@ fn chunk_iter_parts( (iter, partial_chunk_index, partial_chunk_dim) } -impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> { - pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self { +impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> +{ + pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self + { let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v, axis, size); AxisChunksIter { iter, @@ -1335,6 +1435,7 @@ macro_rules! chunk_iter_impl { /// /// **Panics** if `index` is strictly greater than the iterator's remaining /// length. + #[track_caller] pub fn split_at(self, index: usize) -> (Self, Self) { let (left, right) = self.iter.split_at(index); ( @@ -1395,19 +1496,21 @@ macro_rules! chunk_iter_impl { /// /// Iterator element type is `ArrayViewMut<'a, A, D>`. /// -/// See [`.axis_chunks_iter_mut()`](../struct.ArrayBase.html#method.axis_chunks_iter_mut) +/// See [`.axis_chunks_iter_mut()`](ArrayBase::axis_chunks_iter_mut) /// for more information. -pub struct AxisChunksIterMut<'a, A, D> { +pub struct AxisChunksIterMut<'a, A, D> +{ iter: AxisIterCore, partial_chunk_index: usize, partial_chunk_dim: D, life: PhantomData<&'a mut A>, } -impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> { - pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self { - let (iter, partial_chunk_index, partial_chunk_dim) = - chunk_iter_parts(v.into_view(), axis, size); +impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> +{ + pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self + { + let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v.into_view(), axis, size); AxisChunksIterMut { iter, partial_chunk_index, @@ -1439,30 +1542,34 @@ send_sync_read_write!(ElementsBaseMut); /// /// The iterator must produce exactly the number of elements it reported or /// diverge before reaching the end. +#[allow(clippy::missing_safety_doc)] // not nameable downstream pub unsafe trait TrustedIterator {} use crate::indexes::IndicesIterF; use crate::iter::IndicesIter; +#[cfg(feature = "std")] use crate::{geomspace::Geomspace, linspace::Linspace, logspace::Logspace}; - -unsafe impl TrustedIterator for Geomspace {} +#[cfg(feature = "std")] unsafe impl TrustedIterator for Linspace {} +#[cfg(feature = "std")] +unsafe impl TrustedIterator for Geomspace {} +#[cfg(feature = "std")] unsafe impl TrustedIterator for Logspace {} -unsafe impl<'a, A, D> TrustedIterator for Iter<'a, A, D> {} -unsafe impl<'a, A, D> TrustedIterator for IterMut<'a, A, D> {} +unsafe impl TrustedIterator for Iter<'_, A, D> {} +unsafe impl TrustedIterator for IterMut<'_, A, D> {} unsafe impl TrustedIterator for std::iter::Cloned where I: TrustedIterator {} unsafe impl TrustedIterator for std::iter::Map where I: TrustedIterator {} -unsafe impl<'a, A> TrustedIterator for slice::Iter<'a, A> {} -unsafe impl<'a, A> TrustedIterator for slice::IterMut<'a, A> {} +unsafe impl TrustedIterator for slice::Iter<'_, A> {} +unsafe impl TrustedIterator for slice::IterMut<'_, A> {} unsafe impl TrustedIterator for ::std::ops::Range {} // FIXME: These indices iter are dubious -- size needs to be checked up front. unsafe impl TrustedIterator for IndicesIter where D: Dimension {} unsafe impl TrustedIterator for IndicesIterF where D: Dimension {} +unsafe impl TrustedIterator for IntoIter where D: Dimension {} /// Like Iterator::collect, but only for trusted length iterators pub fn to_vec(iter: I) -> Vec -where - I: TrustedIterator + ExactSizeIterator, +where I: TrustedIterator + ExactSizeIterator { to_vec_mapped(iter, |x| x) } diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index e0abbc537..1c2ab6a85 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -1,50 +1,52 @@ -use super::ElementsBase; +use std::marker::PhantomData; + +use super::Baseiter; use crate::imp_prelude::*; use crate::IntoDimension; use crate::Layout; use crate::NdProducer; +use crate::Slice; /// Window producer and iterable /// -/// See [`.windows()`](../struct.ArrayBase.html#method.windows) for more +/// See [`.windows()`](ArrayBase::windows) for more /// information. -pub struct Windows<'a, A, D> { - base: ArrayView<'a, A, D>, +pub struct Windows<'a, A, D> +{ + base: RawArrayView, + life: PhantomData<&'a A>, window: D, strides: D, } -impl<'a, A, D: Dimension> Windows<'a, A, D> { +impl<'a, A, D: Dimension> Windows<'a, A, D> +{ pub(crate) fn new(a: ArrayView<'a, A, D>, window_size: E) -> Self - where - E: IntoDimension, + where E: IntoDimension + { + let window = window_size.into_dimension(); + let ndim = window.ndim(); + + let mut unit_stride = D::zeros(ndim); + unit_stride.slice_mut().fill(1); + + Windows::new_with_stride(a, window, unit_stride) + } + + pub(crate) fn new_with_stride(a: ArrayView<'a, A, D>, window_size: E, axis_strides: E) -> Self + where E: IntoDimension { let window = window_size.into_dimension(); - ndassert!( - a.ndim() == window.ndim(), - concat!( - "Window dimension {} does not match array dimension {} ", - "(with array of shape {:?})" - ), - window.ndim(), - a.ndim(), - a.shape() - ); - let mut size = a.dim; - for (sz, &ws) in size.slice_mut().iter_mut().zip(window.slice()) { - assert_ne!(ws, 0, "window-size must not be zero!"); - // cannot use std::cmp::max(0, ..) since arithmetic underflow panics - *sz = if *sz < ws { 0 } else { *sz - ws + 1 }; - } + let strides = axis_strides.into_dimension(); let window_strides = a.strides.clone(); - unsafe { - Windows { - base: ArrayView::from_shape_ptr(size.clone().strides(a.strides), a.ptr.as_ptr()), - window, - strides: window_strides, - } + let base = build_base(a, window.clone(), strides); + Windows { + base: base.into_raw_view(), + life: PhantomData, + window, + strides: window_strides, } } } @@ -54,6 +56,7 @@ impl_ndproducer! { [Clone => 'a, A, D: Clone ] Windows { base, + life, window, strides, } @@ -75,9 +78,11 @@ where { type Item = ::Item; type IntoIter = WindowsIter<'a, A, D>; - fn into_iter(self) -> Self::IntoIter { + fn into_iter(self) -> Self::IntoIter + { WindowsIter { - iter: self.base.into_elements_base(), + iter: self.base.into_base_iter(), + life: self.life, window: self.window, strides: self.strides, } @@ -86,10 +91,12 @@ where /// Window iterator. /// -/// See [`.windows()`](../struct.ArrayBase.html#method.windows) for more +/// See [`.windows()`](ArrayBase::windows) for more /// information. -pub struct WindowsIter<'a, A, D> { - iter: ElementsBase<'a, A, D>, +pub struct WindowsIter<'a, A, D> +{ + iter: Baseiter, + life: PhantomData<&'a A>, window: D, strides: D, } @@ -99,19 +106,186 @@ impl_iterator! { [Clone => 'a, A, D: Clone] WindowsIter { iter, + life, window, strides, } WindowsIter<'a, A, D> { type Item = ArrayView<'a, A, D>; - fn item(&mut self, elt) { + fn item(&mut self, ptr) { unsafe { - ArrayView::new_( - elt, + ArrayView::new( + ptr, self.window.clone(), self.strides.clone()) } } } } + +send_sync_read_only!(Windows); +send_sync_read_only!(WindowsIter); + +/// Window producer and iterable +/// +/// See [`.axis_windows()`](ArrayBase::axis_windows) for more +/// information. +pub struct AxisWindows<'a, A, D> +{ + base: ArrayView<'a, A, D>, + axis_idx: usize, + window: D, + strides: D, +} + +impl<'a, A, D: Dimension> AxisWindows<'a, A, D> +{ + pub(crate) fn new(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize) -> Self + { + let window_strides = a.strides.clone(); + let axis_idx = axis.index(); + + let mut window = a.raw_dim(); + window[axis_idx] = window_size; + + let ndim = window.ndim(); + let mut unit_stride = D::zeros(ndim); + unit_stride.slice_mut().fill(1); + + let base = build_base(a, window.clone(), unit_stride); + AxisWindows { + base, + axis_idx, + window, + strides: window_strides, + } + } +} + +impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D> +{ + type Item = ArrayView<'a, A, D>; + type Dim = Ix1; + type Ptr = *mut A; + type Stride = isize; + + fn raw_dim(&self) -> Ix1 + { + Ix1(self.base.raw_dim()[self.axis_idx]) + } + + fn layout(&self) -> Layout + { + self.base.layout() + } + + fn as_ptr(&self) -> *mut A + { + self.base.as_ptr() as *mut _ + } + + fn contiguous_stride(&self) -> isize + { + self.base.contiguous_stride() + } + + unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item + { + ArrayView::new_(ptr, self.window.clone(), self.strides.clone()) + } + + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A + { + let mut d = D::zeros(self.base.ndim()); + d[self.axis_idx] = i[0]; + self.base.uget_ptr(&d) + } + + fn stride_of(&self, axis: Axis) -> isize + { + assert_eq!(axis, Axis(0)); + self.base.stride_of(Axis(self.axis_idx)) + } + + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) + { + assert_eq!(axis, Axis(0)); + let (a, b) = self.base.split_at(Axis(self.axis_idx), index); + ( + AxisWindows { + base: a, + axis_idx: self.axis_idx, + window: self.window.clone(), + strides: self.strides.clone(), + }, + AxisWindows { + base: b, + axis_idx: self.axis_idx, + window: self.window, + strides: self.strides, + }, + ) + } + + private_impl!{} +} + +impl<'a, A, D> IntoIterator for AxisWindows<'a, A, D> +where + D: Dimension, + A: 'a, +{ + type Item = ::Item; + type IntoIter = WindowsIter<'a, A, D>; + fn into_iter(self) -> Self::IntoIter + { + WindowsIter { + iter: self.base.into_base_iter(), + life: PhantomData, + window: self.window, + strides: self.strides, + } + } +} + +/// build the base array of the `Windows` and `AxisWindows` structs +fn build_base(a: ArrayView, window: D, strides: D) -> ArrayView +where D: Dimension +{ + ndassert!( + a.ndim() == window.ndim(), + concat!( + "Window dimension {} does not match array dimension {} ", + "(with array of shape {:?})" + ), + window.ndim(), + a.ndim(), + a.shape() + ); + + ndassert!( + a.ndim() == strides.ndim(), + concat!( + "Stride dimension {} does not match array dimension {} ", + "(with array of shape {:?})" + ), + strides.ndim(), + a.ndim(), + a.shape() + ); + + let mut base = a; + base.slice_each_axis_inplace(|ax_desc| { + let len = ax_desc.len; + let wsz = window[ax_desc.axis.index()]; + let stride = strides[ax_desc.axis.index()]; + + if len < wsz { + Slice::new(0, Some(0), 1) + } else { + Slice::new(0, Some((len - wsz + 1) as isize), stride as isize) + } + }); + base +} diff --git a/src/itertools.rs b/src/itertools.rs index 96732f903..d3562e687 100644 --- a/src/itertools.rs +++ b/src/itertools.rs @@ -23,8 +23,7 @@ use std::iter; /// } /// ``` pub(crate) fn enumerate(iterable: I) -> iter::Enumerate -where - I: IntoIterator, +where I: IntoIterator { iterable.into_iter().enumerate() } @@ -64,11 +63,9 @@ where /// The special cases of one and two arguments produce the equivalent of /// `$a.into_iter()` and `$a.into_iter().zip($b)` respectively. /// -/// Prefer this macro `izip!()` over [`multizip`] for the performance benefits +/// Prefer this macro `izip!()` over `multizip` for the performance benefits /// of using the standard library `.zip()`. /// -/// [`multizip`]: fn.multizip.html -/// /// ``` /// #[macro_use] extern crate itertools; /// # fn main() { @@ -88,7 +85,8 @@ where /// **Note:** To enable the macros in this crate, use the `#[macro_use]` /// attribute when importing the crate: /// -/// ``` +/// ```no_run +/// # #[allow(unused_imports)] /// #[macro_use] extern crate itertools; /// # fn main() { } /// ``` diff --git a/src/layout/layoutfmt.rs b/src/layout/layoutfmt.rs index cb799ca81..f20f0caaa 100644 --- a/src/layout/layoutfmt.rs +++ b/src/layout/layoutfmt.rs @@ -7,14 +7,15 @@ // except according to those terms. use super::Layout; -use super::LayoutPriv; -const LAYOUT_NAMES: &[&str] = &["C", "F"]; +const LAYOUT_NAMES: &[&str] = &["C", "F", "c", "f"]; use std::fmt; -impl fmt::Debug for Layout { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl fmt::Debug for Layout +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { if self.0 == 0 { write!(f, "Custom")? } else { diff --git a/src/layout/mod.rs b/src/layout/mod.rs index 741a0e054..026688d63 100644 --- a/src/layout/mod.rs +++ b/src/layout/mod.rs @@ -1,61 +1,243 @@ mod layoutfmt; -// public but users don't interact with it +// Layout it a bitset used for internal layout description of +// arrays, producers and sets of producers. +// The type is public but users don't interact with it. #[doc(hidden)] /// Memory layout description #[derive(Copy, Clone)] pub struct Layout(u32); -pub trait LayoutPriv: Sized { - fn new(x: u32) -> Self; - fn and(self, flag: Self) -> Self; - fn is(self, flag: u32) -> bool; - fn flag(self) -> u32; -} +impl Layout +{ + pub(crate) const CORDER: u32 = 0b01; + pub(crate) const FORDER: u32 = 0b10; + pub(crate) const CPREFER: u32 = 0b0100; + pub(crate) const FPREFER: u32 = 0b1000; -impl LayoutPriv for Layout { #[inline(always)] - fn new(x: u32) -> Self { - Layout(x) + pub(crate) fn is(self, flag: u32) -> bool + { + self.0 & flag != 0 } + /// Return layout common to both inputs #[inline(always)] - fn is(self, flag: u32) -> bool { - self.0 & flag != 0 + pub(crate) fn intersect(self, other: Layout) -> Layout + { + Layout(self.0 & other.0) } + + /// Return a layout that simultaneously "is" what both of the inputs are #[inline(always)] - fn and(self, flag: Layout) -> Layout { - Layout(self.0 & flag.0) + pub(crate) fn also(self, other: Layout) -> Layout + { + Layout(self.0 | other.0) } #[inline(always)] - fn flag(self) -> u32 { - self.0 + pub(crate) fn one_dimensional() -> Layout + { + Layout::c().also(Layout::f()) } -} -impl Layout { - #[doc(hidden)] #[inline(always)] - pub fn one_dimensional() -> Layout { - Layout(CORDER | FORDER) + pub(crate) fn c() -> Layout + { + Layout(Layout::CORDER | Layout::CPREFER) } - #[doc(hidden)] + #[inline(always)] - pub fn c() -> Layout { - Layout(CORDER) + pub(crate) fn f() -> Layout + { + Layout(Layout::FORDER | Layout::FPREFER) } - #[doc(hidden)] + #[inline(always)] - pub fn f() -> Layout { - Layout(FORDER) + pub(crate) fn cpref() -> Layout + { + Layout(Layout::CPREFER) } + + #[inline(always)] + pub(crate) fn fpref() -> Layout + { + Layout(Layout::FPREFER) + } + #[inline(always)] - #[doc(hidden)] - pub fn none() -> Layout { + pub(crate) fn none() -> Layout + { Layout(0) } + + /// A simple "score" method which scores positive for preferring C-order, negative for F-order + /// Subject to change when we can describe other layouts + #[inline] + pub(crate) fn tendency(self) -> i32 + { + (self.is(Layout::CORDER) as i32 - self.is(Layout::FORDER) as i32) + + (self.is(Layout::CPREFER) as i32 - self.is(Layout::FPREFER) as i32) + } } -pub const CORDER: u32 = 0b01; -pub const FORDER: u32 = 0b10; +#[cfg(test)] +mod tests +{ + use super::*; + use crate::imp_prelude::*; + use crate::NdProducer; + + type M = Array2; + type M1 = Array1; + type M0 = Array0; + + macro_rules! assert_layouts { + ($mat:expr, $($layout:ident),*) => {{ + let layout = $mat.view().layout(); + $( + assert!(layout.is(Layout::$layout), + "Assertion failed: array {:?} is not layout {}", + $mat, + stringify!($layout)); + )* + }}; + } + + macro_rules! assert_not_layouts { + ($mat:expr, $($layout:ident),*) => {{ + let layout = $mat.view().layout(); + $( + assert!(!layout.is(Layout::$layout), + "Assertion failed: array {:?} show not have layout {}", + $mat, + stringify!($layout)); + )* + }}; + } + + #[test] + fn contig_layouts() + { + let a = M::zeros((5, 5)); + let b = M::zeros((5, 5).f()); + let ac = a.view().layout(); + let af = b.view().layout(); + assert!(ac.is(Layout::CORDER) && ac.is(Layout::CPREFER)); + assert!(!ac.is(Layout::FORDER) && !ac.is(Layout::FPREFER)); + assert!(!af.is(Layout::CORDER) && !af.is(Layout::CPREFER)); + assert!(af.is(Layout::FORDER) && af.is(Layout::FPREFER)); + } + + #[test] + fn contig_cf_layouts() + { + let a = M::zeros((5, 1)); + let b = M::zeros((1, 5).f()); + assert_layouts!(a, CORDER, CPREFER, FORDER, FPREFER); + assert_layouts!(b, CORDER, CPREFER, FORDER, FPREFER); + + let a = M1::zeros(5); + let b = M1::zeros(5.f()); + assert_layouts!(a, CORDER, CPREFER, FORDER, FPREFER); + assert_layouts!(b, CORDER, CPREFER, FORDER, FPREFER); + + let a = M0::zeros(()); + assert_layouts!(a, CORDER, CPREFER, FORDER, FPREFER); + + let a = M::zeros((5, 5)); + let b = M::zeros((5, 5).f()); + let arow = a.slice(s![..1, ..]); + let bcol = b.slice(s![.., ..1]); + assert_layouts!(arow, CORDER, CPREFER, FORDER, FPREFER); + assert_layouts!(bcol, CORDER, CPREFER, FORDER, FPREFER); + + let acol = a.slice(s![.., ..1]); + let brow = b.slice(s![..1, ..]); + assert_not_layouts!(acol, CORDER, CPREFER, FORDER, FPREFER); + assert_not_layouts!(brow, CORDER, CPREFER, FORDER, FPREFER); + } + + #[test] + fn stride_layouts() + { + let a = M::zeros((5, 5)); + + { + let v1 = a.slice(s![1.., ..]).layout(); + let v2 = a.slice(s![.., 1..]).layout(); + + assert!(v1.is(Layout::CORDER) && v1.is(Layout::CPREFER)); + assert!(!v1.is(Layout::FORDER) && !v1.is(Layout::FPREFER)); + assert!(!v2.is(Layout::CORDER) && v2.is(Layout::CPREFER)); + assert!(!v2.is(Layout::FORDER) && !v2.is(Layout::FPREFER)); + } + + let b = M::zeros((5, 5).f()); + + { + let v1 = b.slice(s![1.., ..]).layout(); + let v2 = b.slice(s![.., 1..]).layout(); + + assert!(!v1.is(Layout::CORDER) && !v1.is(Layout::CPREFER)); + assert!(!v1.is(Layout::FORDER) && v1.is(Layout::FPREFER)); + assert!(!v2.is(Layout::CORDER) && !v2.is(Layout::CPREFER)); + assert!(v2.is(Layout::FORDER) && v2.is(Layout::FPREFER)); + } + } + + #[test] + fn no_layouts() + { + let a = M::zeros((5, 5)); + let b = M::zeros((5, 5).f()); + + // 2D row/column matrixes + let arow = a.slice(s![0..1, ..]); + let acol = a.slice(s![.., 0..1]); + let brow = b.slice(s![0..1, ..]); + let bcol = b.slice(s![.., 0..1]); + assert_layouts!(arow, CORDER, FORDER); + assert_not_layouts!(acol, CORDER, CPREFER, FORDER, FPREFER); + assert_layouts!(bcol, CORDER, FORDER); + assert_not_layouts!(brow, CORDER, CPREFER, FORDER, FPREFER); + + // 2D row/column matrixes - now made with insert axis + for &axis in &[Axis(0), Axis(1)] { + let arow = a.slice(s![0, ..]).insert_axis(axis); + let acol = a.slice(s![.., 0]).insert_axis(axis); + let brow = b.slice(s![0, ..]).insert_axis(axis); + let bcol = b.slice(s![.., 0]).insert_axis(axis); + assert_layouts!(arow, CORDER, FORDER); + assert_not_layouts!(acol, CORDER, CPREFER, FORDER, FPREFER); + assert_layouts!(bcol, CORDER, FORDER); + assert_not_layouts!(brow, CORDER, CPREFER, FORDER, FPREFER); + } + } + + #[test] + fn skip_layouts() + { + let a = M::zeros((5, 5)); + { + let v1 = a.slice(s![..;2, ..]).layout(); + let v2 = a.slice(s![.., ..;2]).layout(); + + assert!(!v1.is(Layout::CORDER) && v1.is(Layout::CPREFER)); + assert!(!v1.is(Layout::FORDER) && !v1.is(Layout::FPREFER)); + assert!(!v2.is(Layout::CORDER) && !v2.is(Layout::CPREFER)); + assert!(!v2.is(Layout::FORDER) && !v2.is(Layout::FPREFER)); + } + + let b = M::zeros((5, 5).f()); + { + let v1 = b.slice(s![..;2, ..]).layout(); + let v2 = b.slice(s![.., ..;2]).layout(); + + assert!(!v1.is(Layout::CORDER) && !v1.is(Layout::CPREFER)); + assert!(!v1.is(Layout::FORDER) && !v1.is(Layout::FPREFER)); + assert!(!v2.is(Layout::CORDER) && !v2.is(Layout::CPREFER)); + assert!(!v2.is(Layout::FORDER) && v2.is(Layout::FPREFER)); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index f9a208df1..b163f16a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2014-2016 bluss and ndarray developers. +// Copyright 2014-2020 bluss and ndarray developers. // // Licensed under the Apache License, Version 2.0 or the MIT license @@ -6,40 +6,47 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. #![crate_name = "ndarray"] -#![doc(html_root_url = "https://docs.rs/ndarray/0.13/")] +#![doc(html_root_url = "https://docs.rs/ndarray/0.15/")] +#![doc(html_logo_url = "https://rust-ndarray.github.io/images/rust-ndarray_logo.svg")] #![allow( - clippy::many_single_char_names, + unstable_name_collisions, // our `PointerExt` collides with upcoming inherent methods on `NonNull` clippy::deref_addrof, - clippy::unreadable_literal, - clippy::many_single_char_names + clippy::manual_map, // is not an error + clippy::while_let_on_iterator, // is not an error + clippy::from_iter_instead_of_collect, // using from_iter is good style + clippy::incompatible_msrv, // false positive PointerExt::offset )] +#![doc(test(attr(deny(warnings))))] +#![doc(test(attr(allow(unused_variables))))] +#![doc(test(attr(allow(deprecated))))] +#![cfg_attr(not(feature = "std"), no_std)] //! The `ndarray` crate provides an *n*-dimensional container for general elements //! and for numerics. //! -//! In *n*-dimensional we include for example 1-dimensional rows or columns, +//! In *n*-dimensional we include, for example, 1-dimensional rows or columns, //! 2-dimensional matrices, and higher dimensional arrays. If the array has *n* //! dimensions, then an element in the array is accessed by using that many indices. //! Each dimension is also called an *axis*. //! -//! - **[`ArrayBase`](struct.ArrayBase.html)**: +//! - **[`ArrayBase`]**: //! The *n*-dimensional array type itself.
//! It is used to implement both the owned arrays and the views; see its docs //! for an overview of all array features.
-//! - The main specific array type is **[`Array`](type.Array.html)**, which owns -//! its elements. +//! - The main specific array type is **[`Array`]**, which owns +//! its elements. //! //! ## Highlights //! //! - Generic *n*-dimensional array -//! - Slicing, also with arbitrary step size, and negative indices to mean -//! elements from the end of the axis. +//! - [Slicing](ArrayBase#slicing), also with arbitrary step size, and negative +//! indices to mean elements from the end of the axis. //! - Views and subviews of arrays; iterators that yield subviews. //! - Higher order operations and arithmetic are performant //! - Array views can be used to slice and mutate any `[T]` data using //! `ArrayView::from` and `ArrayViewMut::from`. -//! - [`Zip`](struct.Zip.html) for lock step function application across two or more arrays or other -//! item producers ([`NdProducer`](trait.NdProducer.html) trait). +//! - [`Zip`] for lock step function application across two or more arrays or other +//! item producers ([`NdProducer`] trait). //! //! ## Crate Status //! @@ -53,45 +60,39 @@ //! - Performance: //! + Prefer higher order methods and arithmetic operations on arrays first, //! then iteration, and as a last priority using indexed algorithms. -//! + The higher order functions like ``.map()``, ``.map_inplace()``, -//! ``.zip_mut_with()``, ``Zip`` and ``azip!()`` are the most efficient ways +//! + The higher order functions like [`.map()`](ArrayBase::map), +//! [`.map_inplace()`](ArrayBase::map_inplace), [`.zip_mut_with()`](ArrayBase::zip_mut_with), +//! [`Zip`] and [`azip!()`](azip) are the most efficient ways //! to perform single traversal and lock step traversal respectively. //! + Performance of an operation depends on the memory layout of the array //! or array view. Especially if it's a binary operation, which //! needs matching memory layout to be efficient (with some exceptions). //! + Efficient floating point matrix multiplication even for very large //! matrices; can optionally use BLAS to improve it further. -//! - **Requires Rust 1.37 or later** +//! +//! - **MSRV: Requires Rust 1.64 or later** //! //! ## Crate Feature Flags //! //! The following crate feature flags are available. They are configured in your -//! `Cargo.toml`. +//! `Cargo.toml`. See [`doc::crate_feature_flags`] for more information. //! -//! - `serde` -//! - Optional, compatible with Rust stable -//! - Enables serialization support for serde 1.x -//! - `rayon` -//! - Optional, compatible with Rust stable -//! - Enables parallel iterators, parallelized methods and [`par_azip!`]. -//! - `approx` -//! - Optional, compatible with Rust stable -//! - Enables implementations of traits from the [`approx`] crate. -//! - `blas` -//! - Optional and experimental, compatible with Rust stable -//! - Enable transparent BLAS support for matrix multiplication. -//! Uses ``blas-src`` for pluggable backend, which needs to be configured -//! separately. +//! - `std`: Rust standard library-using functionality (enabled by default) +//! - `serde`: serialization support for serde 1.x +//! - `rayon`: Parallel iterators, parallelized methods, the [`parallel`] module and [`par_azip!`]. +//! - `approx` Implementations of traits from the [`approx`] crate. +//! - `blas`: transparent BLAS support for matrix multiplication, needs configuration. +//! - `matrixmultiply-threading`: Use threading from `matrixmultiply`. //! //! ## Documentation //! -//! * The docs for [`ArrayBase`](struct.ArrayBase.html) provide an overview of +//! * The docs for [`ArrayBase`] provide an overview of //! the *n*-dimensional array type. Other good pages to look at are the -//! documentation for the [`s![]`](macro.s.html) and -//! [`azip!()`](macro.azip.html) macros. +//! documentation for the [`s![]`](s!) and +//! [`azip!()`](azip!) macros. //! //! * If you have experience with NumPy, you may also be interested in -//! [`ndarray_for_numpy_users`](doc/ndarray_for_numpy_users/index.html). +//! [`ndarray_for_numpy_users`](doc::ndarray_for_numpy_users). //! //! ## The ndarray ecosystem //! @@ -105,36 +106,55 @@ //! but more advanced routines can be found in [`ndarray-stats`](https://crates.io/crates/ndarray-stats). //! //! If you are looking to generate random arrays instead, check out [`ndarray-rand`](https://crates.io/crates/ndarray-rand). +//! +//! For conversion between `ndarray`, [`nalgebra`](https://crates.io/crates/nalgebra) and +//! [`image`](https://crates.io/crates/image) check out [`nshare`](https://crates.io/crates/nshare). + +extern crate alloc; + +#[cfg(not(feature = "std"))] +extern crate core as std; +#[cfg(feature = "std")] +extern crate std; -#[cfg(feature = "blas")] -extern crate blas_src; #[cfg(feature = "blas")] extern crate cblas_sys; #[cfg(feature = "docs")] pub mod doc; +#[cfg(target_has_atomic = "ptr")] +use alloc::sync::Arc; + +#[cfg(not(target_has_atomic = "ptr"))] +use portable_atomic_util::Arc; + use std::marker::PhantomData; -use std::sync::Arc; pub use crate::dimension::dim::*; pub use crate::dimension::{Axis, AxisDescription, Dimension, IntoDimension, RemoveAxis}; +pub use crate::dimension::{DimAdd, DimMax}; pub use crate::dimension::IxDynImpl; pub use crate::dimension::NdIndex; pub use crate::error::{ErrorKind, ShapeError}; pub use crate::indexes::{indices, indices_of}; -pub use crate::slice::{Slice, SliceInfo, SliceNextDim, SliceOrIndex}; +pub use crate::order::Order; +pub use crate::slice::{MultiSliceArg, NewAxis, Slice, SliceArg, SliceInfo, SliceInfoElem, SliceNextDim}; use crate::iterators::Baseiter; -use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, LanesMut}; +use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut}; pub use crate::arraytraits::AsArray; -pub use crate::linalg_traits::{LinalgScalar, NdFloat}; -pub use crate::stacking::stack; +pub use crate::linalg_traits::LinalgScalar; +#[cfg(feature = "std")] +pub use crate::linalg_traits::NdFloat; + +pub use crate::stacking::{concatenate, stack}; pub use crate::impl_views::IndexLonger; -pub use crate::shape_builder::ShapeBuilder; +pub use crate::math_cell::MathCell; +pub use crate::shape_builder::{Shape, ShapeArg, ShapeBuilder, StrideShape}; #[macro_use] mod macro_utils; @@ -143,20 +163,18 @@ mod private; mod aliases; #[macro_use] mod itertools; -#[cfg(feature = "approx")] -mod array_approx; +mod argument_traits; #[cfg(feature = "serde")] mod array_serde; mod arrayformat; mod arraytraits; +pub use crate::argument_traits::AssignElem; +mod data_repr; mod data_traits; pub use crate::aliases::*; -#[allow(deprecated)] -pub use crate::data_traits::{ - Data, DataClone, DataMut, DataOwned, DataShared, RawData, RawDataClone, RawDataMut, -}; +pub use crate::data_traits::{Data, DataMut, DataOwned, DataShared, RawData, RawDataClone, RawDataMut, RawDataSubst}; mod free_functions; pub use crate::free_functions::*; @@ -170,12 +188,21 @@ mod iterators; mod layout; mod linalg_traits; mod linspace; +#[cfg(feature = "std")] +pub use crate::linspace::{linspace, range, Linspace}; mod logspace; +#[cfg(feature = "std")] +pub use crate::logspace::{logspace, Logspace}; +mod math_cell; mod numeric_util; +mod order; +mod partial; mod shape_builder; #[macro_use] mod slice; +mod split_at; mod stacking; +mod low_level_util; #[macro_use] mod zip; @@ -186,13 +213,24 @@ pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip}; pub use crate::layout::Layout; /// Implementation's prelude. Common types used everywhere. -mod imp_prelude { +mod imp_prelude +{ pub use crate::dimension::DimensionExt; pub use crate::prelude::*; pub use crate::ArcArray; pub use crate::{ - CowRepr, Data, DataMut, DataOwned, DataShared, Ix, Ixs, RawData, RawDataMut, RawViewRepr, - RemoveAxis, ViewRepr, + CowRepr, + Data, + DataMut, + DataOwned, + DataShared, + Ix, + Ixs, + RawData, + RawDataMut, + RawViewRepr, + RemoveAxis, + ViewRepr, }; } @@ -205,9 +243,13 @@ pub type Ixs = isize; /// An *n*-dimensional array. /// -/// The array is a general container of elements. It cannot grow or shrink, but -/// can be sliced into subsets of its data. -/// The array supports arithmetic operations by applying them elementwise. +/// The array is a general container of elements. +/// The array supports arithmetic operations by applying them elementwise, if the +/// elements are numeric, but it supports non-numeric elements too. +/// +/// The arrays rarely grow or shrink, since those operations can be costly. On +/// the other hand there is a rich set of methods and operations for taking views, +/// slices, and making traversals over one or more arrays. /// /// In *n*-dimensional we include for example 1-dimensional rows or columns, /// 2-dimensional matrices, and higher dimensional arrays. If the array has *n* @@ -218,13 +260,7 @@ pub type Ixs = isize; /// /// Type aliases [`Array`], [`ArcArray`], [`CowArray`], [`ArrayView`], and /// [`ArrayViewMut`] refer to `ArrayBase` with different types for the data -/// container. -/// -/// [`Array`]: type.Array.html -/// [`ArcArray`]: type.ArcArray.html -/// [`ArrayView`]: type.ArrayView.html -/// [`ArrayViewMut`]: type.ArrayViewMut.html -/// [`CowArray`]: type.CowArray.html +/// container: arrays with different kinds of ownership or different kinds of array views. /// /// ## Contents /// @@ -248,7 +284,7 @@ pub type Ixs = isize; /// /// ## `Array` /// -/// [`Array`](type.Array.html) is an owned array that owns the underlying array +/// [`Array`] is an owned array that owns the underlying array /// elements directly (just like a `Vec`) and it is the default way to create and /// store n-dimensional data. `Array` has two type parameters: `A` for /// the element type, and `D` for the dimensionality. A particular @@ -267,17 +303,16 @@ pub type Ixs = isize; /// /// ## `ArcArray` /// -/// [`ArcArray`](type.ArcArray.html) is an owned array with reference counted +/// [`ArcArray`] is an owned array with reference counted /// data (shared ownership). /// Sharing requires that it uses copy-on-write for mutable operations. /// Calling a method for mutating elements on `ArcArray`, for example -/// [`view_mut()`](#method.view_mut) or [`get_mut()`](#method.get_mut), +/// [`view_mut()`](Self::view_mut) or [`get_mut()`](Self::get_mut), /// will break sharing and require a clone of the data (if it is not uniquely held). /// /// ## `CowArray` /// -/// [`CowArray`](type.CowArray.html) is analogous to -/// [`std::borrow::Cow`](https://doc.rust-lang.org/std/borrow/enum.Cow.html). +/// [`CowArray`] is analogous to [`std::borrow::Cow`]. /// It can represent either an immutable view or a uniquely owned array. If a /// `CowArray` instance is the immutable view variant, then calling a method /// for mutating elements in the array will cause it to be converted into the @@ -288,7 +323,7 @@ pub type Ixs = isize; /// /// [`ArrayView`] and [`ArrayViewMut`] are read-only and read-write array views /// respectively. They use dimensionality, indexing, and almost all other -/// methods the same was as the other array types. +/// methods the same way as the other array types. /// /// Methods for `ArrayBase` apply to array views too, when the trait bounds /// allow. @@ -296,9 +331,10 @@ pub type Ixs = isize; /// Please see the documentation for the respective array view for an overview /// of methods specific to array views: [`ArrayView`], [`ArrayViewMut`]. /// -/// A view is created from an array using `.view()`, `.view_mut()`, using -/// slicing (`.slice()`, `.slice_mut()`) or from one of the many iterators -/// that yield array views. +/// A view is created from an array using [`.view()`](ArrayBase::view), +/// [`.view_mut()`](ArrayBase::view_mut), using +/// slicing ([`.slice()`](ArrayBase::slice), [`.slice_mut()`](ArrayBase::slice_mut)) or from one of +/// the many iterators that yield array views. /// /// You can also create an array view from a regular slice of data not /// allocated with `Array` — see array view methods or their `From` impls. @@ -341,16 +377,16 @@ pub type Ixs = isize; /// /// Important traits and types for dimension and indexing: /// -/// - A [`Dim`](struct.Dim.html) value represents a dimensionality or index. -/// - Trait [`Dimension`](trait.Dimension.html) is implemented by all -/// dimensionalities. It defines many operations for dimensions and indices. -/// - Trait [`IntoDimension`](trait.IntoDimension.html) is used to convert into a -/// `Dim` value. -/// - Trait [`ShapeBuilder`](trait.ShapeBuilder.html) is an extension of -/// `IntoDimension` and is used when constructing an array. A shape describes -/// not just the extent of each axis but also their strides. -/// - Trait [`NdIndex`](trait.NdIndex.html) is an extension of `Dimension` and is -/// for values that can be used with indexing syntax. +/// - A [`struct@Dim`] value represents a dimensionality or index. +/// - Trait [`Dimension`] is implemented by all +/// dimensionalities. It defines many operations for dimensions and indices. +/// - Trait [`IntoDimension`] is used to convert into a +/// `Dim` value. +/// - Trait [`ShapeBuilder`] is an extension of +/// `IntoDimension` and is used when constructing an array. A shape describes +/// not just the extent of each axis but also their strides. +/// - Trait [`NdIndex`] is an extension of `Dimension` and is +/// for values that can be used with indexing syntax. /// /// /// The default memory order of an array is *row major* order (a.k.a “c” order), @@ -364,10 +400,10 @@ pub type Ixs = isize; /// /// ## Loops, Producers and Iterators /// -/// Using [`Zip`](struct.Zip.html) is the most general way to apply a procedure +/// Using [`Zip`] is the most general way to apply a procedure /// across one or several arrays or *producers*. /// -/// [`NdProducer`](trait.NdProducer.html) is like an iterable but for +/// [`NdProducer`] is like an iterable but for /// multidimensional data. All producers have dimensions and axes, like an /// array view, and they can be split and used with parallelization using `Zip`. /// @@ -400,16 +436,16 @@ pub type Ixs = isize; /// /// The `outer_iter` and `axis_iter` are one dimensional producers. /// -/// ## `.genrows()`, `.gencolumns()` and `.lanes()` +/// ## `.rows()`, `.columns()` and `.lanes()` /// -/// [`.genrows()`][gr] is a producer (and iterable) of all rows in an array. +/// [`.rows()`][gr] is a producer (and iterable) of all rows in an array. /// /// ``` /// use ndarray::Array; /// /// // 1. Loop over the rows of a 2D array /// let mut a = Array::zeros((10, 10)); -/// for mut row in a.genrows_mut() { +/// for mut row in a.rows_mut() { /// row.fill(1.); /// } /// @@ -417,9 +453,9 @@ pub type Ixs = isize; /// use ndarray::Zip; /// let mut b = Array::zeros(a.nrows()); /// -/// Zip::from(a.genrows()) +/// Zip::from(a.rows()) /// .and(&mut b) -/// .apply(|a_row, b_elt| { +/// .for_each(|a_row, b_elt| { /// *b_elt = a_row[a.ncols() - 1] - a_row[0]; /// }); /// ``` @@ -435,21 +471,21 @@ pub type Ixs = isize; /// has *a m* rows. It's composed of *a* times the previous array, so it /// has *a* times as many rows. /// -/// All methods: [`.genrows()`][gr], [`.genrows_mut()`][grm], -/// [`.gencolumns()`][gc], [`.gencolumns_mut()`][gcm], +/// All methods: [`.rows()`][gr], [`.rows_mut()`][grm], +/// [`.columns()`][gc], [`.columns_mut()`][gcm], /// [`.lanes(axis)`][l], [`.lanes_mut(axis)`][lm]. /// -/// [gr]: #method.genrows -/// [grm]: #method.genrows_mut -/// [gc]: #method.gencolumns -/// [gcm]: #method.gencolumns_mut -/// [l]: #method.lanes -/// [lm]: #method.lanes_mut +/// [gr]: Self::rows +/// [grm]: Self::rows_mut +/// [gc]: Self::columns +/// [gcm]: Self::columns_mut +/// [l]: Self::lanes +/// [lm]: Self::lanes_mut /// -/// Yes, for 2D arrays `.genrows()` and `.outer_iter()` have about the same +/// Yes, for 2D arrays `.rows()` and `.outer_iter()` have about the same /// effect: /// -/// + `genrows()` is a producer with *n* - 1 dimensions of 1 dimensional items +/// + `rows()` is a producer with *n* - 1 dimensions of 1 dimensional items /// + `outer_iter()` is a producer with 1 dimension of *n* - 1 dimensional items /// /// ## Slicing @@ -458,33 +494,45 @@ pub type Ixs = isize; /// the array. Slicing methods include [`.slice()`], [`.slice_mut()`], /// [`.slice_move()`], and [`.slice_collapse()`]. /// -/// The slicing argument can be passed using the macro [`s![]`](macro.s!.html), +/// The slicing argument can be passed using the macro [`s![]`](s!), /// which will be used in all examples. (The explicit form is an instance of -/// [`&SliceInfo`]; see its docs for more information.) -/// -/// [`&SliceInfo`]: struct.SliceInfo.html +/// [`SliceInfo`] or another type which implements [`SliceArg`]; see their docs +/// for more information.) /// /// If a range is used, the axis is preserved. If an index is used, that index /// is selected and the axis is removed; this selects a subview. See -/// [*Subviews*](#subviews) for more information about subviews. Note that -/// [`.slice_collapse()`] behaves like [`.collapse_axis()`] by preserving -/// the number of dimensions. -/// -/// [`.slice()`]: #method.slice -/// [`.slice_mut()`]: #method.slice_mut -/// [`.slice_move()`]: #method.slice_move -/// [`.slice_collapse()`]: #method.slice_collapse +/// [*Subviews*](#subviews) for more information about subviews. If a +/// [`NewAxis`] instance is used, a new axis is inserted. Note that +/// [`.slice_collapse()`] panics on `NewAxis` elements and behaves like +/// [`.collapse_axis()`] by preserving the number of dimensions. +/// +/// [`.slice()`]: Self::slice +/// [`.slice_mut()`]: Self::slice_mut +/// [`.slice_move()`]: Self::slice_move +/// [`.slice_collapse()`]: Self::slice_collapse +/// +/// When slicing arrays with generic dimensionality, creating an instance of +/// [`SliceInfo`] to pass to the multi-axis slicing methods like [`.slice()`] +/// is awkward. In these cases, it's usually more convenient to use +/// [`.slice_each_axis()`]/[`.slice_each_axis_mut()`]/[`.slice_each_axis_inplace()`] +/// or to create a view and then slice individual axes of the view using +/// methods such as [`.slice_axis_inplace()`] and [`.collapse_axis()`]. +/// +/// [`.slice_each_axis()`]: Self::slice_each_axis +/// [`.slice_each_axis_mut()`]: Self::slice_each_axis_mut +/// [`.slice_each_axis_inplace()`]: Self::slice_each_axis_inplace +/// [`.slice_axis_inplace()`]: Self::slice_axis_inplace +/// [`.collapse_axis()`]: Self::collapse_axis /// /// It's possible to take multiple simultaneous *mutable* slices with /// [`.multi_slice_mut()`] or (for [`ArrayViewMut`] only) /// [`.multi_slice_move()`]. /// -/// [`.multi_slice_mut()`]: #method.multi_slice_mut -/// [`.multi_slice_move()`]: type.ArrayViewMut.html#method.multi_slice_move +/// [`.multi_slice_mut()`]: Self::multi_slice_mut +/// [`.multi_slice_move()`]: ArrayViewMut#method.multi_slice_move /// /// ``` -/// -/// use ndarray::{arr2, arr3, s}; +/// use ndarray::{arr2, arr3, s, ArrayBase, DataMut, Dimension, NewAxis, Slice}; /// /// // 2 submatrices of 2 rows with 3 elements per row, means a shape of `[2, 2, 3]`. /// @@ -519,16 +567,17 @@ pub type Ixs = isize; /// assert_eq!(d, e); /// assert_eq!(d.shape(), &[2, 1, 3]); /// -/// // Let’s create a slice while selecting a subview with +/// // Let’s create a slice while selecting a subview and inserting a new axis with /// // /// // - Both submatrices of the greatest dimension: `..` /// // - The last row in each submatrix, removing that axis: `-1` /// // - Row elements in reverse order: `..;-1` -/// let f = a.slice(s![.., -1, ..;-1]); -/// let g = arr2(&[[ 6, 5, 4], -/// [12, 11, 10]]); +/// // - A new axis at the end. +/// let f = a.slice(s![.., -1, ..;-1, NewAxis]); +/// let g = arr3(&[[ [6], [5], [4]], +/// [[12], [11], [10]]]); /// assert_eq!(f, g); -/// assert_eq!(f.shape(), &[2, 3]); +/// assert_eq!(f.shape(), &[2, 3, 1]); /// /// // Let's take two disjoint, mutable slices of a matrix with /// // @@ -543,6 +592,21 @@ pub type Ixs = isize; /// [5, 7]]); /// assert_eq!(s0, i); /// assert_eq!(s1, j); +/// +/// // Generic function which assigns the specified value to the elements which +/// // have indices in the lower half along all axes. +/// fn fill_lower(arr: &mut ArrayBase, x: S::Elem) +/// where +/// S: DataMut, +/// S::Elem: Clone, +/// D: Dimension, +/// { +/// arr.slice_each_axis_mut(|ax| Slice::from(0..ax.len / 2)).fill(x); +/// } +/// fill_lower(&mut h, 9); +/// let k = arr2(&[[9, 9, 2, 3], +/// [4, 5, 6, 7]]); +/// assert_eq!(h, k); /// ``` /// /// ## Subviews @@ -563,16 +627,16 @@ pub type Ixs = isize; /// Methods for selecting an individual subview take two arguments: `axis` and /// `index`. /// -/// [`.axis_iter()`]: #method.axis_iter -/// [`.axis_iter_mut()`]: #method.axis_iter_mut -/// [`.fold_axis()`]: #method.fold_axis -/// [`.index_axis()`]: #method.index_axis -/// [`.index_axis_inplace()`]: #method.index_axis_inplace -/// [`.index_axis_mut()`]: #method.index_axis_mut -/// [`.index_axis_move()`]: #method.index_axis_move -/// [`.collapse_axis()`]: #method.collapse_axis -/// [`.outer_iter()`]: #method.outer_iter -/// [`.outer_iter_mut()`]: #method.outer_iter_mut +/// [`.axis_iter()`]: Self::axis_iter +/// [`.axis_iter_mut()`]: Self::axis_iter_mut +/// [`.fold_axis()`]: Self::fold_axis +/// [`.index_axis()`]: Self::index_axis +/// [`.index_axis_inplace()`]: Self::index_axis_inplace +/// [`.index_axis_mut()`]: Self::index_axis_mut +/// [`.index_axis_move()`]: Self::index_axis_move +/// [`.collapse_axis()`]: Self::collapse_axis +/// [`.outer_iter()`]: Self::outer_iter +/// [`.outer_iter_mut()`]: Self::outer_iter_mut /// /// ``` /// @@ -655,10 +719,10 @@ pub type Ixs = isize; /// /// ### Binary Operators with Array and Scalar /// -/// The trait [`ScalarOperand`](trait.ScalarOperand.html) marks types that can be used in arithmetic +/// The trait [`ScalarOperand`] marks types that can be used in arithmetic /// with arrays directly. For a scalar `K` the following combinations of operands /// are supported (scalar can be on either the left or right side, but -/// `ScalarOperand` docs has the detailed condtions). +/// `ScalarOperand` docs has the detailed conditions). /// /// - `&A @ K` or `K @ &A` which produces a new `Array` /// - `B @ K` or `K @ B` which consumes `B`, updates it with the result and returns it @@ -678,7 +742,7 @@ pub type Ixs = isize; /// Arrays support limited *broadcasting*, where arithmetic operations with /// array operands of different sizes can be carried out by repeating the /// elements of the smaller dimension array. See -/// [`.broadcast()`](#method.broadcast) for a more detailed +/// [`.broadcast()`](Self::broadcast) for a more detailed /// description. /// /// ``` @@ -824,12 +888,12 @@ pub type Ixs = isize; ///

/// -/// [`CowArray::from(a)`](type.CowArray.html#impl-From%2C%20D>>) +/// [`CowArray::from(a)`](CowArray#impl-From%2C%20D>>) /// /// /// -/// [`CowArray::from(a.into_owned())`](type.CowArray.html#impl-From%2C%20D>>) +/// [`CowArray::from(a.into_owned())`](CowArray#impl-From%2C%20D>>) /// /// @@ -839,12 +903,12 @@ pub type Ixs = isize; /// /// -/// [`CowArray::from(a)`](type.CowArray.html#impl-From%2C%20D>>) +/// [`CowArray::from(a)`](CowArray#impl-From%2C%20D>>) /// /// /// -/// [`CowArray::from(a.view())`](type.CowArray.html#impl-From%2C%20D>>) +/// [`CowArray::from(a.view())`](CowArray#impl-From%2C%20D>>) /// ///