Skip to content
Snippets Groups Projects
Commit 34b217ec authored by Steve Reis's avatar Steve Reis
Browse files

Merge branch 'feat/linear-regression-cv' into 'develop'

Linear Regression (CV) integrated

See merge request sibmip/gateway!70
parents 8abfa4ce d46ce3e7
No related branches found
No related tags found
No related merge requests found
import { Domain } from 'src/engine/models/domain.model';
import { TableResult } from 'src/engine/models/result/table-result.model';
import { Experiment } from '../../../../models/experiment/experiment.model';
import LinearRegressionCVHandler from './linear-regression-cv.handler';
const data = {
dependent_var: 'leftocpoccipitalpole',
indep_vars: [
'Intercept',
'righthippocampus',
'rightsogsuperioroccipitalgyrus',
'leftppplanumpolare',
],
n_obs: [497, 498, 498, 499],
mean_sq_error: [0.3296455054532643, 0.02930654997949175],
r_squared: [0.5886631959286948, 0.04365853383949705],
mean_abs_error: [0.2585157369288272, 0.019919123005319055],
};
const domain: Domain = {
id: 'dummy-id',
groups: [],
rootGroup: {
id: 'dummy-id',
},
datasets: [{ id: 'desd-synthdata', label: 'Dead Synthdata' }],
variables: [
{ id: 'leftocpoccipitalpole', label: 'Left OCP occipital Pole' },
{ id: 'righthippocampus', label: 'Right Hippo Campus' },
{ id: 'rightsogsuperioroccipitalgyrus', label: 'Right superior occipital' },
{ id: 'leftppplanumpolare', label: 'Left Planum polare' },
],
};
const createExperiment = (): Experiment => ({
id: 'dummy-id',
name: 'Testing purpose',
algorithm: {
name: 'LINEAR_REGRESSION_CROSS_VALIDATION',
},
datasets: ['desd-synthdata'],
domain: 'dementia',
variables: ['leftocpoccipitalpole'],
coVariables: [
'righthippocampus',
'rightsogsuperioroccipitalgyrus',
'leftppplanumpolare',
],
results: [],
});
describe('Linear regression CV result handler', () => {
let linearHandler: LinearRegressionCVHandler;
let experiment: Experiment;
beforeEach(() => {
linearHandler = new LinearRegressionCVHandler();
experiment = createExperiment();
});
describe('Handle', () => {
it('with standard linear algo data', () => {
const expectedDataPoints = [
['Intercept', 497],
['Right Hippo Campus', 498],
['Right superior occipital', 498],
['Left Planum polare', 499],
];
const expectedScoresData = [
['Root mean squared error', '0.3296', '0.02931'],
['R-squared', '0.5887', '0.04366'],
['Mean absolute error', '0.2585', '0.01992'],
];
linearHandler.handle(experiment, data, domain);
const json = JSON.stringify(experiment.results);
const dataPoints = experiment.results[0] as TableResult;
const scoresData = experiment.results[1] as TableResult;
expect(dataPoints.data).toStrictEqual(expectedDataPoints);
expect(scoresData.data).toStrictEqual(expectedScoresData);
expect(json.includes(domain.variables[0].label)).toBeTruthy();
expect(experiment.results.length === 2);
});
it('Should be empty with another algo', () => {
experiment.algorithm.name = 'dummy_algo';
linearHandler.handle(experiment, data, domain);
expect(experiment.results.length === 0);
});
});
});
import { Domain } from 'src/engine/models/domain.model';
import { Variable } from 'src/engine/models/variable.model';
import { isNumber } from '../../../../../common/utils/shared.utils';
import { Experiment } from '../../../../models/experiment/experiment.model';
import {
TableResult,
TableStyle,
} from '../../../../models/result/table-result.model';
import BaseHandler from '../base.handler';
const NUMBER_PRECISION = 4;
const ALGO_NAME = 'linear_regression_cross_validation';
const lookupDict = {
dependent_var: 'Dependent variable',
indep_vars: 'Independent variables',
n_obs: 'Number of observations',
mean_sq_error: 'Root mean squared error',
r_squared: 'R-squared',
mean_abs_error: 'Mean absolute error',
};
export default class LinearRegressionCVHandler extends BaseHandler {
private getModel(data: any): TableResult | undefined {
return {
name: 'Data points',
tableStyle: TableStyle.NORMAL,
headers: ['', `${lookupDict['n_obs']} (${data['dependent_var']})`].map(
(name) => ({ name, type: 'string' }),
),
data: data['indep_vars'].map((name: string, i: number) => [
name,
data['n_obs'][i],
]),
};
}
private getScores(data: any): TableResult | undefined {
return {
name: 'Scores',
tableStyle: TableStyle.NORMAL,
headers: ['', 'Mean', 'Standard deviation'].map((name) => ({
name: name,
type: 'string',
})),
data: ['mean_sq_error', 'r_squared', 'mean_abs_error'].map((variable) => [
lookupDict[variable],
...data[variable].map((val: unknown) =>
isNumber(val) ? val.toPrecision(NUMBER_PRECISION) : val,
),
]),
};
}
getLabelFromVariableId(id: string, vars: Variable[]): string {
const varible = vars.find((v) => v.id === id);
return varible.label ?? id;
}
handle(experiment: Experiment, data: any, domain: Domain): void {
if (experiment.algorithm.name.toLowerCase() !== ALGO_NAME)
return super.handle(experiment, data, domain);
const varIds = [...experiment.variables, ...(experiment.coVariables ?? [])];
const variables = domain.variables.filter((v) => varIds.includes(v.id));
let jsonData = JSON.stringify(data);
variables.forEach((v) => {
const regEx = new RegExp(v.id, 'gi');
jsonData = jsonData.replaceAll(regEx, v.label);
});
const improvedData = JSON.parse(jsonData);
const model = this.getModel(improvedData);
if (model) experiment.results.push(model);
const coefs = this.getScores(improvedData);
if (coefs) experiment.results.push(coefs);
}
}
......@@ -35,7 +35,7 @@ export default class LinearRegressionHandler extends BaseHandler {
const tableModel: TableResult = {
name: 'Model',
tableStyle: TableStyle.NORMAL,
headers: ['name', 'value'].map((name) => ({ name, type: 'string' })),
headers: ['', 'name', 'value'].map((name) => ({ name, type: 'string' })),
data: [
'dependent_var',
'n_obs',
......@@ -72,7 +72,7 @@ export default class LinearRegressionHandler extends BaseHandler {
const tableCoef: TableResult = {
name: 'Coefficients',
tableStyle: TableStyle.NORMAL,
headers: keys.map((name) => ({
headers: ['', ...keys].map((name) => ({
name: lookupDict[name],
type: 'string',
})),
......
......@@ -2,21 +2,24 @@ import { Domain } from 'src/engine/models/domain.model';
import { Experiment } from '../../../../engine/models/experiment/experiment.model';
import AnovaOneWayHandler from './algorithms/anova-one-way.handler';
import DescriptiveHandler from './algorithms/descriptive.handler';
import LinearRegressionCVHandler from './algorithms/linear-regression-cv.handler';
import LinearRegressionHandler from './algorithms/linear-regression.handler';
import PCAHandler from './algorithms/PCA.handler';
import PearsonHandler from './algorithms/pearson.handler';
import RawHandler from './algorithms/raw.handler';
import ResultHandler from './result-handler.interface';
const start = new PearsonHandler();
const start = new PearsonHandler() as ResultHandler;
start
.setNext(new DescriptiveHandler())
.setNext(new AnovaOneWayHandler())
.setNext(new PCAHandler())
.setNext(new LinearRegressionHandler())
.setNext(new LinearRegressionCVHandler())
.setNext(new RawHandler()); // should be last handler as it works as a fallback (if other handlers could not process the results)
export default (exp: Experiment, data: unknown, domain: Domain): Experiment => {
start.handle(exp, data);
start.handle(exp, data, domain);
return exp;
};
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment