Skip to content

Commit

Permalink
Fix tensor2d/3d/4d to require shape to have the correct length (tenso…
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmilkov authored Apr 27, 2018
1 parent 44c950c commit ef6a33e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/ops/array_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ export class ArrayOps {
static tensor2d(
values: TensorLike2D, shape?: [number, number],
dtype: DataType = 'float32'): Tensor2D {
if (shape != null && shape.length !== 2) {
throw new Error('tensor2d() requires shape to have two numbers');
}
const inferredShape = util.inferShape(values);
if (inferredShape.length !== 2 && inferredShape.length !== 1) {
throw new Error(
Expand Down Expand Up @@ -186,6 +189,9 @@ export class ArrayOps {
static tensor3d(
values: TensorLike3D, shape?: [number, number, number],
dtype: DataType = 'float32'): Tensor3D {
if (shape != null && shape.length !== 3) {
throw new Error('tensor3d() requires shape to have three numbers');
}
const inferredShape = util.inferShape(values);
if (inferredShape.length !== 3 && inferredShape.length !== 1) {
throw new Error(
Expand Down Expand Up @@ -225,6 +231,9 @@ export class ArrayOps {
static tensor4d(
values: TensorLike4D, shape?: [number, number, number, number],
dtype: DataType = 'float32'): Tensor4D {
if (shape != null && shape.length !== 4) {
throw new Error('tensor4d() requires shape to have four numbers');
}
const inferredShape = util.inferShape(values);
if (inferredShape.length !== 4 && inferredShape.length !== 1) {
throw new Error(
Expand Down
18 changes: 18 additions & 0 deletions src/tensor_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ describeWithFlags('tensor', ALL_ENVS, () => {
expectArraysClose(a, [1, 2, 3, 4, 5, 6]);
});

it('tf.tensor2d() requires shape to be of length 2', () => {
// tslint:disable-next-line:no-any
const shape: any = [4];
expect(() => tf.tensor2d([1, 2, 3, 4], shape)).toThrowError();
});

it('tf.tensor2d() from number[][], but shape does not match', () => {
// Actual shape is [2, 3].
expect(() => tf.tensor2d([[1, 2, 3], [4, 5, 6]], [3, 2])).toThrowError();
Expand All @@ -305,6 +311,12 @@ describeWithFlags('tensor', ALL_ENVS, () => {
expect(() => tf.tensor3d([1, 2, 3, 4])).toThrowError();
});

it('tf.tensor3d() requires shape to be of length 3', () => {
// tslint:disable-next-line:no-any
const shape: any = [4, 1];
expect(() => tf.tensor3d([1, 2, 3, 4], shape)).toThrowError();
});

it('tensor4d() from number[][][][]', () => {
const a = tf.tensor4d([[[[1]], [[2]]], [[[4]], [[5]]]], [2, 2, 1, 1]);
expectArraysClose(a, [1, 2, 4, 5]);
Expand All @@ -322,6 +334,12 @@ describeWithFlags('tensor', ALL_ENVS, () => {
expect(() => tf.tensor4d([1, 2, 3, 4])).toThrowError();
});

it('tf.tensor4d() requires shape to be of length 4', () => {
// tslint:disable-next-line:no-any
const shape: any = [4, 1];
expect(() => tf.tensor4d([1, 2, 3, 4], shape)).toThrowError();
});

it('default dtype', () => {
const a = tf.scalar(3);
expect(a.dtype).toBe('float32');
Expand Down

0 comments on commit ef6a33e

Please sign in to comment.