Skip to content

Commit

Permalink
[tfjs-data] support async generator (#8408)
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik authored Oct 8, 2024
1 parent 3daf152 commit 636c616
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
19 changes: 9 additions & 10 deletions tfjs-data/src/readers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,12 @@ export function func<T extends TensorContainer>(

/**
* Create a `Dataset` that produces each element from provided JavaScript
* generator, which is a function*
* (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions),
* or a function that returns an
* iterator
* (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions).
* generator, which is a function that returns a (potentially async) iterator.
*
* The returned iterator should have `.next()` function that returns element in
* format of `{value: TensorContainer, done:boolean}`.
* For more information on iterators and generators, see
* https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators .
* For the iterator protocol, see
* https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Iteration_protocols .
*
* Example of creating a dataset from an iterator factory:
* ```js
Expand Down Expand Up @@ -188,8 +186,8 @@ export function func<T extends TensorContainer>(
* await ds.forEachAsync(e => console.log(e));
* ```
*
* @param generator A JavaScript generator function that returns a JavaScript
* iterator.
* @param generator A JavaScript function that returns
* a (potentially async) JavaScript iterator.
*
* @doc {
* heading: 'Data',
Expand All @@ -199,7 +197,8 @@ export function func<T extends TensorContainer>(
* }
*/
export function generator<T extends TensorContainer>(
generator: () => Iterator<T>| Promise<Iterator<T>>): Dataset<T> {
generator: () => Iterator<T> | Promise<Iterator<T>> | AsyncIterator<T>,
): Dataset<T> {
return datasetFromIteratorFn(async () => {
const gen = await generator();
return iteratorFromFunction(() => gen.next());
Expand Down
15 changes: 15 additions & 0 deletions tfjs-data/src/readers_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ describeAllEnvs('readers', () => {
expect(result).toEqual([0, 1, 2, 3, 4]);
});

it('generate dataset from async generator', async () => {
async function* dataGenerator() {
const numElements = 5;
let index = 0;
while (index < numElements) {
const x = index;
index++;
yield x;
}
}
const ds = tfd.generator(dataGenerator);
const result = await ds.toArrayForTest();
expect(result).toEqual([0, 1, 2, 3, 4]);
});

it('generate multiple datasets from JavaScript generator', async () => {
function* dataGenerator() {
const numElements = 5;
Expand Down

0 comments on commit 636c616

Please sign in to comment.