Skip to content

Commit

Permalink
Completed text-classification modality (#24)
Browse files Browse the repository at this point in the history
* Completed text-classification modality

* Refactored textClassification output element
  • Loading branch information
ShivanshShalabh authored Jul 3, 2024
1 parent ffe436d commit 24cb796
Show file tree
Hide file tree
Showing 12 changed files with 325 additions and 48 deletions.
13 changes: 12 additions & 1 deletion src/components/Experiment/QuickInput/QuickInput.stories.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
imageTo3D,
imageToText,
textTo3D,
textClassification,
} from "../../../helpers/TaskIDs";
import {
SampleImageClassificationInputs,
Expand All @@ -33,7 +34,8 @@ import {
SampleTextToImage,
SampleTextToVideo,
SampleImageToText,
SampleTextTo3DInputs
SampleTextTo3DInputs,
SampleTextClassification,
} from "../../../helpers/sampleImages";

export default {
Expand Down Expand Up @@ -251,3 +253,12 @@ TextTo3D.args = {
},
};

export const TextClassification = Template.bind({});
TextClassification.args = {
sampleInputs: SampleTextClassification,
model: {
output: {
type: textClassification,
},
},
};
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@import "../../../../../App";

.classification-output {
.classification-output, .text-classification-output {
background-color: white;
border-radius: 5px;
box-shadow: 0 0 15px -2px rgba(0, 0, 0, 0.25);
Expand All @@ -10,7 +10,6 @@
margin-left: 16px;

padding: 32px;
width: 0;

@include maxWidth(1000px) {
margin-left: 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import React from 'react';
import TopPrediction from "../Classification/TopPrediction";
import "../Classification/ClassificationOutput.scss";
import PredictionExpander from "../../../../Common/PredictionExpander";
import NoPredictions from "../_Common/components/NoPredictions";
import Task from "../../../../../helpers/Task";
import OutputDuration from "../_Common/components/OutputDuration";
import DurationConverter from "../_Common/utils/DurationConverter";
import useBEMNaming from "../../../../../common/useBEMNaming";
import InputPreview from '../../InputPreview';
const defaultProps = {
className: "text-classification-output",
features: []
};

export default function TextClassificationOutput(givenProps) {
const props = { ...defaultProps, ...givenProps };
const { getBlock, getElement } = useBEMNaming(props.className);
const task = Task.text_classification;
if(props?.trial?.results?.responses[0]?.features) {
props.features = props?.trial.results.responses[0].features;
}
const getPredictionBody = () => {
if (props.features.length > 0)
return <div className={getElement('predictions')}>
<TopPrediction hideRating={props.hideRating} feature={props.features[0]} />
<PredictionExpander predictions={props.features} />
</div>;

return <NoPredictions modelId={props.modelId} />;
};
return (
<div className={getBlock()}>
<div className={getElement("title-row")}>
<h3 className={getElement('title')}>Output</h3>
{!props.hideDuration &&
<OutputDuration duration={DurationConverter(props.trial.results.duration)} />
}
</div>
<div className={getElement('subtitle')}>{task.outputText}
</div>
<InputPreview input={props.trial.inputs[0]} inputType="text" onBackClicked={props.onBackClicked} />
{getPredictionBody()}
</div>
);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import React from "react";
import TextClassificationOutput from "./TextClassificationOutput";
import { TestTextClassificationOutput} from "./testData/testTextClassification";

export default {
title: "Experiments/Quick Output/Text Classification",
component: TextClassificationOutput,
};

const template = (args) => <TextClassificationOutput {...args} />;

export const Default = template.bind({});
Default.args = { trial: TestTextClassificationOutput };
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import React from 'react';
import expect from 'expect';
import {shallow} from 'enzyme';
import ClassificationOutput from './ClassificationOutput';
import TopPrediction from "./TopPrediction";
import PredictionExpander from "../../../../Common/PredictionExpander";
import TestFeatures, {TestImageClassificationResult} from "./Features";

describe('Classification Output Component', () => {
describe('Renders', () => {
let wrapper;

beforeEach(() => {
wrapper = shallow(<ClassificationOutput trial={TestImageClassificationResult} features={TestFeatures}
modelId={1}/>);
});

it('with a container div', () => {
expect(wrapper.at(0).type()).toBe('div');
expect(wrapper.at(0).prop('className')).toBe('classification-output');
});

it('with a title', () => {
const titleElement = wrapper.childAt(0).childAt(0);
expect(titleElement.type()).toBe('h3');
expect(titleElement.prop('className')).toBe('classification-output__title');
expect(titleElement.text()).toBe('Output');
});

it('with a subtitle', () => {
expect(wrapper.childAt(1).type()).toBe('div');
expect(wrapper.childAt(1).prop('className')).toBe('classification-output__subtitle');
expect(wrapper.childAt(1).text()).toBe('How this model identified the object in this image:');
});

describe('with a list of predictions', () => {
describe('beginning with the top prediction component', () => {
it('that has been passed the first prediction', () => {
const topPrediction = wrapper.childAt(2).childAt(0);

expect(topPrediction.type()).toBe(TopPrediction);
expect(topPrediction.prop('feature')).toBe(TestFeatures[0]);
});
});

it('that shows a prediction expander', () => {
const predictions = wrapper.childAt(2);

expect(predictions.prop('className')).toBe('classification-output__predictions');
expect(predictions.childAt(1).type()).toBe(PredictionExpander);
expect(predictions.childAt(1).prop('predictions')).toBe(TestFeatures);
});

});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
export const TestTextClassificationOutputGeneratedToken = {
id: "sampleidhere"
};

export const TestTextClassificationOutput = {
id: "sampletesttextclassificationoutputidhere",
inputs: ["The weather is very pleasant today."],
completed_at: "2023-06-03T18:17:14.513854Z",
results: {
'duration': "9.216154124s",
'duration_for_inference': "9.193807904s",
'responses': [
{

'features':
[
{
classification: {
label: 'positive'
},
"probability": 0.9846002459526062
},
{
classification: {
"label": "neutral"
},
"probability": 0.012036120519042015

},
{
classification: {
"label": "negative"
},
"probability": 0.0033636766020208597
}

],
'id': "sampletesttextclassificationoutputresponseidhere"
}
]
}
};
71 changes: 45 additions & 26 deletions src/components/Experiment/QuickOutput/QuickOutput.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ import {
visualQuestionAnswering,
documentQuestionAnswering,
textToVideo,
textTo3D
textTo3D,
textClassification,
imageToText,
} from "../../../helpers/TaskIDs";
import ObjectDetection from "./Outputs/ObjectDetection/ObjectDetection";
import ImageEnhancement from "./Outputs/ImageEnhancement/ImageEnhancement";
Expand All @@ -32,18 +34,20 @@ import TextOutput from "./Outputs/Text/TextOutput";
import TextToCodeOutput from "./Outputs/TextToCode/TextToCodeOutput";
import TextConversationOutput from "./Outputs/TextConversation/TextConversationOutput";
import StyleTransferOutput from "./Outputs/StyleTransfer/StyleTransferOutput";
import TextGuidedImageToImageOutput from "./Outputs/TextGuidedImageToImage/TextGuidedImageToImageOutput"
import VisualQuestionAnsweringOutput from "./Outputs/VisualQuestionAnswering/VisualQuestionAnsweringOutput"
import DocumentQuestionAnsweringOutput from "./Outputs/DocumentQuestionAnswering/DocumentQuestionAnsweringOutput"
import TextToVideoOutput from "./Outputs/TextToVideo/TextToVideoOutput"
import TextToImageOutput from "./Outputs/TextToImage/TextToImageOutput"
import TextGuidedImageToImageOutput from "./Outputs/TextGuidedImageToImage/TextGuidedImageToImageOutput";
import VisualQuestionAnsweringOutput from "./Outputs/VisualQuestionAnswering/VisualQuestionAnsweringOutput";
import DocumentQuestionAnsweringOutput from "./Outputs/DocumentQuestionAnswering/DocumentQuestionAnsweringOutput";
import TextToVideoOutput from "./Outputs/TextToVideo/TextToVideoOutput";
import TextToImageOutput from "./Outputs/TextToImage/TextToImageOutput";
import TextClassificationOutput from "./Outputs/TextClassification/TextClassificationOutput";
import ImageToTextOutput from "./Outputs/ImageToText/ImageToTextOutput";
import { AudioToText, ImageTo3D, TextTo3D } from "../QuickInput/QuickInput.stories";

const defaultProps = {
className: "quick-output",
features: [],
input: "",
compare: () => {},
compare: () => { },
processFailed: false,
inputType: "image", // Todo: Change this default?
};
Expand Down Expand Up @@ -107,18 +111,18 @@ export default function QuickOutput(givenProps) {
);
case styleTransfer:
return (
<StyleTransferOutput
<StyleTransferOutput
onBackClicked={props.onBackClicked}
trial={props.trialOutput}
/>
)
);
case imageTo3D:
return (
<ImageTo3D
<ImageTo3D
onBackClicked={props.onBackClicked}
trial={props.trialOutput}
trial={props.trialOutput}
/>
)
);
case textToText:
return (
<TextOutput
Expand All @@ -134,61 +138,76 @@ export default function QuickOutput(givenProps) {
/>
);
case audioToText:
return(
<AudioToText
return (
<AudioToText
onBackClicked={props.onBackClicked}
trial={props.trialOutput}
trial={props.trialOutput}
/>
)
);
case textConversation:
return (
<TextConversationOutput
<TextConversationOutput
trial={props.trialOutput}
onSubmit={props.runTrial}
/>
)
);
case textGuidedImageToImage:
return (
<TextGuidedImageToImageOutput
onBackClicked={props.onBackClicked}
trial={props.trialOutput}
/>
)
);
case visualQuestionAnswering:
return (
<VisualQuestionAnsweringOutput
onBackClicked={props.onBackClicked}
onBackClicked={props.onBackClicked}
trial={props.trialOutput}
/>
)
);
case documentQuestionAnswering:
return (
<DocumentQuestionAnsweringOutput
onBackClicked={props.onBackClicked}
trial={props.trialOutput}
/>
)
);
case textToVideo:
return (
<TextToVideoOutput
onBackClicked={props.onBackClicked}
trial={props.trialOutput}
/>
)
);
case textToAudio:
return (
<TextToImageOutput
onBackClicked={props.onBackClicked}
trial={props.trialOutput}
/>
)
);
case textTo3D:
return (
<TextTo3D
<TextTo3D
onBackClicked={props.onBackClicked}
trial={props.trialOutput}
/>
)
);
case imageToText:
return (
<ImageToTextOutput
onBackClicked={props.onBackClicked}
trial={props.trialOutput}
/>
);

case textClassification:
return (
<TextClassificationOutput
onBackClicked={props.onBackClicked}
trial={props.trialOutput}
/>
);
default:
return (
<>
Expand Down
Loading

0 comments on commit 24cb796

Please sign in to comment.