Skip to content

Commit

Permalink
feat: Add audio classification support (#28)
Browse files Browse the repository at this point in the history
* feat: Add audio classification support

* chore: Update icon for TextToVideo task, updated inputTexts. inputPreview and sampleImages
  • Loading branch information
ShivanshShalabh authored Jul 10, 2024
1 parent 6bc9617 commit 7fbfde1
Show file tree
Hide file tree
Showing 14 changed files with 383 additions and 140 deletions.
14 changes: 13 additions & 1 deletion src/components/Experiment/QuickInput/QuickInput.stories.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
textTo3D,
textClassification,
audioToAudio,
audioClassification,
} from "../../../helpers/TaskIDs";
import {
SampleImageClassificationInputs,
Expand All @@ -38,6 +39,7 @@ import {
SampleTextTo3DInputs,
SampleTextClassification,
SampleAudioToAudioInputs,
SampleAudioClassificationInputs,
} from "../../../helpers/sampleImages";

export default {
Expand Down Expand Up @@ -149,7 +151,7 @@ AudioToText.args = {
{
title: "automatic-speech-recognition-input(3).flac",
src: "https://xlab1.netlify.app/automatic-speech-recognition-input.flac"
},
},
],
model: {
output: {
Expand Down Expand Up @@ -273,4 +275,14 @@ AudioToAudio.args = {
type: audioToAudio,
},
},
};

export const AudioClassification = Template.bind({});
AudioClassification.args = {
sampleInputs: SampleAudioClassificationInputs,
model: {
output: {
type: audioClassification,
},
},
};
19 changes: 14 additions & 5 deletions src/components/Experiment/QuickOutput/InputPreview.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,45 @@ import useBEMNaming from "../../../common/useBEMNaming";
const defaultProps = {
className: "input-preview",
input: "",
onBackClicked: () => {},
onBackClicked: () => { },
inputType: "image", // TODO: Change this default?
};

export default function InputPreview(givenProps) {
const props = { ...defaultProps, ...givenProps };
const { getBlock, getElement } = useBEMNaming(props.className);

const inputTypes = {
image: "Image",
audio: "Audio",
text: "Text",
};


const getInput = () => {
switch (props.inputType) {
case "text":
return <p className={getElement("text")}>{props.input}</p>;
case "audio": // Currently not being used
case "audio":
return <audio className={getElement("audio")} controls src={props.input} />;
case "image":
default:
return <img className={getElement("image")} src={props.input} />;
default:
return <p>Not currently supported</p>;
}
};

return (
<div className={getBlock()}>
<h3 className={getElement("title")}>
Input {props.inputType === "image" ? "Image" : "Text"}
Input {inputTypes[props.inputType]}
</h3>
{getInput()}
<button
className={getElement("back-button")}
onClick={props.onBackClicked}
>
Try a different {props.inputType === "image" ? "image" : "value"}
Try a different {inputTypes[props.inputType]?.toLowerCase()}
</button>
</div>
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import React from 'react';
import TopPrediction from "../Classification/TopPrediction";
import "../Classification/ClassificationOutput.scss";
import PredictionExpander from "../../../../Common/PredictionExpander";
import NoPredictions from "../_Common/components/NoPredictions";
import Task from "../../../../../helpers/Task";
import OutputDuration from "../_Common/components/OutputDuration";
import DurationConverter from "../_Common/utils/DurationConverter";
import useBEMNaming from "../../../../../common/useBEMNaming";

const defaultProps = {
className: "audio-classification-output",
features: []
};

export default function AudioClassificationOutput(givenProps) {
const props = { ...defaultProps, ...givenProps };
const { getBlock, getElement } = useBEMNaming(props.className);
const task = Task.audio_classification;
if (props?.trial?.results?.responses[0]?.features) {
props.features = props?.trial.results.responses[0].features;
}

const getPredictionBody = () => {
if (props.features.length > 0)
return <div className={getElement('predictions')}>
<TopPrediction hideRating={props.hideRating} feature={props.features[0]} />
<PredictionExpander predictions={props.features} />
</div>;

return <NoPredictions modelId={props.modelId} />;
};
return (
<>

<div className={getBlock()}>
<div className={getElement("title-row")}>
<h3 className={getElement('title')}>Output</h3>
{!props.hideDuration &&
<OutputDuration duration={DurationConverter(props.trial.results.duration)} />
}
</div>
<div className={getElement('subtitle')}>{task.outputText}
</div>
{getPredictionBody()}
</div>
</>
);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import React from "react";
import AudioClassificationOutput from "./AudioClassificationOutput";
import { TestAudioClassificationOutput } from "./testData/testAudioClassification";
import QuickOutput from "../../QuickOutput";

export default {
title: "Experiments/Quick Output/Audio Classification",
component: AudioClassificationOutput,
};

const template = (args) => <QuickOutput {...args} />;

export const Default = template.bind({});
Default.args = { trialOutput: TestAudioClassificationOutput, hideHeader: true };
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { DefaultAudioClassificationModel } from "../../../../../../helpers/DefaultModels";

export const TestAudioClassificationOutputGeneratedToken = {
id: "sampleidhere"
};

export const TestAudioClassificationOutput = {
id: "sampletestaudioclassificationoutputidhere",
inputs: [
{
title: "audio1.flac",
src: "https://xlab1.netlify.app/audio-to-audio-input.flac"
},
],
completed_at: "2023-06-03T18:17:14.513854Z",
results: {
'duration': "9.216154124s",
'duration_for_inference': "9.193807904s",
'responses': [
{

'features':
[
{
classification: {
label: 'eng'
},
"probability": 0.9846002459526062
},
{
classification: {
"label": "lat"
},
"probability": 0.012036120519042015

},
{
classification: {
"label": "frn"
},
"probability": 0.0033636766020208597
}

],
'id': "sampletestaudioclassificationoutputresponseidhere"
}
]
},
model: DefaultAudioClassificationModel,
};

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@import "../../../../../App";

.classification-output, .text-classification-output {
.classification-output, .text-classification-output, .audio-classification-output {
background-color: white;
border-radius: 5px;
box-shadow: 0 0 15px -2px rgba(0, 0, 0, 0.25);
Expand Down
12 changes: 12 additions & 0 deletions src/components/Experiment/QuickOutput/QuickOutput.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
textTo3D,
textClassification,
imageToText,
audioClassification,
} from "../../../helpers/TaskIDs";
import ObjectDetection from "./Outputs/ObjectDetection/ObjectDetection";
import ImageEnhancement from "./Outputs/ImageEnhancement/ImageEnhancement";
Expand All @@ -42,6 +43,7 @@ import TextToImageOutput from "./Outputs/TextToImage/TextToImageOutput";
import TextClassificationOutput from "./Outputs/TextClassification/TextClassificationOutput";
import ImageToTextOutput from "./Outputs/ImageToText/ImageToTextOutput";
import { AudioToText, ImageTo3D, TextTo3D } from "../QuickInput/QuickInput.stories";
import AudioClassificationOutput from "./Outputs/AudioClassification/AudioClassificationOutput";

const defaultProps = {
className: "quick-output",
Expand Down Expand Up @@ -208,6 +210,16 @@ export default function QuickOutput(givenProps) {
trial={props.trialOutput}
/>
);
case audioClassification:
return (
<>
<InputPreview input={props.trialOutput.inputs[0].src} inputType="audio" onBackClicked={props.onBackClicked} />
<AudioClassificationOutput
features={props.features}
trial={props.trialOutput}
/>
</>
);
default:
return (
<>
Expand Down
1 change: 1 addition & 0 deletions src/components/ModelDetailPage/ModelDetailPage.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ const ModelDetailPage = (props) => {
const getInputType = () => {
switch (outputType) {
case audioToText:
case audioClassification:
return "audio";
case textToText:
return "text";
Expand Down
55 changes: 54 additions & 1 deletion src/helpers/DefaultModels.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import {
visualQuestionAnswering,
imageToText,
textClassification,
audioToAudio,
audioToAudio,
audioClassification,
} from "./TaskIDs";

export const DefaultImageClassificationModel = {
Expand Down Expand Up @@ -1022,4 +1023,56 @@ export const DefaultAudioToAudioModel = {
link2: "",
},
version: "1.0",
};

export const DefaultAudioClassificationModel = {
id: 193,
created_at: "2022-04-29T20:48:47.370171Z",
updated_at: "2022-04-29T20:48:47.370171Z",
attributes: {
Top1: "",
Top5: "",
kind: "CNN",
manifest_author: "Jingning Tang",
training_dataset: "PASCAL VOC 2012",
},
description:
"TensorFlow Chatbot model, which is trained on the COCO (Common Objects in Context) dataset. Use deeplabv3_mnv2_dm05_pascal_train_aug(deeplabv3_mnv2_dm05_pascal_train_aug_2018_10_01) from TensorFlow DeepLab Model Zoo.\n",
short_description:
"DeepLabv3 is a deep convolutional neural networks for semantic chatbotness. It employ atrous convolution in cascade or in parallel to capture multi-scale context by adopting multiple atrous rates.",
model: {
graph_checksum: "0336ceb67b378df8ada0efe9eadb5ac8",
graph_path:
"https://s3.amazonaws.com/store.carml.org/models/tensorflow/models/deeplabv3_mnv2_dm05_pascal_train_aug_2018_10_01/frozen_inference_graph.pb",
weights_checksum: "",
weights_path: "",
},
framework: {
id: 4,
name: "TensorFlow",
version: "1.14.0",
architectures: [
{
name: "amd64",
},
],
},
input: {
description: "text to be responded to",
type: "text",
},
license: "Apache License, Version 2.0",
name: "DeepLabv3_MobileNet_v2_DM_05_PASCAL_VOC_Train_Aug",
output: {
description: "the chatbot's response to the inputted text",
type: audioClassification,
},
url: {
github:
"https://github.com/rai-project/tensorflow/blob/master/builtin_models/DeepLabv3_MobileNet_v2_DM_05_PASCAL_VOC_Train_Aug.yml",
citation: "https://arxiv.org/pdf/1802.02611v3.pdf",
link1: "https://arxiv.org/pdf/1706.05587.pdf",
link2: "",
},
version: "1.0",
};
Loading

0 comments on commit 7fbfde1

Please sign in to comment.