Skip to content
Snippets Groups Projects
Commit b7da1f38 authored by Steve Reis's avatar Steve Reis
Browse files

feat(exareme2): Logistic regression CV integration

parent 7e5175cd
No related branches found
No related tags found
No related merge requests found
import { Domain } from '../../../../models/domain.model';
import { Experiment } from '../../../../models/experiment/experiment.model';
import LogisticRegressionCVHandler from './logistic-regression-cv.handler';
const data = [
{
dependent_var: 'gender',
indep_vars: [
'Intercept',
'ppmicategory[PD]',
'ppmicategory[PRODROMA]',
'lefthippocampus',
'righthippocampus',
],
summary: {
row_names: ['fold_1', 'fold_2', 'average', 'stdev'],
n_obs: [40, 40, null, null],
accuracy: [0.575, 0.5, 0.5375, 0.053033008588991036],
precision: [
0.5833333333333334, 0.45, 0.5166666666666667, 0.09428090415820635,
],
recall: [
0.3684210526315789, 0.5, 0.4342105263157895, 0.09304036594559838,
],
fscore: [
0.4516129032258065, 0.4736842105263158, 0.46264855687606116,
0.015606771061842295,
],
},
confusion_matrix: { tp: 16, fp: 16, tn: 27, fn: 21 },
roc_curves: [
{
name: 'fold_0',
tpr: [0.0, 0.0, 0.3684210526315789, 0.6842105263157895, 1.0],
fpr: [0.0, 0.0, 0.23809523809523814, 0.6666666666666667, 1.0],
auc: 0.550125313283208,
},
{
name: 'fold_1',
tpr: [0.0, 0.0, 0.5, 1.0, 1.0],
fpr: [0.0, 0.09090909090909094, 0.5, 0.9545454545454546, 1.0],
auc: 0.48863636363636365,
},
],
},
];
const domain: Domain = {
id: 'dummy-id',
groups: [],
rootGroup: {
id: 'dummy-id',
},
datasets: [{ id: 'desd-synthdata', label: 'Dead Synthdata' }],
variables: [
{ id: 'ppmicategory', label: 'PPMI Category' },
{ id: 'righthippocampus', label: 'Right Hippo Campus' },
],
};
const createExperiment = (): Experiment => ({
id: 'dummy-id',
name: 'Testing purpose',
algorithm: {
name: LogisticRegressionCVHandler.ALGO_NAME,
},
datasets: ['desd-synthdata'],
domain: 'dementia',
variables: ['ppmicategory'],
coVariables: ['righthippocampus'],
results: [],
});
describe('Logistic regression CV result handler', () => {
let logisticCVHandler: LogisticRegressionCVHandler;
let experiment: Experiment;
beforeEach(() => {
logisticCVHandler = new LogisticRegressionCVHandler();
experiment = createExperiment();
});
describe('handle', () => {
it('should return exactly 3 results', () => {
logisticCVHandler.handle(experiment, data, domain);
expect(experiment.results).toHaveLength(3);
});
});
});
import { Domain } from '../../../../models/domain.model';
import { Experiment } from '../../../../models/experiment/experiment.model';
import { HeatMapResult } from '../../../../models/result/heat-map-result.model';
import {
LineChartResult,
LineResult,
} from '../../../../models/result/line-chart-result.model';
import { TableResult } from '../../../../models/result/table-result.model';
import BaseHandler from '../base.handler';
const NUMBER_PRECISION = 4;
const lookupDict = {
dependent_var: 'Dependent variable',
indep_vars: 'Independent variables',
n_obs: 'Number of observations',
fscore: 'F-score',
accuracy: 'Accuracy',
precision: 'Precision',
recall: 'Recall',
average: 'Average',
stdev: 'Standard deviation',
blank: '',
};
const keys = ['n_obs', 'accuracy', 'recall', 'precision', 'fscore'];
type LineCurve = {
name: string;
tpr: number[];
fpr: number[];
auc: number;
};
type InputData = {
dependent_var: string;
indep_vars: string[];
summary: Record<string, number[] | string[]>;
confusion_matrix: Record<string, number>;
roc_curves: LineCurve[];
};
export default class LogisticRegressionCVHandler extends BaseHandler {
public static readonly ALGO_NAME = 'logistic_regression_cv';
private canHandle(experiment: Experiment, data: unknown): boolean {
return (
experiment.algorithm.name.toLowerCase() ===
LogisticRegressionCVHandler.ALGO_NAME &&
!!data &&
!!data[0] &&
!!data[0]['summary']
);
}
getSummary(data: InputData): TableResult {
return {
name: 'Summary',
headers: ['blank', ...keys].map((key) => ({
name: lookupDict[key],
type: 'string',
})),
data: data.summary['row_names'].map((key: any, i: number) => {
// could be optimized
return [key, ...keys.map((k) => data['summary'][k][i])];
}),
};
}
getConfusionMatrix(data: InputData): HeatMapResult {
const matrix = data['confusion_matrix'];
return {
name: 'Confusion matrix',
matrix: [[matrix['tp']], [matrix['fp']], [matrix['fn']], [matrix['tn']]],
xAxis: {
categories: ['Positive', 'Negative'],
label: 'Actual Values',
},
yAxis: {
categories: ['Negative', 'Positive'],
label: 'Predicted Values',
},
};
}
getROC(data: InputData): LineChartResult {
return {
name: 'ROC Curves',
lines: data.roc_curves.map((line: LineCurve) => {
return {
label: `${line.name} (AUC: ${line.auc.toPrecision(
NUMBER_PRECISION,
)})`,
x: line.fpr,
y: line.tpr,
} as LineResult;
}),
xAxis: {
label: 'False Positive Rate',
},
yAxis: {
label: 'True Positive Rate',
},
hasBisector: true,
};
}
handle(experiment: Experiment, data: unknown, domain?: Domain): void {
if (!this.canHandle(experiment, data))
return super.handle(experiment, data, domain);
const extractedData = data[0];
const varIds = [...experiment.variables, ...(experiment.coVariables ?? [])];
const variables = domain.variables.filter((v) => varIds.includes(v.id));
let jsonData = JSON.stringify(extractedData);
variables.forEach((v) => {
const regEx = new RegExp(v.id, 'gi');
jsonData = jsonData.replaceAll(regEx, v.label);
});
const improvedData = JSON.parse(jsonData);
const results = [
this.getSummary(improvedData),
this.getConfusionMatrix(improvedData),
this.getROC(improvedData),
];
results.filter((r) => !!r).forEach((r) => experiment.results.push(r));
}
}
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment