Skip to content

Commit

Permalink
Add tests for tf.conv1d gradients (tensorflow#992)
Browse files Browse the repository at this point in the history
  • Loading branch information
easadler authored and dsmilkov committed May 3, 2018
1 parent 52cc2f6 commit b062378
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions src/ops/conv1d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,62 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
tf.conv1d(x, {} as tf.Tensor3D, stride, pad, dataFormat, dilation))
.toThrowError(/Argument 'filter' passed to 'conv1d' must be a Tensor/);
});

it('conv1d gradients, input=2x2x1,d2=1,f=1,s=1,d=1,p=same', () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const filterShape: [number, number, number] =
[fSize, inputDepth, outputDepth];
const pad = 'same';
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], filterShape);

const dy = tf.tensor3d([3, 2, 1, 0], inputShape);

const grads = tf.grads((x: tf.Tensor3D, w: tf.Tensor3D) => tf.conv1d(
x, w, stride, pad, dataFormat, dilation));
const [dx, dw] = grads([x, w], dy);

expect(dx.shape).toEqual(x.shape);
expectArraysClose(dx, [9, 6, 3, 0]);

expect(dw.shape).toEqual(w.shape);
expectArraysClose(dw, [10]);
});

it('conv1d gradients input=14x1,d2=1,f=3x1x1,s=1,p=valid', () => {
const inputDepth = 1;
const inputShape: [number, number] = [14, inputDepth];

const outputDepth = 1;
const fSize = 3;
const pad = 'valid';
const stride = 1;
const dataFormat = 'NWC';

const x = tf.tensor2d(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], inputShape);
const w = tf.tensor3d([3, 2, 1], [fSize, inputDepth, outputDepth]);

const dy = tf.tensor2d(
[3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0], [12, inputDepth]);

const grads = tf.grads((x: tf.Tensor2D, w: tf.Tensor3D) => tf.conv1d(
x, w, stride, pad, dataFormat));
const [dx, dw] = grads([x, w], dy);

expect(dx.shape).toEqual(x.shape);
expectArraysClose(dx,
[9, 12, 10, 4, 10, 12, 10, 4, 10, 12, 10, 4, 1, 0]);

expect(dw.shape).toEqual(w.shape);
expectArraysClose(dw, [102, 120, 138]);
});

});

0 comments on commit b062378

Please sign in to comment.