diff --git a/tfjs-data/src/readers.ts b/tfjs-data/src/readers.ts index ed93a55577c..070874c4f5a 100644 --- a/tfjs-data/src/readers.ts +++ b/tfjs-data/src/readers.ts @@ -140,14 +140,12 @@ export function func( /** * Create a `Dataset` that produces each element from provided JavaScript - * generator, which is a function* - * (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions), - * or a function that returns an - * iterator - * (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions). + * generator, which is a function that returns a (potentially async) iterator. * - * The returned iterator should have `.next()` function that returns element in - * format of `{value: TensorContainer, done:boolean}`. + * For more information on iterators and generators, see + * https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators . + * For the iterator protocol, see + * https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Iteration_protocols . * * Example of creating a dataset from an iterator factory: * ```js @@ -188,8 +186,8 @@ export function func( * await ds.forEachAsync(e => console.log(e)); * ``` * - * @param generator A JavaScript generator function that returns a JavaScript - * iterator. + * @param generator A JavaScript function that returns + * a (potentially async) JavaScript iterator. * * @doc { * heading: 'Data', @@ -199,7 +197,8 @@ export function func( * } */ export function generator( - generator: () => Iterator| Promise>): Dataset { + generator: () => Iterator | Promise> | AsyncIterator, +): Dataset { return datasetFromIteratorFn(async () => { const gen = await generator(); return iteratorFromFunction(() => gen.next()); diff --git a/tfjs-data/src/readers_test.ts b/tfjs-data/src/readers_test.ts index 740ef2d6cbb..4a4bdb6306c 100644 --- a/tfjs-data/src/readers_test.ts +++ b/tfjs-data/src/readers_test.ts @@ -45,6 +45,21 @@ describeAllEnvs('readers', () => { expect(result).toEqual([0, 1, 2, 3, 4]); }); + it('generate dataset from async generator', async () => { + async function* dataGenerator() { + const numElements = 5; + let index = 0; + while (index < numElements) { + const x = index; + index++; + yield x; + } + } + const ds = tfd.generator(dataGenerator); + const result = await ds.toArrayForTest(); + expect(result).toEqual([0, 1, 2, 3, 4]); + }); + it('generate multiple datasets from JavaScript generator', async () => { function* dataGenerator() { const numElements = 5;