diff --git a/src/utils/core.js b/src/utils/core.js index 6a6137dff..5fe0a8d05 100644 --- a/src/utils/core.js +++ b/src/utils/core.js @@ -1,10 +1,10 @@ /** * @file Core utility functions/classes for Transformers.js. - * + * * These are only used internally, meaning an end-user shouldn't * need to access anything here. - * + * * @module utils/core */ @@ -46,7 +46,7 @@ export function escapeRegExp(string) { * Check if a value is a typed array. * @param {*} val The value to check. * @returns {boolean} True if the value is a `TypedArray`, false otherwise. - * + * * Adapted from https://stackoverflow.com/a/71091338/13989043 */ export function isTypedArray(val) { @@ -63,6 +63,15 @@ export function isIntegralNumber(x) { return Number.isInteger(x) || typeof x === 'bigint' } +/** + * Determine if a provided width or height is nullish. + * @param {*} x The value to check. + * @returns {boolean} True if the value is `null`, `undefined` or `-1`, false otherwise. + */ +export function isNullishDimension(x) { + return x === null || x === undefined || x === -1 || x === '-1'; +} + /** * Calculates the dimensions of a nested array. * @@ -132,9 +141,9 @@ export function calculateReflectOffset(i, w) { } /** - * - * @param {Object} o - * @param {string[]} props + * + * @param {Object} o + * @param {string[]} props * @returns {Object} */ export function pick(o, props) { @@ -151,7 +160,7 @@ export function pick(o, props) { /** * Calculate the length of a string, taking multi-byte characters into account. * This mimics the behavior of Python's `len` function. - * @param {string} s The string to calculate the length of. + * @param {string} s The string to calculate the length of. * @returns {number} The length of the string. */ export function len(s) { diff --git a/src/utils/image.js b/src/utils/image.js index 33bdf11d8..3e123fe67 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -1,13 +1,14 @@ /** - * @file Helper module for image processing. - * - * These functions and classes are only used internally, + * @file Helper module for image processing. + * + * These functions and classes are only used internally, * meaning an end-user shouldn't need to access anything here. - * + * * @module utils/image */ +import { isNullishDimension } from './core.js'; import { getFile } from './hub.js'; import { env } from '../env.js'; import { Tensor } from './tensor.js'; @@ -91,7 +92,7 @@ export class RawImage { this.channels = channels; } - /** + /** * Returns the size of the image (width, height). * @returns {[number, number]} The size of the image (width, height). */ @@ -101,9 +102,9 @@ export class RawImage { /** * Helper method for reading an image from a variety of input types. - * @param {RawImage|string|URL} input + * @param {RawImage|string|URL} input * @returns The image object. - * + * * **Example:** Read image from a URL. * ```javascript * let image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg'); @@ -181,7 +182,7 @@ export class RawImage { /** * Helper method to create a new Image from a tensor - * @param {Tensor} tensor + * @param {Tensor} tensor */ static fromTensor(tensor, channel_format = 'CHW') { if (tensor.dims.length !== 3) { @@ -306,8 +307,8 @@ export class RawImage { /** * Resize the image to the given dimensions. This method uses the canvas API to perform the resizing. - * @param {number} width The width of the new image. - * @param {number} height The height of the new image. + * @param {number} width The width of the new image. `null` or `-1` will preserve the aspect ratio. + * @param {number} height The height of the new image. `null` or `-1` will preserve the aspect ratio. * @param {Object} options Additional options for resizing. * @param {0|1|2|3|4|5|string} [options.resample] The resampling method to use. * @returns {Promise} `this` to support chaining. @@ -319,6 +320,18 @@ export class RawImage { // Ensure resample method is a string let resampleMethod = RESAMPLING_MAPPING[resample] ?? resample; + // Calculate width / height to maintain aspect ratio, in the event that + // the user passed a null value in. + // This allows users to pass in something like `resize(320, null)` to + // resize to 320 width, but maintain aspect ratio. + if (isNullishDimension(width) && isNullishDimension(height)) { + return this; + } else if (isNullishDimension(width)) { + width = (height / this.height) * this.width; + } else if (isNullishDimension(height)) { + height = (width / this.width) * this.height; + } + if (BROWSER_ENV) { // TODO use `resample` in browser environment @@ -355,7 +368,7 @@ export class RawImage { case 'nearest': case 'bilinear': case 'bicubic': - // Perform resizing using affine transform. + // Perform resizing using affine transform. // This matches how the python Pillow library does it. img = img.affine([width / this.width, 0, 0, height / this.height], { interpolator: resampleMethod @@ -368,7 +381,7 @@ export class RawImage { img = img.resize({ width, height, fit: 'fill', - kernel: 'lanczos3', // PIL Lanczos uses a kernel size of 3 + kernel: 'lanczos3', // PIL Lanczos uses a kernel size of 3 }); break; @@ -425,6 +438,31 @@ export class RawImage { } } + /** + * Pad the image to a square. + * @param {*} dim The length of one of the square's side. + * @returns {Promise} `this` to support chaining. + */ + async padToSquare(dim) { + // We cannot pad to a square if the image is larger than provided size. + if (this.width > dim || this.height > dim) { + return this; + } + + // If no value was provided, then use the largest side. + if (dim === undefined) { + dim = Math.max(this.width, this.height); + } + + // Odd numbers will add extra padding to the right and bottom. + return this.pad([ + Math.floor((dim - this.width) / 2), + Math.ceil((dim - this.width) / 2), + Math.floor((dim - this.height) / 2), + Math.ceil((dim - this.height) / 2), + ]); + } + async crop([x_min, y_min, x_max, y_max]) { // Ensure crop bounds are within the image x_min = Math.max(x_min, 0); @@ -447,7 +485,7 @@ export class RawImage { // Create canvas object for this image const canvas = this.toCanvas(); - // Create a new canvas of the desired size. This is needed since if the + // Create a new canvas of the desired size. This is needed since if the // image is too small, we need to pad it with black pixels. const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d'); @@ -495,7 +533,7 @@ export class RawImage { // Create canvas object for this image const canvas = this.toCanvas(); - // Create a new canvas of the desired size. This is needed since if the + // Create a new canvas of the desired size. This is needed since if the // image is too small, we need to pad it with black pixels. const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d'); @@ -742,4 +780,4 @@ export class RawImage { } }); } -} \ No newline at end of file +} diff --git a/tests/utils/utils.test.js b/tests/utils/utils.test.js index 8a1891f19..cd20ad012 100644 --- a/tests/utils/utils.test.js +++ b/tests/utils/utils.test.js @@ -1,5 +1,6 @@ import { AutoProcessor, hamming, hanning, mel_filter_bank } from "../../src/transformers.js"; import { getFile } from "../../src/utils/hub.js"; +import { RawImage } from "../../src/utils/image.js"; import { MAX_TEST_EXECUTION_TIME } from "../init.js"; import { compare } from "../test_utils.js"; @@ -59,4 +60,55 @@ describe("Utilities", () => { expect(await data.text()).toBe("Hello, world!"); }); }); + + describe("Image utilities", () => { + it("Read image from URL", async () => { + const image = await RawImage.fromURL("https://picsum.photos/300/200"); + expect(image.width).toBe(300); + expect(image.height).toBe(200); + expect(image.channels).toBe(3); + }); + + it("Can resize image", async () => { + const image = await RawImage.fromURL("https://picsum.photos/300/200"); + const resized = await image.resize(150, 100); + expect(resized.width).toBe(150); + expect(resized.height).toBe(100); + }); + + it("Can resize with aspect ratio", async () => { + const image = await RawImage.fromURL("https://picsum.photos/300/200"); + const resized = await image.resize(150, null); + expect(resized.width).toBe(150); + expect(resized.height).toBe(100); + }); + + it("Returns original image if width and height are null", async () => { + const image = await RawImage.fromURL("https://picsum.photos/300/200"); + const resized = await image.resize(null, null); + expect(resized.width).toBe(300); + expect(resized.height).toBe(200); + }); + + it("Can pad to a square with no args", async () => { + const image = await RawImage.fromURL("https://picsum.photos/300/200"); + const padded = await image.padToSquare(); + expect(padded.width).toBe(300); + expect(padded.height).toBe(300); + }); + + it("Can pad to a square with larger sides", async () => { + const image = await RawImage.fromURL("https://picsum.photos/300/200"); + const padded = await image.padToSquare(400); + expect(padded.width).toBe(400); + expect(padded.height).toBe(400); + }); + + it("Cannot pad to square if dim is smaller than image", async () => { + const image = await RawImage.fromURL("https://picsum.photos/300/200"); + const padded = await image.padToSquare(100); + expect(padded.width).toBe(300); + expect(padded.height).toBe(200); + }); + }); });