Skip to content

Commit

Permalink
ranking intervention charts improvement (#6126)
Browse files Browse the repository at this point in the history
  • Loading branch information
shawnyama authored Jan 17, 2025
1 parent 6d48187 commit f4d1109
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Dataset } from '@/types/Types';
import { WorkflowPortStatus } from '@/types/workflow';
import { renameFnGenerator } from '@/components/workflow/ops/calibrate-ciemss/calibrate-utils';

import { createRankingInterventionsChart } from '@/services/charts';
import { createRankingInterventionsChart, CATEGORICAL_SCHEME } from '@/services/charts';
import { DATASET_VAR_NAME_PREFIX, getDatasetResultCSV, mergeResults, getDataset } from '@/services/dataset';
import {
DataArray,
Expand Down Expand Up @@ -184,6 +184,7 @@ export function generateRankingCharts(
props,
modelConfigIdToInterventionPolicyIdMap,
chartData,
datasets,
interventionPolicies
) {
// Reset charts
Expand All @@ -198,6 +199,7 @@ export function generateRankingCharts(
})
.flat();
const allRankedCriteriaValues: { score: number; name: string }[][] = [];
const interventionNameColorMap: Record<string, string> = {};

props.node.state.criteriaOfInterestCards.forEach((card) => {
if (!card.selectedConfigurationId || !chartData.value) return;
Expand All @@ -208,12 +210,29 @@ export function generateRankingCharts(
: chartData.value.resultSummary[chartData.value.resultSummary.length - 1];

const rankingCriteriaValues: { score: number; name: string }[] = [];
interventionPolicies.value.forEach((policy, index) => {

let colorIndex = 0;
datasets.value.forEach(({ metadata }, index: number) => {
const policy = interventionPolicies.value.find(
({ id }) => id === metadata.simulationAttributes?.interventionPolicyId
);

// Skip this intervention policy if a configuration is not using it
if (!policy.id || !policy.name || !commonInterventionPolicyIds.includes(policy.id) || !card.selectedVariable) {
if (
!policy ||
!policy.id ||
!policy.name ||
!commonInterventionPolicyIds.includes(policy.id) ||
!card.selectedVariable
) {
return;
}

if (!interventionNameColorMap[policy.name]) {
interventionNameColorMap[policy.name] = CATEGORICAL_SCHEME[colorIndex];
colorIndex++;
}

rankingCriteriaValues.push({
score: pointOfComparison[`${chartData.value?.pyciemssMap[card.selectedVariable]}_mean:${index}`] ?? 0,
name: policy.name ?? ''
Expand All @@ -225,32 +244,37 @@ export function generateRankingCharts(
? rankingCriteriaValues.sort((a, b) => b.score - a.score)
: rankingCriteriaValues.sort((a, b) => a.score - b.score);

sortedRankingCriteriaValues.forEach((value, index) => {
value.score = index + 1;
});

rankingCriteriaCharts.value.push(createRankingInterventionsChart(sortedRankingCriteriaValues, card.name));
rankingCriteriaCharts.value.push(
createRankingInterventionsChart(
sortedRankingCriteriaValues,
interventionNameColorMap,
card.name,
card.selectedVariable
)
);
allRankedCriteriaValues.push(sortedRankingCriteriaValues);
});

// Sum up the scores of the same intervention policy
const scoreMap: Record<string, number> = {};
// Sum up the values of the same intervention policy
const valueMap: Record<string, number> = {};
allRankedCriteriaValues.flat().forEach(({ score, name }) => {
if (scoreMap[name]) {
scoreMap[name] += score;
if (valueMap[name]) {
valueMap[name] += score;
} else {
scoreMap[name] = score;
valueMap[name] = score;
}
});

const rankingResultsValues = Object.keys(scoreMap)
const rankingResultsScores: { score: number; name: string }[] = Object.keys(valueMap)
.map((name) => ({
name,
score: scoreMap[name]
score: valueMap[name]
}))
.sort((a, b) => a.score - b.score);
.sort((a, b) => a.score - b.score)
// Instead of the values, we want to rank by score
.map((value, index) => ({ ...value, score: index + 1 }));

rankingResultsChart.value = createRankingInterventionsChart(rankingResultsValues, '');
rankingResultsChart.value = createRankingInterventionsChart(rankingResultsScores, interventionNameColorMap);
}

export async function generateImpactCharts(
Expand All @@ -269,13 +293,15 @@ export async function generateImpactCharts(
}

// TODO: this should probably be split up into smaller functions but for now it's at least not duplicated in the node and drilldown
// TODO: Please type the function params in this file for a later pass
export async function initialize(
props,
isFetchingDatasets,
datasets,
datasetResults,
modelConfigIdToInterventionPolicyIdMap,
chartData,
impactChartData,
rankingChartData,
baselineDatasetIndex,
selectedPlotType,
modelConfigurations,
Expand Down Expand Up @@ -311,7 +337,7 @@ export async function initialize(
datasetResults.value = await fetchDatasetResults(datasets.value);
isFetchingDatasets.value = false;

await generateImpactCharts(chartData, datasets, datasetResults, baselineDatasetIndex, selectedPlotType);
await generateImpactCharts(impactChartData, datasets, datasetResults, baselineDatasetIndex, selectedPlotType);
const modelConfigurationIds = Object.keys(modelConfigIdToInterventionPolicyIdMap.value);
if (isEmpty(modelConfigurationIds)) return;
const modelConfigurationPromises = modelConfigurationIds.map((id) => getModelConfigurationById(id));
Expand All @@ -326,12 +352,22 @@ export async function initialize(
interventionPolicies.value = policies.filter((policy) => policy !== null);
});

if (!rankingChartData) return;

rankingChartData.value = buildChartData(
datasets.value,
datasetResults.value,
baselineDatasetIndex.value,
PlotValue.VALUE
);

generateRankingCharts(
rankingCriteriaCharts,
rankingResultsChart,
props,
modelConfigIdToInterventionPolicyIdMap,
chartData,
rankingChartData,
datasets,
interventionPolicies
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
:options="compareOptions"
option-label="label"
option-value="value"
@change="outputPanelBehavior"
/>
<!-- Pascale asked me to hide this until the feature is implemented -->
<!-- <tera-checkbox
Expand All @@ -43,9 +44,7 @@
option-value="id"
:loading="isFetchingDatasets"
placeholder="Optional"
@change="
generateImpactCharts(chartData, datasets, datasetResults, baselineDatasetIndex, selectedPlotType)
"
@change="onChangeImpactComparison"
/>
<label>Comparison tables</label>
<tera-checkbox v-model="isATESelected" label="Average treatment effect (ATE)" />
Expand Down Expand Up @@ -167,9 +166,7 @@
v-model="knobs.selectedPlotType"
:value="option.value"
name="plotValues"
@change="
generateImpactCharts(chartData, datasets, datasetResults, baselineDatasetIndex, selectedPlotType)
"
@change="onChangeImpactComparison"
/>
<label class="pl-2 py-1" :for="option.value">{{ option.label }}</label>
</div>
Expand Down Expand Up @@ -261,11 +258,16 @@ const onRun = () => {
rankingResultsChart,
props,
modelConfigIdToInterventionPolicyIdMap,
chartData,
rankingChartData,
datasets,
interventionPolicies
);
};
function onChangeImpactComparison() {
generateImpactCharts(impactChartData, datasets, datasetResults, baselineDatasetIndex, selectedPlotType);
}
interface BasicKnobs {
criteriaOfInterestCards: CriteriaOfInterestCard[];
selectedCompareOption: CompareValue;
Expand All @@ -281,7 +283,7 @@ const knobs = ref<BasicKnobs>({
});
const addCriteria = () => {
knobs.value.criteriaOfInterestCards.push(blankCriteriaOfInterest);
knobs.value.criteriaOfInterestCards.push(cloneDeep(blankCriteriaOfInterest));
};
const deleteCriteria = (index: number) => {
Expand All @@ -306,21 +308,22 @@ const {
const outputPanel = ref(null);
const chartSize = useDrilldownChartSize(outputPanel);
const chartData = ref<ChartData | null>(null);
const impactChartData = ref<ChartData | null>(null);
const rankingChartData = ref<ChartData | null>(null);
const rankingResultsChart = ref<any>(null);
const rankingCriteriaCharts = ref<any>([]);
const variableNames = computed(() => {
if (chartData.value === null) return [];
if (impactChartData.value === null) return [];
const excludes = ['timepoint_id', 'sample_id', 'timepoint_unknown'];
return Object.keys(chartData.value.pyciemssMap).filter((key) => !excludes.includes(key));
return Object.keys(impactChartData.value.pyciemssMap).filter((key) => !excludes.includes(key));
});
const { generateAnnotation, getChartAnnotationsByChartId, useCompareDatasetCharts } = useCharts(
props.node.id,
null,
null,
chartData,
impactChartData,
chartSize,
null,
null
Expand All @@ -331,18 +334,29 @@ const baselineDatasetIndex = computed(() =>
);
const variableCharts = useCompareDatasetCharts(selectedVariableSettings, selectedPlotType, baselineDatasetIndex);
function outputPanelBehavior() {
if (knobs.value.selectedCompareOption === CompareValue.RANK) {
isOutputSettingsOpen.value = false;
} else if (knobs.value.selectedCompareOption === CompareValue.IMPACT) {
isOutputSettingsOpen.value = true;
}
}
onMounted(() => {
const state = cloneDeep(props.node.state);
knobs.value = Object.assign(knobs.value, state);
if (!knobs.value.selectedDataset) knobs.value.selectedDataset = datasets.value[0]?.id ?? null;
outputPanelBehavior();
initialize(
props,
isFetchingDatasets,
datasets,
datasetResults,
modelConfigIdToInterventionPolicyIdMap,
chartData,
impactChartData,
rankingChartData,
baselineDatasetIndex,
selectedPlotType,
modelConfigurations,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ onMounted(() => {
datasetResults,
modelConfigIdToInterventionPolicyIdMap,
chartData,
null,
baselineDatasetIndex,
selectedPlotType,
modelConfigurations,
Expand All @@ -94,6 +95,7 @@ watch(
datasetResults,
modelConfigIdToInterventionPolicyIdMap,
chartData,
null,
baselineDatasetIndex,
selectedPlotType,
modelConfigurations,
Expand Down
27 changes: 23 additions & 4 deletions packages/client/hmi-client/src/services/charts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1793,7 +1793,12 @@ export function createFunmanParameterCharts(
};
}

export function createRankingInterventionsChart(values: { score: number; name: string }[], title: string) {
export function createRankingInterventionsChart(
values: { score: number; name: string }[],
interventionNameColorMap: Record<string, string>,
title: string | null = null,
variableName: string | null = null
) {
const globalFont = 'Figtree';

return {
Expand Down Expand Up @@ -1831,7 +1836,20 @@ export function createRankingInterventionsChart(values: { score: number; name: s
y: {
field: 'score',
type: 'quantitative',
title: 'Score'
// If a specific variable is selected the score should hold its actual value
title: variableName || 'Score'
},
color: {
field: 'name',
type: 'nominal',
scale: {
domain: Object.keys(interventionNameColorMap),
range: Object.values(interventionNameColorMap)
},
legend: {
title: null,
orient: 'top'
}
}
},
transform: [{ window: [{ op: 'row_number', as: 'index' }] }],
Expand All @@ -1846,13 +1864,14 @@ export function createRankingInterventionsChart(values: { score: number; name: s
align: 'right',
baseline: 'bottom',
dy: -15,
angle: 270
angle: 270,
fill: 'black'
// FIXME:
// I don't know how to fix the text to the bottom of the bar, its origin seems to be around the top
// and giving it the proper dx shift varies depending on the bar size
},
encoding: {
text: { field: 'name', type: 'nominal' }
text: { field: 'name', type: 'nominal', color: 'black' }
}
}
]
Expand Down

0 comments on commit f4d1109

Please sign in to comment.