Skip to content

Commit d5e7f67

Browse files
committed
Take generic ArrayBase in concatenate and stack
1 parent 27af864 commit d5e7f67

2 files changed

Lines changed: 9 additions & 7 deletions

File tree

src/stacking.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ use crate::imp_prelude::*;
3333
/// [3., 3.]]))
3434
/// );
3535
/// ```
36-
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
36+
pub fn concatenate<S, D, A>(axis: Axis, arrays: &[ArrayBase<S, D, A>]) -> Result<Array<A, D>, ShapeError>
3737
where
38+
S: Data<Elem = A>,
3839
A: Clone,
3940
D: RemoveAxis,
4041
{
@@ -66,7 +67,7 @@ where
6667
};
6768

6869
for array in arrays {
69-
res.append(axis, array.clone())?;
70+
res.append(axis, array.view())?;
7071
}
7172
debug_assert_eq!(res.len_of(axis), stacked_dim);
7273
Ok(res)
@@ -96,8 +97,9 @@ where
9697
/// );
9798
/// # }
9899
/// ```
99-
pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D::Larger>, ShapeError>
100+
pub fn stack<S, D, A>(axis: Axis, arrays: &[ArrayBase<S, D, A>]) -> Result<Array<A, D::Larger>, ShapeError>
100101
where
102+
S: Data<Elem = A>,
101103
A: Clone,
102104
D: Dimension,
103105
D::Larger: RemoveAxis,
@@ -129,7 +131,7 @@ where
129131
};
130132

131133
for array in arrays {
132-
res.append(axis, array.clone().insert_axis(axis))?;
134+
res.append(axis, array.view().insert_axis(axis))?;
133135
}
134136

135137
debug_assert_eq!(res.len_of(axis), arrays.len());

tests/stacking.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use ndarray::{arr2, arr3, aview1, aview2, concatenate, stack, Array2, Axis, ErrorKind, Ix1};
1+
use ndarray::{arr2, arr3, aview1, aview2, concatenate, stack, Array2, Axis, ErrorKind, Ix1, ViewRepr};
22

33
#[test]
44
fn concatenating()
@@ -29,7 +29,7 @@ fn concatenating()
2929
let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]);
3030
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);
3131

32-
let res: Result<Array2<f64>, _> = ndarray::concatenate(Axis(0), &[]);
32+
let res: Result<Array2<f64>, _> = ndarray::concatenate::<ViewRepr<&f64>, _, _>(Axis(0), &[]);
3333
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
3434
}
3535

@@ -50,6 +50,6 @@ fn stacking()
5050
let res = ndarray::stack(Axis(3), &[a.view(), a.view()]);
5151
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);
5252

53-
let res: Result<Array2<f64>, _> = ndarray::stack::<_, Ix1>(Axis(0), &[]);
53+
let res: Result<Array2<f64>, _> = ndarray::stack::<ViewRepr<&f64>, Ix1, _>(Axis(0), &[]);
5454
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
5555
}

0 commit comments

Comments
 (0)