Skip to content

Commit

Permalink
Export GPGPUContext and add getCanvas() to the WebGLBackend. (tensorf…
Browse files Browse the repository at this point in the history
…low#982)

BREAKING
FEATURE
  • Loading branch information
Nikhil Thorat authored Apr 29, 2018
1 parent 81278ae commit 0396f9f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 23 deletions.
12 changes: 2 additions & 10 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ import * as environment from './environment';
import {Environment} from './environment';
// Serialization.
import * as io from './io/io';
import * as gpgpu_util from './kernels/webgl/gpgpu_util';
import * as webgl_util from './kernels/webgl/webgl_util';
import * as test_util from './test_util';
import * as util from './util';
import {version} from './version';
import * as webgl from './webgl';

// Optimizers.
export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer';
Expand Down Expand Up @@ -64,14 +63,7 @@ export {doc} from './doc';
export const nextFrame = BrowserUtil.nextFrame;

// Second level exports.
export {environment, io, test_util, util};

// WebGL specific utils.
export const webgl = {
webgl_util,
gpgpu_util
};
export {WebGLTimingInfo} from './kernels/backend_webgl';
export {environment, io, test_util, util, webgl};

// Backend specific.
export {KernelBackend, BackendTimingInfo} from './kernels/backend';
21 changes: 13 additions & 8 deletions src/kernels/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'
import * as types from '../types';
import {DataType, DataTypeMap, RecursiveArray, TypedArray} from '../types';
import * as util from '../util';

import {KernelBackend} from './backend';
import * as backend_util from './backend_util';
import {ArgMinMaxProgram} from './webgl/argminmax_gpu';
Expand All @@ -44,6 +45,7 @@ import {GatherProgram} from './webgl/gather_gpu';
import {GPGPUContext} from './webgl/gpgpu_context';
import * as gpgpu_math from './webgl/gpgpu_math';
import {GPGPUBinary, GPGPUProgram, TensorData} from './webgl/gpgpu_math';
import * as gpgpu_util from './webgl/gpgpu_util';
import {WhereProgram} from './webgl/logical_gpu';
import {LRNProgram} from './webgl/lrn_gpu';
import {MaxPool2DBackpropProgram} from './webgl/max_pool_backprop_gpu';
Expand Down Expand Up @@ -296,21 +298,25 @@ export class MathBackendWebGL implements KernelBackend {
if (ENV.get('WEBGL_VERSION') < 1) {
throw new Error('WebGL is not supported on this device');
}
if (typeof document !== 'undefined') {
this.canvas = document.createElement('canvas');
}
if (gpgpu == null) {
this.gpgpu = new GPGPUContext();
this.gpgpu = new GPGPUContext(gpgpu_util.createWebGLContext(this.canvas));
this.gpgpuCreatedLocally = true;
} else {
this.gpgpuCreatedLocally = false;
}
if (typeof document !== 'undefined') {
this.canvas = document.createElement('canvas');
}

this.textureManager = new TextureManager(this.gpgpu);
}

getGPGPUContext(): GPGPUContext {
return this.gpgpu;
}
getCanvas(): HTMLCanvasElement {
return this.canvas;
}

slice<T extends Tensor>(x: T, begin: number[], size: number[]): T {
const program = new SliceProgram(size);
Expand Down Expand Up @@ -886,10 +892,9 @@ export class MathBackendWebGL implements KernelBackend {
resizeNearestNeighbor(
x: Tensor4D, newHeight: number, newWidth: number,
alignCorners: boolean): Tensor4D {
const program =
new ResizeNearestNeighborProgram(x.shape, newHeight,
newWidth, alignCorners);
return this.compileAndRun(program, [x]);
const program = new ResizeNearestNeighborProgram(
x.shape, newHeight, newWidth, alignCorners);
return this.compileAndRun(program, [x]);
}

multinomial(
Expand Down
10 changes: 5 additions & 5 deletions src/tracking_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
*/

import * as tf from './index';
import {CPU_ENVS, WEBGL_ENVS} from './test_util';
import {describeWithFlags} from './jasmine_util';
import {CPU_ENVS, WEBGL_ENVS} from './test_util';

describeWithFlags('time webgl', WEBGL_ENVS, () => {
it('upload + compute', async () => {
const a = tf.zeros([10, 10]);
const time = await tf.time(() => a.square()) as tf.WebGLTimingInfo;
const time = await tf.time(() => a.square()) as tf.webgl.WebGLTimingInfo;
expect(time.uploadWaitMs > 0);
expect(time.downloadWaitMs === 0);
expect(time.kernelMs > 0);
Expand All @@ -32,7 +32,7 @@ describeWithFlags('time webgl', WEBGL_ENVS, () => {
it('upload + compute + dataSync', async () => {
const a = tf.zeros([10, 10]);
const time =
await tf.time(() => a.square().dataSync()) as tf.WebGLTimingInfo;
await tf.time(() => a.square().dataSync()) as tf.webgl.WebGLTimingInfo;
expect(time.uploadWaitMs > 0);
expect(time.downloadWaitMs > 0);
expect(time.kernelMs > 0);
Expand All @@ -42,7 +42,7 @@ describeWithFlags('time webgl', WEBGL_ENVS, () => {
it('upload + compute + data', async () => {
const a = tf.zeros([10, 10]);
const time = await tf.time(async () => await a.square().data()) as
tf.WebGLTimingInfo;
tf.webgl.WebGLTimingInfo;
expect(time.uploadWaitMs > 0);
expect(time.downloadWaitMs > 0);
expect(time.kernelMs > 0);
Expand All @@ -53,7 +53,7 @@ describeWithFlags('time webgl', WEBGL_ENVS, () => {
const a = tf.zeros([10, 10]);
// Pre-upload a on gpu.
a.square();
const time = await tf.time(() => a.sqrt()) as tf.WebGLTimingInfo;
const time = await tf.time(() => a.sqrt()) as tf.webgl.WebGLTimingInfo;
// The tensor was already on gpu.
expect(time.uploadWaitMs === 0);
expect(time.downloadWaitMs === 0);
Expand Down
24 changes: 24 additions & 0 deletions src/webgl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as webgl_util from './kernels/webgl/webgl_util';
import * as gpgpu_util from './kernels/webgl/webgl_util';

export {MathBackendWebGL, WebGLTimingInfo} from './kernels/backend_webgl';
export {GPGPUContext} from './kernels/webgl/gpgpu_context';
// WebGL specific utils.
export {gpgpu_util, webgl_util};

0 comments on commit 0396f9f

Please sign in to comment.