diff --git a/app/javascript/projects/modelling/components/segment_component.ts b/app/javascript/projects/modelling/components/segment_component.ts index 7ac0f96..3c6fd1e 100644 --- a/app/javascript/projects/modelling/components/segment_component.ts +++ b/app/javascript/projects/modelling/components/segment_component.ts @@ -9,7 +9,7 @@ import { createXYZ } from "ol/tilegrid" import { Point, Polygon } from "ol/geom" import { Coordinate } from "ol/coordinate" -async function retrieveSegmentationMasks(prompts: string, threshold: string, projectProps: ProjectProperties) : Promise{ +async function retrieveSegmentationMasks(prompts: string, confidence: string, projectProps: ProjectProperties) : Promise{ const tileGrid = createXYZ() @@ -18,7 +18,7 @@ async function retrieveSegmentationMasks(prompts: string, threshold: string, pro const segs = await fetch("https://landscapes.wearepal.ai/api/v1/segment?" + new URLSearchParams( { labels: prompts, - threshold, + confidence, bbox: projectProps.extent.join(","), layer: "rgb:full_mosaic_3857", height: outputTileRange.getHeight().toString(), @@ -68,11 +68,11 @@ async function retrieveSegmentationMasks(prompts: string, threshold: string, pro ) result.set(featureTileRange.maxX, featureTileRange.minY, true) - confBox.set(featureTileRange.maxX, featureTileRange.minY, pred.score) + confBox.set(featureTileRange.maxX, featureTileRange.minY, pred.confidence) }) const predBox = pred.box - const predExtent = [Math.min(predBox.xmin, predBox.xmin), Math.min(predBox.ymin, predBox.ymax), Math.max(predBox.xmin, predBox.xmax), Math.max(predBox.ymin, predBox.ymax)] + const predExtent = [predBox.xmin, predBox.ymin, predBox.xmax, predBox.ymax] const featureTileRange = tileGrid.getTileRangeForExtentAndZ( predExtent, @@ -91,18 +91,20 @@ async function retrieveSegmentationMasks(prompts: string, threshold: string, pro } export class SegmentComponent extends BaseComponent { + cache: Map projectProps: ProjectProperties constructor(projectProps: ProjectProperties) { super("Segmentation Model") this.category = "Inputs" this.projectProps = projectProps + this.cache = new Map() } async builder(node: Node) { - if (!('threshold' in node.data)) { - node.data.threshold = "0.1" + if (!('confidence' in node.data)) { + node.data.confidence = "10" } if (!('prompt' in node.data)) { @@ -110,11 +112,11 @@ export class SegmentComponent extends BaseComponent { } node.addOutput(new Output('mask', 'Segmentation Mask', booleanDataSocket)) - node.addOutput(new Output('conf', 'Segmentation Mask (Confidence)', numericDataSocket)) - node.addOutput(new Output('box', 'Segmentation Box', booleanDataSocket)) + node.addOutput(new Output('conf', 'Confidence', numericDataSocket)) + node.addOutput(new Output('box', 'Detection Box', booleanDataSocket)) node.addControl(new TextControl(this.editor, 'prompt', 'Prompt', '500px')) - node.addControl(new TextControl(this.editor, 'threshold', 'Threshold', '100px')) + node.addControl(new TextControl(this.editor, 'confidence', 'Confidence (%)', '100px')) } @@ -124,13 +126,20 @@ export class SegmentComponent extends BaseComponent { if (editorNode === undefined) { return } const prompts = node.data.prompt as string - const threshold = node.data.threshold as string - - const result = await retrieveSegmentationMasks(prompts, threshold, this.projectProps) - - outputs['mask'] = result[0] - outputs['box'] = result[1] - outputs['conf'] = result[2] + const confidence = node.data.confidence as string + + if (this.cache.has(`${prompts}_${confidence}%`)) { + const result = this.cache.get(`${prompts}_${confidence}%`)! + outputs['mask'] = result[0] + outputs['box'] = result[1] + outputs['conf'] = result[2] + }else{ + const result = await retrieveSegmentationMasks(prompts, confidence, this.projectProps) + this.cache.set(`${prompts}_${confidence}%`, result) + outputs['mask'] = result[0] + outputs['box'] = result[1] + outputs['conf'] = result[2] + } } -} \ No newline at end of file +}