diff --git a/tfjs-converter/src/executor/graph_model.ts b/tfjs-converter/src/executor/graph_model.ts index ef9c8ad57a1..95334467706 100644 --- a/tfjs-converter/src/executor/graph_model.ts +++ b/tfjs-converter/src/executor/graph_model.ts @@ -23,6 +23,8 @@ import {OperationMapper} from '../operations/operation_mapper'; import {GraphExecutor} from './graph_executor'; import {ResourceManager} from './resource_manager'; +// tslint:disable-next-line: no-imports-from-dist +import {decodeWeightsStream} from '@tensorflow/tfjs-core/dist/io/io_utils'; export const TFHUB_SEARCH_PARAM = '?tfjs-format=file'; export const DEFAULT_MODEL_NAME = 'model.json'; @@ -154,7 +156,12 @@ export class GraphModel implements const loadResult = this.handler.load() as ReturnType; if (util.isPromise(loadResult)) { - return loadResult.then(artifacts => this.loadSync(artifacts)) as Result; + return loadResult.then(artifacts => { + if (artifacts.getWeightStream == null) { + return this.loadSync(artifacts); + } + return this.loadStreaming(artifacts); + }) as Result; } return this.loadSync(loadResult) as Result; @@ -167,6 +174,25 @@ export class GraphModel implements * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true} */ loadSync(artifacts: io.ModelArtifacts) { + const weightMap = this.io.decodeWeights( + artifacts.weightData, artifacts.weightSpecs); + + return this.loadWithWeightMap(artifacts, weightMap); + } + + private async loadStreaming(artifacts: io.ModelArtifacts): Promise { + if (artifacts.getWeightStream == null) { + throw new Error('Model artifacts missing streamWeights function'); + } + + const weightMap = await decodeWeightsStream( + artifacts.getWeightStream(), artifacts.weightSpecs); + + return this.loadWithWeightMap(artifacts, weightMap); + } + + private loadWithWeightMap(artifacts: io.ModelArtifacts, + weightMap: NamedTensorMap) { this.artifacts = artifacts; const graph = this.artifacts.modelTopology as tensorflow.IGraphDef; @@ -184,8 +210,6 @@ export class GraphModel implements this.signature = signature; this.version = `${graph.versions.producer}.${graph.versions.minConsumer}`; - const weightMap = this.io.decodeWeights( - this.artifacts.weightData, this.artifacts.weightSpecs); this.executor = new GraphExecutor( OperationMapper.Instance.transformGraph(graph, this.signature)); this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap); diff --git a/tfjs-converter/src/executor/graph_model_test.ts b/tfjs-converter/src/executor/graph_model_test.ts index 2b6d4ca104e..4ccd826b1ce 100644 --- a/tfjs-converter/src/executor/graph_model_test.ts +++ b/tfjs-converter/src/executor/graph_model_test.ts @@ -25,6 +25,8 @@ import {GraphNode} from '../operations/types'; import {GraphModel, loadGraphModel, loadGraphModelSync} from './graph_model'; import {HASH_TABLE_MODEL_V2} from './test_data/hash_table_v2_model_loader'; import {STRUCTURED_OUTPUTS_MODEL} from './test_data/structured_outputs_model_loader'; +// tslint:disable-next-line: no-imports-from-dist +import {expectArrayBuffersEqual} from '@tensorflow/tfjs-core/dist/test_util'; const HOST = 'http://example.org'; const MODEL_URL = `${HOST}/model.json`; @@ -125,6 +127,24 @@ const SIMPLE_HTTP_MODEL_LOADER = { } }; +const SIMPLE_STREAMING_MODEL_LOADER = { + load: async () => { + return { + modelTopology: SIMPLE_MODEL, + weightSpecs: weightsManifest, + getWeightStream: () => { + const data = bias.dataSync(); + const blob = new Blob([data]); + return blob.stream(); + }, + format: 'tfjs-graph-model', + generatedBy: '1.15', + convertedBy: '1.3.1', + userDefinedMetadata: {signature: SIGNATURE} + }; + } +}; + const NO_INPUT_SIGNATURE_MODEL_LOADER = { load: async () => { return { @@ -438,7 +458,7 @@ describe('loadGraphModel', () => { }); it('Pass a fetchFunc', async () => { - const fetchFunc = () => {}; + const fetchFunc = (() => {}) as unknown as typeof fetch; spyIo.getLoadHandlers.and.returnValue([CUSTOM_HTTP_MODEL_LOADER]); await loadGraphModel(MODEL_URL, {fetchFunc}, spyIo); expect(spyIo.getLoadHandlers).toHaveBeenCalledWith(MODEL_URL, {fetchFunc}); @@ -594,7 +614,13 @@ describe('Model', () => { describe('simple model', () => { beforeEach(() => { - spyIo.getLoadHandlers.and.returnValue([SIMPLE_HTTP_MODEL_LOADER]); + spyIo.getLoadHandlers.and.callFake((_url: string|string[], + loadOptions?: io.LoadOptions) => { + if (loadOptions.streamWeights) { + return [SIMPLE_STREAMING_MODEL_LOADER]; + } + return [SIMPLE_HTTP_MODEL_LOADER]; + }); spyIo.browserHTTPRequest.and.returnValue(SIMPLE_HTTP_MODEL_LOADER); }); it('load', async () => { @@ -776,6 +802,14 @@ describe('Model', () => { expect(model).toBeDefined(); }); + it('should stream graph model weights', async () => { + const model = await loadGraphModel(MODEL_URL, {streamWeights: true}, + spyIo); + expect(model).toBeDefined(); + expectArrayBuffersEqual(model.weights['Const'][0].dataSync(), + bias.dataSync()); + }); + describe('InferenceModel interface', () => { it('should expose inputs', async () => { await model.load(); diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index c30ce501dd3..a8ba2da62ca 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -27,8 +27,8 @@ import {assert} from '../util'; import {getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils'; import {CompositeArrayBuffer} from './composite_array_buffer'; import {IORouter, IORouterRegistry} from './router_registry'; -import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types'; -import {loadWeightsAsArrayBuffer} from './weights_loader'; +import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {loadWeightsAsArrayBuffer, streamWeights} from './weights_loader'; const OCTET_STREAM_MIME_TYPE = 'application/octet-stream'; const JSON_TYPE = 'application/json'; @@ -36,7 +36,7 @@ export class HTTPRequest implements IOHandler { protected readonly path: string; protected readonly requestInit: RequestInit; - private readonly fetch: Function; + private readonly fetch: typeof fetch; private readonly weightUrlConverter: (weightName: string) => Promise; readonly DEFAULT_METHOD = 'POST'; @@ -44,14 +44,13 @@ export class HTTPRequest implements IOHandler { static readonly URL_SCHEME_REGEX = /^https?:\/\//; private readonly weightPathPrefix: string; - private readonly onProgress: OnProgressCallback; + private readonly loadOptions: LoadOptions; constructor(path: string, loadOptions?: LoadOptions) { if (loadOptions == null) { loadOptions = {}; } this.weightPathPrefix = loadOptions.weightPathPrefix; - this.onProgress = loadOptions.onProgress; this.weightUrlConverter = loadOptions.weightUrlConverter; if (loadOptions.fetchFunc != null) { @@ -84,6 +83,7 @@ export class HTTPRequest implements IOHandler { 'requestInit is expected to have no pre-existing body, but has one.'); } this.requestInit = loadOptions.requestInit || {}; + this.loadOptions = loadOptions; } async save(modelArtifacts: ModelArtifacts): Promise { @@ -135,15 +135,7 @@ export class HTTPRequest implements IOHandler { } } - /** - * Load model artifacts via HTTP request(s). - * - * See the documentation to `tf.io.http` for details on the saved - * artifacts. - * - * @returns The loaded model artifacts (if loading succeeds). - */ - async load(): Promise { + private async loadModelJSON(): Promise { const modelConfigRequest = await this.fetch(this.path, this.requestInit); if (!modelConfigRequest.ok) { @@ -182,18 +174,45 @@ export class HTTPRequest implements IOHandler { `topology or manifest for weights.`); } + return modelJSON; + } + + /** + * Load model artifacts via HTTP request(s). + * + * See the documentation to `tf.io.http` for details on the saved + * artifacts. + * + * @returns The loaded model artifacts (if loading succeeds). + */ + async load(): Promise { + if (this.loadOptions.streamWeights) { + return this.loadStream(); + } + const modelJSON = await this.loadModelJSON(); return getModelArtifactsForJSON( modelJSON, (weightsManifest) => this.loadWeights(weightsManifest)); } - private async loadWeights(weightsManifest: WeightsManifestConfig): - Promise<[WeightsManifestEntry[], WeightData]> { + private async loadStream(): Promise { + const modelJSON = await this.loadModelJSON(); + const fetchURLs = await this.getWeightUrls(modelJSON.weightsManifest); + const weightSpecs = getWeightSpecs(modelJSON.weightsManifest); + const stream = () => streamWeights(fetchURLs, this.loadOptions); + + return { + ...modelJSON, + weightSpecs, + getWeightStream: stream, + }; + } + + private async getWeightUrls(weightsManifest: WeightsManifestConfig): + Promise { const weightPath = Array.isArray(this.path) ? this.path[1] : this.path; const [prefix, suffix] = parseUrl(weightPath); const pathPrefix = this.weightPathPrefix || prefix; - const weightSpecs = getWeightSpecs(weightsManifest); - const fetchURLs: string[] = []; const urlPromises: Array> = []; for (const weightsGroup of weightsManifest) { @@ -209,12 +228,15 @@ export class HTTPRequest implements IOHandler { if (this.weightUrlConverter) { fetchURLs.push(...await Promise.all(urlPromises)); } + return fetchURLs; + } + + private async loadWeights(weightsManifest: WeightsManifestConfig): + Promise<[WeightsManifestEntry[], WeightData]> { + const fetchURLs = await this.getWeightUrls(weightsManifest); + const weightSpecs = getWeightSpecs(weightsManifest); - const buffers = await loadWeightsAsArrayBuffer(fetchURLs, { - requestInit: this.requestInit, - fetchFunc: this.fetch, - onProgress: this.onProgress - }); + const buffers = await loadWeightsAsArrayBuffer(fetchURLs, this.loadOptions); return [weightSpecs, buffers]; } } diff --git a/tfjs-core/src/io/io.ts b/tfjs-core/src/io/io.ts index 49e9a1e2e06..3c1c8724e11 100644 --- a/tfjs-core/src/io/io.ts +++ b/tfjs-core/src/io/io.ts @@ -22,7 +22,7 @@ import './local_storage'; import {browserFiles} from './browser_files'; import {browserHTTPRequest, http, isHTTPScheme} from './http'; -import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils'; +import {concatenateArrayBuffers, decodeWeights, decodeWeightsStream, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils'; import {fromMemory, fromMemorySync, withSaveHandler, withSaveHandlerSync} from './passthrough'; import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry'; import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry, WeightData} from './types'; @@ -36,6 +36,7 @@ export { CompositeArrayBuffer, concatenateArrayBuffers, decodeWeights, + decodeWeightsStream, encodeWeights, fromMemory, fromMemorySync, diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index fa9005a9ba8..25dbe5fd46d 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -23,6 +23,11 @@ import {sizeFromShape} from '../util'; import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, ModelJSON, WeightData, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types'; import {CompositeArrayBuffer} from './composite_array_buffer'; +import {Tensor} from '../tensor'; +import {backend} from '../globals'; +import {DataId} from '../tensor_info'; +import {env} from '../environment'; +import {getBackend} from '../globals'; /** Number of bytes reserved for the length of the string. (32bit integer). */ const NUM_BYTES_STRING_LENGTH = 4; @@ -117,120 +122,234 @@ export function decodeWeights( // TODO(adarob, cais): Support quantization. const compositeBuffer = new CompositeArrayBuffer(weightData); const out: NamedTensorMap = {}; - let float16Decode: (buffer: Uint16Array) => Float32Array | undefined; let offset = 0; for (const spec of specs) { - const name = spec.name; - const dtype = spec.dtype; - const shape = spec.shape; - const size = sizeFromShape(shape); - let values: TypedArray|string[]|Uint8Array[]; - - if ('quantization' in spec) { - const quantization = spec.quantization; - if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { - if (!('min' in quantization && 'scale' in quantization)) { - throw new Error( - `Weight ${spec.name} with quantization ${quantization.dtype} ` + - `doesn't have corresponding metadata min and scale.`); - } - } else if (quantization.dtype === 'float16') { - if (dtype !== 'float32') { - throw new Error( - `Weight ${spec.name} is quantized with ${quantization.dtype} ` + - `which only supports weights of type float32 not ${dtype}.`); - } - } else { + const byteLength = getWeightBytelength(spec, (start, end) => { + return compositeBuffer.slice(offset + start, offset + end); + }); + out[spec.name] = decodeWeight(spec, compositeBuffer + .slice(offset, offset + byteLength)); + offset += byteLength; + } + return out; +} + +function getWeightBytelength(spec: WeightsManifestEntry, + slice: (start: number, end: number) => ArrayBuffer): number { + + const size = sizeFromShape(spec.shape); + let bytesPerValue: number; + if ('quantization' in spec) { + const quantization = spec.quantization; + bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; + } else if (spec.dtype === 'string') { + // Can not statically determine string length. + let byteLength = 0; + for (let i = 0; i < size; i++) { + byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array( + slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0]; + } + return byteLength; + } else { + bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype]; + } + + return size * bytesPerValue; +} + +async function getWeightBytelengthAsync( + spec: WeightsManifestEntry, + slice: (start: number, end: number) => Promise +): Promise { + + const size = sizeFromShape(spec.shape); + let bytesPerValue: number; + if ('quantization' in spec) { + const quantization = spec.quantization; + bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; + } else if (spec.dtype === 'string') { + // Can not statically determine string length. + let byteLength = 0; + for (let i = 0; i < size; i++) { + byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array( + await slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0]; + } + return byteLength; + } else { + bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype]; + } + + return size * bytesPerValue; +} + +function decodeWeight( + spec: WeightsManifestEntry, + byteBuffer: ArrayBuffer): Tensor { + + const name = spec.name; + const dtype = spec.dtype; + const shape = spec.shape; + const size = sizeFromShape(shape); + let values: TypedArray | string[] | Uint8Array[]; + let offset = 0; + + if ('quantization' in spec) { + const quantization = spec.quantization; + if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { + if (!('min' in quantization && 'scale' in quantization)) { throw new Error( - `Weight ${spec.name} has unknown ` + - `quantization dtype ${quantization.dtype}. ` + - `Supported quantization dtypes are: ` + - `'uint8', 'uint16', and 'float16'.`); + `Weight ${spec.name} with quantization ${quantization.dtype} ` + + `doesn't have corresponding metadata min and scale.`); } - const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; - const byteBuffer = - compositeBuffer.slice(offset, offset + size * quantizationSizeFactor); - const quantizedArray = (quantization.dtype === 'uint8') ? - new Uint8Array(byteBuffer) : - new Uint16Array(byteBuffer); - if (dtype === 'float32') { - if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { - values = new Float32Array(quantizedArray.length); - for (let i = 0; i < quantizedArray.length; i++) { - const v = quantizedArray[i]; - values[i] = v * quantization.scale + quantization.min; - } - } else if (quantization.dtype === 'float16') { - if (float16Decode === undefined) { - float16Decode = getFloat16Decoder(); - } - values = float16Decode(quantizedArray as Uint16Array); - } else { - throw new Error( - `Unsupported quantization type ${quantization.dtype} ` + - `for weight type float32.`); - } - } else if (dtype === 'int32') { - if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') { - throw new Error( - `Unsupported quantization type ${quantization.dtype} ` + - `for weight type int32.`); - } - values = new Int32Array(quantizedArray.length); + } else if (quantization.dtype === 'float16') { + if (dtype !== 'float32') { + throw new Error( + `Weight ${spec.name} is quantized with ${quantization.dtype} ` + + `which only supports weights of type float32 not ${dtype}.`); + } + } else { + throw new Error( + `Weight ${spec.name} has unknown ` + + `quantization dtype ${quantization.dtype}. ` + + `Supported quantization dtypes are: ` + + `'uint8', 'uint16', and 'float16'.`); + } + const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; + const quantizedArray = (quantization.dtype === 'uint8') ? + new Uint8Array(byteBuffer) : + new Uint16Array(byteBuffer); + if (dtype === 'float32') { + if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { + values = new Float32Array(quantizedArray.length); for (let i = 0; i < quantizedArray.length; i++) { const v = quantizedArray[i]; - values[i] = Math.round(v * quantization.scale + quantization.min); + values[i] = v * quantization.scale + quantization.min; } + } else if (quantization.dtype === 'float16') { + // TODO: This is inefficient. Make getFloat16Decoder efficient. + const float16Decode = getFloat16Decoder(); + values = float16Decode(quantizedArray as Uint16Array); } else { - throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); + throw new Error( + `Unsupported quantization type ${quantization.dtype} ` + + `for weight type float32.`); } - offset += size * quantizationSizeFactor; - } else if (dtype === 'string') { - const size = sizeFromShape(spec.shape); - values = []; - for (let i = 0; i < size; i++) { - const byteLength = new Uint32Array( - compositeBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; - offset += NUM_BYTES_STRING_LENGTH; - const bytes = new Uint8Array( - compositeBuffer.slice(offset, offset + byteLength)); - (values as Uint8Array[]).push(bytes); - offset += byteLength; + } else if (dtype === 'int32') { + if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') { + throw new Error( + `Unsupported quantization type ${quantization.dtype} ` + + `for weight type int32.`); + } + values = new Int32Array(quantizedArray.length); + for (let i = 0; i < quantizedArray.length; i++) { + const v = quantizedArray[i]; + values[i] = Math.round(v * quantization.scale + quantization.min); } } else { - const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; - const byteBuffer = compositeBuffer.slice(offset, - offset + size * dtypeFactor); - - if (dtype === 'float32') { - values = new Float32Array(byteBuffer); - } else if (dtype === 'int32') { - values = new Int32Array(byteBuffer); - } else if (dtype === 'bool') { - values = new Uint8Array(byteBuffer); - } else if (dtype === 'complex64') { - values = new Float32Array(byteBuffer); - const real = new Float32Array(values.length / 2); - const image = new Float32Array(values.length / 2); - for (let i = 0; i < real.length; i++) { - real[i] = values[i * 2]; - image[i] = values[i * 2 + 1]; - } - const realTensor = tensor(real, shape, 'float32'); - const imageTensor = tensor(image, shape, 'float32'); - out[name] = complex(realTensor, imageTensor); - realTensor.dispose(); - imageTensor.dispose(); - } else { - throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); + throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); + } + offset += size * quantizationSizeFactor; + } else if (dtype === 'string') { + const size = sizeFromShape(spec.shape); + values = []; + for (let i = 0; i < size; i++) { + const byteLength = new Uint32Array( + byteBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; + offset += NUM_BYTES_STRING_LENGTH; + const bytes = new Uint8Array( + byteBuffer.slice(offset, offset + byteLength)); + (values as Uint8Array[]).push(bytes); + offset += byteLength; + } + } else { + const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; + if (dtype === 'float32') { + values = new Float32Array(byteBuffer); + } else if (dtype === 'int32') { + values = new Int32Array(byteBuffer); + } else if (dtype === 'bool') { + values = new Uint8Array(byteBuffer); + } else if (dtype === 'complex64') { + values = new Float32Array(byteBuffer); + const real = new Float32Array(values.length / 2); + const image = new Float32Array(values.length / 2); + for (let i = 0; i < real.length; i++) { + real[i] = values[i * 2]; + image[i] = values[i * 2 + 1]; } - offset += size * dtypeFactor; + const realTensor = tensor(real, shape, 'float32'); + const imageTensor = tensor(image, shape, 'float32'); + const complexTensor = complex(realTensor, imageTensor); + realTensor.dispose(); + imageTensor.dispose(); + return complexTensor; + } else { + throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } - if (dtype !== 'complex64') { - out[name] = tensor(values, shape, dtype); + offset += size * dtypeFactor; + } + return tensor(values, shape, dtype); +} + +async function readToLength(reader: ReadableStreamDefaultReader, + initialData: ArrayBuffer, + length: number): Promise { + let data = new Uint8Array(initialData); + + while (data.byteLength < length) { + const {done, value} = await reader.read(); + if (done && value == null) { + const missing = length - data.byteLength; + throw new Error(`Reader is done but ${missing} bytes are still expected`); } + + // TODO: Don't create a new array every loop. + const newData = new Uint8Array(data.length + value.byteLength); + newData.set(data, 0); + newData.set(new Uint8Array(value), data.length); + data = newData; } - return out; + + return data.buffer; +} + +export async function decodeWeightsStream( + weightStream: ReadableStream, + specs: WeightsManifestEntry[]): Promise { + + const tensors: NamedTensorMap = {}; + const reader = weightStream.getReader(); + let data = new ArrayBuffer(0); + + for (const spec of specs) { + const byteLength = await getWeightBytelengthAsync(spec, + async (start, end) => { + data = await readToLength(reader, data, end); + return data.slice(start, end); + }); + data = await readToLength(reader, data, byteLength); + + // Slice the tensor out + const tensorData = data.slice(0, byteLength); + data = data.slice(byteLength); + + const weightTensor = decodeWeight(spec, tensorData); + tensors[spec.name] = weightTensor; + + // TODO(mattsoulanille): Better way to call uploadToGPU. + // TODO(mattsoulanille): Make this work for webgl too. + if (getBackend() === 'webgpu') { + const b = backend(); + + if ('uploadToGPU' in b && + sizeFromShape(weightTensor.shape) >= (env() + .get('WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD') as number)) { + (b.uploadToGPU as (dataId: DataId) => void)(weightTensor.dataId); + } + } + } + + return tensors; } /** diff --git a/tfjs-core/src/io/io_utils_test.ts b/tfjs-core/src/io/io_utils_test.ts index 01e497c075b..6d710288537 100644 --- a/tfjs-core/src/io/io_utils_test.ts +++ b/tfjs-core/src/io/io_utils_test.ts @@ -469,118 +469,153 @@ describeWithFlags('encodeWeights', ALL_ENVS, () => { }); describeWithFlags('decodeWeights', {}, () => { - it('Mixed dtype tensors', async () => { - const tensors: NamedTensorMap = { - x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), - x2: scalar(13.37, 'float32'), - x3: tensor1d([true, false, false], 'bool'), - x4: tensor2d([['здраво', 'a'], ['b', 'c']], [2, 2], 'string'), - x5: tensor1d([''], 'string'), // Empty string. - x6: scalar('hello'), // Single string. - y1: tensor2d([-10, -20, -30], [3, 1], 'float32'), - y2: tf.complex([1, 1], [2, 2]) - }; - const dataAndSpecs = await tf.io.encodeWeights(tensors); - const data = dataAndSpecs.data; - const specs = dataAndSpecs.specs; - const decoded = tf.io.decodeWeights(data, specs); - expect(Object.keys(decoded).length).toEqual(8); - expectArraysEqual(await decoded['x1'].data(), await tensors['x1'].data()); - expectArraysEqual(await decoded['x2'].data(), await tensors['x2'].data()); - expectArraysEqual(await decoded['x3'].data(), await tensors['x3'].data()); - expectArraysEqual(await decoded['x4'].data(), await tensors['x4'].data()); - expectArraysEqual(await decoded['x5'].data(), await tensors['x5'].data()); - expectArraysEqual(await decoded['x6'].data(), await tensors['x6'].data()); - expectArraysEqual(await decoded['y1'].data(), await tensors['y1'].data()); - expectArraysEqual(await decoded['y2'].data(), await tensors['y2'].data()); - }); - - it('Unsupported dtype raises Error', () => { - const buffer = new ArrayBuffer(4); - // tslint:disable-next-line:no-any - const specs: any = [ - { - name: 'x', - dtype: 'int16', - shape: [], - }, - {name: 'y', dtype: 'int16', shape: []} - ]; - expect(() => tf.io.decodeWeights(buffer, specs)) - .toThrowError(/Unsupported dtype in weight \'x\': int16/); - }); - - it('support quantization uint8 weights', async () => { - const manifestSpecs: WeightsManifestEntry[] = [ - { - 'name': 'weight0', - 'dtype': 'float32', - 'shape': [3], - 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint8'} - }, - { - 'name': 'weight1', - 'dtype': 'int32', - 'shape': [3], - 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint8'} + function toStream(buffer: ArrayBuffer): ReadableStream { + let position = 0; + const chunkSize = 14; // something relatively small for testing + return new ReadableStream({ + pull: (controller) => { + if (position < buffer.byteLength) { + const chunk = buffer.slice(position, position + chunkSize); + position += chunkSize; + controller.enqueue(chunk); + } else { + controller.close(); + } } - ]; - const data = new Uint8Array([0, 48, 255, 0, 48, 255]); - const decoded = tf.io.decodeWeights(data.buffer, manifestSpecs); - const weight0 = decoded['weight0']; - expectArraysClose(await weight0.data(), [-1, 3.8, 24.5]); - expect(weight0.shape).toEqual([3]); - expect(weight0.dtype).toEqual('float32'); + }); + } + + async function decodeAsBuffer(data: ArrayBuffer, + specs: tf.io.WeightsManifestEntry[]) { + const result = tf.io.decodeWeights(data, specs); + // Make sure it doesn't return a promise. + expect(result).not.toBeInstanceOf(Promise); + // Wrap it in a promise to work with the tests. + return Promise.resolve(result); + } + + async function decodeAsStream(data: ArrayBuffer, + specs: tf.io.WeightsManifestEntry[]) { + return tf.io.decodeWeightsStream(toStream(data), specs); + } + + for (const [name, decode] of [['from arraybuffer', decodeAsBuffer], + ['from stream', decodeAsStream]] as const) { + describe(name, () => { + it('Mixed dtype tensors', async () => { + const tensors: NamedTensorMap = { + x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), + x2: scalar(13.37, 'float32'), + x3: tensor1d([true, false, false], 'bool'), + x4: tensor2d([['здраво', 'a'], ['b', 'c']], [2, 2], 'string'), + x5: tensor1d([''], 'string'), // Empty string. + x6: scalar('hello'), // Single string. + y1: tensor2d([-10, -20, -30], [3, 1], 'float32'), + y2: tf.complex([1, 1], [2, 2]) + }; + const dataAndSpecs = await tf.io.encodeWeights(tensors); + const data = dataAndSpecs.data; + const specs = dataAndSpecs.specs; + const res = await decode(data, specs); + expect(Object.keys(res).length).toEqual(8); + expectArraysEqual(await res['x1'].data(), await tensors['x1'].data()); + expectArraysEqual(await res['x2'].data(), await tensors['x2'].data()); + expectArraysEqual(await res['x3'].data(), await tensors['x3'].data()); + expectArraysEqual(await res['x4'].data(), await tensors['x4'].data()); + expectArraysEqual(await res['x5'].data(), await tensors['x5'].data()); + expectArraysEqual(await res['x6'].data(), await tensors['x6'].data()); + expectArraysEqual(await res['y1'].data(), await tensors['y1'].data()); + expectArraysEqual(await res['y2'].data(), await tensors['y2'].data()); + }); - const weight1 = decoded['weight1']; - expectArraysEqual(await weight1.data(), [-1, 4, 25]); - expect(weight1.shape).toEqual([3]); - expect(weight1.dtype).toEqual('int32'); - }); + it('Unsupported dtype raises Error', async () => { + const buffer = new ArrayBuffer(4); + // tslint:disable-next-line:no-any + const specs: any = [ + { + name: 'x', + dtype: 'int16', + shape: [], + }, + {name: 'y', dtype: 'int16', shape: []} + ]; + await expectAsync(decode(buffer, specs)) + .toBeRejectedWithError(/Unsupported dtype in weight \'x\': int16/); + }); - it('support quantization uint16 weights', async () => { - const manifestSpecs: WeightsManifestEntry[] = [ - { - 'name': 'weight0', - 'dtype': 'float32', - 'shape': [3], - 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint16'} - }, - { - 'name': 'weight1', - 'dtype': 'int32', - 'shape': [3], - 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint16'} - } - ]; - const data = new Uint16Array([0, 48, 255, 0, 48, 255]); - const decoded = tf.io.decodeWeights(data.buffer, manifestSpecs); - const weight0 = decoded['weight0']; - expectArraysClose(await weight0.data(), [-1, 3.8, 24.5]); - expect(weight0.shape).toEqual([3]); - expect(weight0.dtype).toEqual('float32'); - - const weight1 = decoded['weight1']; - expectArraysEqual(await weight1.data(), [-1, 4, 25]); - expect(weight1.shape).toEqual([3]); - expect(weight1.dtype).toEqual('int32'); - }); - it('support quantization float16 weights', async () => { - const manifestSpecs: WeightsManifestEntry[] = [ - { - name: 'weight0', - dtype: 'float32', - shape: [3], - quantization: { dtype: 'float16' }, - }, - ]; - const data = new Uint16Array([13312, 14336, 14848]); - const decoded = tf.io.decodeWeights(data.buffer, manifestSpecs); - const weight0 = decoded['weight0']; - expectArraysClose(await weight0.data(), [0.25, 0.5, 0.75]); - expect(weight0.shape).toEqual([3]); - expect(weight0.dtype).toEqual('float32'); - }); + it('support quantization uint8 weights', async () => { + const manifestSpecs: WeightsManifestEntry[] = [ + { + 'name': 'weight0', + 'dtype': 'float32', + 'shape': [3], + 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint8'} + }, + { + 'name': 'weight1', + 'dtype': 'int32', + 'shape': [3], + 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint8'} + } + ]; + const data = new Uint8Array([0, 48, 255, 0, 48, 255]); + const decoded = await decode(data.buffer, manifestSpecs); + const weight0 = decoded['weight0']; + expectArraysClose(await weight0.data(), [-1, 3.8, 24.5]); + expect(weight0.shape).toEqual([3]); + expect(weight0.dtype).toEqual('float32'); + + const weight1 = decoded['weight1']; + expectArraysEqual(await weight1.data(), [-1, 4, 25]); + expect(weight1.shape).toEqual([3]); + expect(weight1.dtype).toEqual('int32'); + }); + + it('support quantization uint16 weights', async () => { + const manifestSpecs: WeightsManifestEntry[] = [ + { + 'name': 'weight0', + 'dtype': 'float32', + 'shape': [3], + 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint16'} + }, + { + 'name': 'weight1', + 'dtype': 'int32', + 'shape': [3], + 'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint16'} + } + ]; + const data = new Uint16Array([0, 48, 255, 0, 48, 255]); + const decoded = await decode(data.buffer, manifestSpecs); + const weight0 = decoded['weight0']; + expectArraysClose(await weight0.data(), [-1, 3.8, 24.5]); + expect(weight0.shape).toEqual([3]); + expect(weight0.dtype).toEqual('float32'); + + const weight1 = decoded['weight1']; + expectArraysEqual(await weight1.data(), [-1, 4, 25]); + expect(weight1.shape).toEqual([3]); + expect(weight1.dtype).toEqual('int32'); + }); + it('support quantization float16 weights', async () => { + const manifestSpecs: WeightsManifestEntry[] = [ + { + name: 'weight0', + dtype: 'float32', + shape: [3], + quantization: { dtype: 'float16' }, + }, + ]; + const data = new Uint16Array([13312, 14336, 14848]); + const decoded = await decode(data.buffer, manifestSpecs); + const weight0 = decoded['weight0']; + expectArraysClose(await weight0.data(), [0.25, 0.5, 0.75]); + expect(weight0.shape).toEqual([3]); + expect(weight0.dtype).toEqual('float32'); + }); + }); + } }); describe('stringByteLength', () => { diff --git a/tfjs-core/src/io/progress.ts b/tfjs-core/src/io/progress.ts index 8d6b3d7fa8a..73e1e19d54c 100644 --- a/tfjs-core/src/io/progress.ts +++ b/tfjs-core/src/io/progress.ts @@ -27,8 +27,8 @@ import {OnProgressCallback} from './types'; * @param startFraction Optional fraction start. Default to 0. * @param endFraction Optional fraction end. Default to 1. */ -export function monitorPromisesProgress( - promises: Array>, onProgress: OnProgressCallback, +export function monitorPromisesProgress( + promises: Array>, onProgress: OnProgressCallback, startFraction?: number, endFraction?: number) { checkPromises(promises); startFraction = startFraction == null ? 0 : startFraction; @@ -36,7 +36,7 @@ export function monitorPromisesProgress( checkFraction(startFraction, endFraction); let resolvedPromise = 0; - const registerMonitor = (promise: Promise<{}>) => { + const registerMonitor = (promise: Promise) => { promise.then(value => { const fraction = startFraction + ++resolvedPromise / promises.length * (endFraction - startFraction); @@ -47,7 +47,7 @@ export function monitorPromisesProgress( return promise; }; - function checkPromises(promises: Array>): void { + function checkPromises(promises: Array>): void { assert( promises != null && Array.isArray(promises) && promises.length > 0, () => 'promises must be a none empty array'); diff --git a/tfjs-core/src/io/router_registry_test.ts b/tfjs-core/src/io/router_registry_test.ts index 834e8e3d10e..079a03a5602 100644 --- a/tfjs-core/src/io/router_registry_test.ts +++ b/tfjs-core/src/io/router_registry_test.ts @@ -136,7 +136,7 @@ describeWithFlags('IORouterRegistry', BROWSER_ENVS, () => { const loadOptions: LoadOptions = { onProgress: (fraction: number) => {}, - fetchFunc: () => {} + fetchFunc: ((() => {}) as unknown as typeof fetch), }; const loadHandler = tf.io.getLoadHandlers('foo:///123', loadOptions); expect(loadHandler.length).toEqual(1); diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts index 2dc0893a82f..177884f2ef1 100644 --- a/tfjs-core/src/io/types.ts +++ b/tfjs-core/src/io/types.ts @@ -257,6 +257,13 @@ export declare interface ModelArtifacts { */ weightData?: WeightData; + /** + * Returns a stream of the weights. Some models are too large to fit in + * V8's memory heap, and `getWeightStream` loads their weights without storing + * them all in memory at the same time. + */ + getWeightStream?: () => ReadableStream; + /** * Hard-coded format name for models saved from TensorFlow.js or converted * by TensorFlow.js Converter. @@ -482,7 +489,7 @@ export interface LoadOptions { /** * A function used to override the `window.fetch` function. */ - fetchFunc?: Function; + fetchFunc?: typeof fetch; /** * Strict loading model: whether extraneous weights or missing @@ -532,6 +539,12 @@ export interface LoadOptions { * With this func you can convert the weight file name to any URL. */ weightUrlConverter?: (weightFileName: string) => Promise; + + /** + * Whether to stream the model directly to the backend or cache all its + * weights on CPU first. Useful for large models. + */ + streamWeights?: boolean; } /** diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 8ad0ef2f85b..9a09a798c45 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -71,6 +71,40 @@ export async function loadWeightsAsArrayBuffer( return buffers; } +export function streamWeights(fetchURLs: string[], loadOptions: LoadOptions): ReadableStream { + const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : + loadOptions.fetchFunc; + + let fetchIndex = 0; + let chunkReader: ReadableStreamDefaultReader | undefined; + loadOptions.onProgress?.(0); + return new ReadableStream({ + pull: async (controller) => { + while (fetchIndex < fetchURLs.length) { + if (!chunkReader) { + const body = (await fetchFunc(fetchURLs[fetchIndex], + loadOptions.requestInit, + {isBinary: true})).body; + + chunkReader = body.getReader(); + } + + const {done, value} = await chunkReader.read(); + + if (done) { + fetchIndex++; + chunkReader = undefined; + loadOptions.onProgress?.(fetchIndex / fetchURLs.length); + continue; + } + controller.enqueue(value); + return; + } + controller.close(); + }, + }); +} + /** * Reads a weights manifest JSON configuration, fetches the weights and * returns them as `Tensor`s. diff --git a/tfjs/yarn.lock b/tfjs/yarn.lock index efddf89820c..5dbc13b3f00 100644 --- a/tfjs/yarn.lock +++ b/tfjs/yarn.lock @@ -1976,14 +1976,6 @@ resolved "https://registry.yarnpkg.com/@types/minimatch/-/minimatch-3.0.4.tgz#f0ec25dbf2f0e4b18647313ac031134ca5b24b21" integrity sha512-1z8k4wzFnNjVK/tlxvrWuK5WMt6mydWWP7+zvH5eFep4oj+UkrfiJTRtjCeBXNpwaA/FYqqtb4/QS4ianFpIRA== -"@types/node-fetch@^2.1.2": - version "2.6.4" - resolved "https://registry.yarnpkg.com/@types/node-fetch/-/node-fetch-2.6.4.tgz#1bc3a26de814f6bf466b25aeb1473fa1afe6a660" - integrity sha512-1ZX9fcN4Rvkvgv4E6PAY5WXUFWFcRWxZa3EW83UjycOB9ljJCedb2CupIP4RZMEwF/M3eTcCihbBRgwtGbg5Rg== - dependencies: - "@types/node" "*" - form-data "^3.0.0" - "@types/node@*", "@types/node@>=10.0.0": version "18.11.9" resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.9.tgz#02d013de7058cea16d36168ef2fc653464cfbad4" @@ -2153,11 +2145,6 @@ async@^3.0.1: resolved "https://registry.yarnpkg.com/async/-/async-3.2.0.tgz#b3a2685c5ebb641d3de02d161002c60fc9f85720" integrity sha512-TR2mEZFVOj2pLStYxLht7TyfuRzaydfpxr3k9RpHIzMgw7A64dzsdqCxH1WJyQdoe8T10nDXd9wnEigmiuHIZw== -asynckit@^0.4.0: - version "0.4.0" - resolved "https://registry.yarnpkg.com/asynckit/-/asynckit-0.4.0.tgz#c79ed97f7f34cb8f2ba1bc9790bcc366474b4b79" - integrity sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q== - available-typed-arrays@^1.0.2: version "1.0.2" resolved "https://registry.yarnpkg.com/available-typed-arrays/-/available-typed-arrays-1.0.2.tgz#6b098ca9d8039079ee3f77f7b783c4480ba513f5" @@ -2586,13 +2573,6 @@ combine-source-map@^0.8.0: lodash.memoize "~3.0.3" source-map "~0.5.3" -combined-stream@^1.0.8: - version "1.0.8" - resolved "https://registry.yarnpkg.com/combined-stream/-/combined-stream-1.0.8.tgz#c3d45a8b34fd730631a110a8a2520682b31d5a7f" - integrity sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg== - dependencies: - delayed-stream "~1.0.0" - commander@^2.12.1, commander@^2.20.0: version "2.20.3" resolved "https://registry.yarnpkg.com/commander/-/commander-2.20.3.tgz#fd485e84c03eb4881c20722ba48035e8531aeb33" @@ -2802,11 +2782,6 @@ define-properties@^1.1.3: dependencies: object-keys "^1.0.12" -delayed-stream@~1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/delayed-stream/-/delayed-stream-1.0.0.tgz#df3ae199acadfb7d440aaae0b29e2272b24ec619" - integrity sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ== - depd@~1.1.2: version "1.1.2" resolved "https://registry.yarnpkg.com/depd/-/depd-1.1.2.tgz#9bcd52e14c097763e749b274c4346ed2e560b5a9" @@ -3089,15 +3064,6 @@ foreach@^2.0.5: resolved "https://registry.yarnpkg.com/foreach/-/foreach-2.0.5.tgz#0bee005018aeb260d0a3af3ae658dd0136ec1b99" integrity sha1-C+4AUBiusmDQo6865ljdATbsG5k= -form-data@^3.0.0: - version "3.0.1" - resolved "https://registry.yarnpkg.com/form-data/-/form-data-3.0.1.tgz#ebd53791b78356a99af9a300d4282c4d5eb9755f" - integrity sha512-RHkBKtLWUVwd7SqRIvCZMEvAMoGUp0XU+seQiZejj0COz3RI3hWP4sCv3gZWWLjJTd7rGwcsF5eKZGii0r/hbg== - dependencies: - asynckit "^0.4.0" - combined-stream "^1.0.8" - mime-types "^2.1.12" - from@~0: version "0.1.7" resolved "https://registry.yarnpkg.com/from/-/from-0.1.7.tgz#83c60afc58b9c56997007ed1a768b3ab303a44fe" @@ -3927,13 +3893,6 @@ mime-db@1.52.0: resolved "https://registry.yarnpkg.com/mime-db/-/mime-db-1.52.0.tgz#bbabcdc02859f4987301c856e3387ce5ec43bf70" integrity sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg== -mime-types@^2.1.12, mime-types@~2.1.34: - version "2.1.35" - resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.35.tgz#381a871b62a734450660ae3deee44813f70d959a" - integrity sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw== - dependencies: - mime-db "1.52.0" - mime-types@~2.1.24: version "2.1.34" resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.34.tgz#5a712f9ec1503511a945803640fafe09d3793c24" @@ -3941,6 +3900,13 @@ mime-types@~2.1.24: dependencies: mime-db "1.51.0" +mime-types@~2.1.34: + version "2.1.35" + resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.35.tgz#381a871b62a734450660ae3deee44813f70d959a" + integrity sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw== + dependencies: + mime-db "1.52.0" + mime@^2.5.2: version "2.6.0" resolved "https://registry.yarnpkg.com/mime/-/mime-2.6.0.tgz#a2a682a95cd4d0cb1d6257e28f83da7e35800367"