Skip to content

Commit

Permalink
Merge pull request #462 from wearepal/segmentation-component-extra-paras
Browse files Browse the repository at this point in the history
n_repeats and classifier confidence added to segmentation component
  • Loading branch information
paulthatjazz authored Dec 5, 2024
2 parents 7e59086 + a961841 commit 387519e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 16 deletions.
68 changes: 54 additions & 14 deletions app/javascript/projects/modelling/components/segment_component.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import { ProjectProperties } from "."
import { TextControl } from "../controls/text"
import { BooleanTileGrid, NumericTileGrid } from "../tile_grid"
import { createXYZ } from "ol/tilegrid"
import { Point, Polygon } from "ol/geom"
import { Point } from "ol/geom"
import { Coordinate } from "ol/coordinate"

async function retrieveSegmentationMasks(prompts: string, confidence: string, projectProps: ProjectProperties) : Promise<any[]>{
async function retrieveSegmentationMasks(prompts: string, det_conf: string, clf_conf: string, n_repeats: string, projectProps: ProjectProperties, err: (err: string) => void) : Promise<any[]>{

const tileGrid = createXYZ()

Expand All @@ -18,18 +18,30 @@ async function retrieveSegmentationMasks(prompts: string, confidence: string, pr
const segs = await fetch("https://landscapes.wearepal.ai/api/v1/segment?" + new URLSearchParams(
{
labels: prompts,
confidence,
det_conf,
clf_conf,
n_repeats,
bbox: projectProps.extent.join(","),
layer: "rgb:full_mosaic_3857",
height: outputTileRange.getHeight().toString(),
width: outputTileRange.getWidth().toString(),
}
))

if(segs.status !== 200){
err(segs.statusText)
return []
}

const segsJson = await segs.json()

const preds = segsJson.predictions

if(preds === null){
err("No predictions found")
return []
}

const result = new BooleanTileGrid(
projectProps.zoom,
outputTileRange.minX,
Expand Down Expand Up @@ -103,8 +115,23 @@ export class SegmentComponent extends BaseComponent {

async builder(node: Node) {

if (!('confidence' in node.data)) {
node.data.confidence = "10"
node.meta.toolTip = "This node takes in 4 inputs: a prompt, a detector confidence, "
+"a classifier confidence, and the number of repeats. It then returns a segmentation mask, a detection box, "
+"and a confidence value. The prompt is the object you want to segment, detector confidence is the confidence "
+"threshold for the detector (it is recommended that this is set low for high recall), classifier confidence is "
+"the confidence threshold for the classifier (it is recommendeded that this is set higher for increased accuracy."
+" please note: setting this to 0 will disable this function), and the number of repeats is the number of times you want to repeat the segmentation process."

if (!('det_conf' in node.data)) {
node.data.det_conf = "5"
}

if (!('cls_conf' in node.data)) {
node.data.cls_conf = "75"
}

if (!('n_repeats' in node.data)) {
node.data.n_repeats = "5"
}

if (!('prompt' in node.data)) {
Expand All @@ -116,7 +143,9 @@ export class SegmentComponent extends BaseComponent {
node.addOutput(new Output('box', 'Detection Box', booleanDataSocket))

node.addControl(new TextControl(this.editor, 'prompt', 'Prompt', '500px'))
node.addControl(new TextControl(this.editor, 'confidence', 'Confidence (%)', '100px'))
node.addControl(new TextControl(this.editor, 'det_conf', 'Detector Confidence (%)', '100px'))
node.addControl(new TextControl(this.editor, 'cls_conf', 'Classifier Confidence (%)', '100px'))
node.addControl(new TextControl(this.editor, 'n_repeats', 'Repeats', '100px'))

}

Expand All @@ -126,19 +155,30 @@ export class SegmentComponent extends BaseComponent {
if (editorNode === undefined) { return }

const prompts = node.data.prompt as string
const confidence = node.data.confidence as string
const det_conf = node.data.det_conf as string
const cls_conf = node.data.cls_conf as string
const n_repeats = node.data.n_repeats as string

if (this.cache.has(`${prompts}_${confidence}%`)) {
const result = this.cache.get(`${prompts}_${confidence}%`)!
if (this.cache.has(`${prompts}_${cls_conf}%${det_conf}%${n_repeats}`)) {
const result = this.cache.get(`${prompts}_${cls_conf}%${det_conf}%${n_repeats}`)!
outputs['mask'] = result[0]
outputs['box'] = result[1]
outputs['conf'] = result[2]
}else{
const result = await retrieveSegmentationMasks(prompts, confidence, this.projectProps)
this.cache.set(`${prompts}_${confidence}%`, result)
outputs['mask'] = result[0]
outputs['box'] = result[1]
outputs['conf'] = result[2]
let nodeErr = ""
const result = await retrieveSegmentationMasks(prompts, det_conf, cls_conf, n_repeats, this.projectProps, (err) => {
nodeErr = err
})
if (result.length === 0) {
editorNode.meta.errorMessage = nodeErr
editorNode.update()
}else{
delete editorNode.meta.errorMessage
this.cache.set(`${prompts}_${cls_conf}%${det_conf}%${n_repeats}`, result)
outputs['mask'] = result[0]
outputs['box'] = result[1]
outputs['conf'] = result[2]
}
}

}
Expand Down
4 changes: 2 additions & 2 deletions app/javascript/projects/node_component.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ export class NodeComponent extends Node {

return (
<div className={`node ${selected}`} style={{
boxShadow: "0px 0px 8px rgba(0, 0, 0, 0.25)",
boxShadow: "0px 0px 8px rgba(0, 0, 0, 0.25)",
background: "rgba(0, 0, 0, 0.5)",
color: "white",
borderRadius: "4px",
border: "solid 3px transparent",
border: node.meta.errorMessage ? "solid 2px rgba(210, 0, 0, .71)" : "solid 3px transparent",
cursor: "pointer",
minWidth: "250px",
height: "auto",
Expand Down

0 comments on commit 387519e

Please sign in to comment.