Skip to content
Snippets Groups Projects
Commit fa77f687 authored by Steve Reis's avatar Steve Reis
Browse files

Feat: Integration nominal variable result and algorithms list

parent acf0ee15
No related branches found
No related tags found
No related merge requests found
......@@ -8,17 +8,20 @@ import {
} from 'src/common/interfaces/utilities.interface';
import { errorAxiosHandler } from 'src/common/utils/shared.utils';
import { ENGINE_MODULE_OPTIONS } from 'src/engine/engine.constants';
import EngineService from 'src/engine/engine.service';
import ConnectorConfiguration from 'src/engine/interfaces/connector-configuration.interface';
import Connector from 'src/engine/interfaces/connector.interface';
import EngineOptions from 'src/engine/interfaces/engine-options.interface';
import { Domain } from 'src/engine/models/domain.model';
import { Algorithm } from 'src/engine/models/experiment/algorithm.model';
import { AllowedLink } from 'src/engine/models/experiment/algorithm/nominal-parameter.model';
import { Experiment } from 'src/engine/models/experiment/experiment.model';
import { RawResult } from 'src/engine/models/result/raw-result.model';
import {
TableResult,
TableStyle,
} from 'src/engine/models/result/table-result.model';
import { Variable } from 'src/engine/models/variable.model';
import { ExperimentCreateInput } from 'src/experiments/models/input/experiment-create.input';
import { User } from 'src/users/models/user.model';
import {
......@@ -27,6 +30,8 @@ import {
transformToDomain,
transformToHisto,
transformToTable,
transformToTableNominal,
transfoToHistoNominal as transformToHistoNominal,
} from './transformations';
export default class DataShieldConnector implements Connector {
......@@ -35,6 +40,7 @@ export default class DataShieldConnector implements Connector {
constructor(
@Inject(ENGINE_MODULE_OPTIONS) private readonly options: EngineOptions,
private readonly httpService: HttpService,
private readonly engineService: EngineService,
) {}
getConfiguration(): ConnectorConfiguration {
......@@ -77,17 +83,65 @@ export default class DataShieldConnector implements Connector {
}
async getAlgorithms(): Promise<Algorithm[]> {
return [];
return [
{
id: 'linear-regression',
label: 'Linear Regression',
description:
'Linear regression analysis is a method of statistical analysis that fits a linear function in order to predict the value of a covariate as a function of one or more variables. Linear regression is a simple model that is easy to understand and interpret.',
variable: {
isRequired: true,
allowedTypes: ['number'],
hasMultiple: false,
},
coVariable: {
isRequired: true,
allowedTypes: ['number'],
hasMultiple: true,
},
},
{
id: 'logistic-regression',
label: 'Logistic Regression',
description:
'Logistic regression is a statistical method for predicting the probability of a binary event.',
variable: {
isRequired: true,
allowedTypes: ['nominal'],
hasMultiple: false,
hint: 'A binary event to predict',
},
coVariable: {
isRequired: true,
allowedTypes: ['number'],
hasMultiple: true,
},
parameters: [
{
name: 'pos-level',
label: 'Positive level',
linkedTo: AllowedLink.VARIABLE,
isRequired: true,
},
{
name: 'neg-level',
label: 'Negative level',
linkedTo: AllowedLink.VARIABLE,
isRequired: true,
},
],
},
];
}
async getHistogram(
variable: string,
variable: Variable,
datasets: string[],
cookie?: string,
): Promise<RawResult> {
const url = new URL(this.options.baseurl + `histogram`);
url.searchParams.append('var', variable);
url.searchParams.append('var', variable.id);
url.searchParams.append('type', 'combine');
url.searchParams.append('cohorts', datasets.join(','));
......@@ -114,10 +168,20 @@ export default class DataShieldConnector implements Connector {
};
}
const title = variable.replace(/\./g, ' ').trim();
const title = variable.label ?? variable.id;
const data = { ...response.data, title };
const chart = transformToHisto.evaluate(data);
if (variable.type === 'nominal' && variable.enumerations) {
data['lookup'] = variable.enumerations.reduce((prev, curr) => {
prev[curr.value] = curr.label;
return prev;
}, {});
}
const chart =
variable.type === 'nominal'
? transformToHistoNominal.evaluate(data)
: transformToHisto.evaluate(data);
return {
rawdata: {
......@@ -128,13 +192,13 @@ export default class DataShieldConnector implements Connector {
}
async getDescriptiveStats(
variable: string,
variable: Variable,
datasets: string[],
cookie?: string,
): Promise<TableResult> {
const url = new URL(this.options.baseurl + 'quantiles');
url.searchParams.append('var', variable);
url.searchParams.append('var', variable.id);
url.searchParams.append('type', 'split');
url.searchParams.append('cohorts', datasets.join(','));
......@@ -148,9 +212,35 @@ export default class DataShieldConnector implements Connector {
}),
);
const title = variable.replace(/\./g, ' ').trim();
const title = variable.label ?? variable.id;
const data = { ...response.data, title };
const table = transformToTable.evaluate(data);
const table = (
variable.enumerations
? transformToTableNominal.evaluate(data)
: transformToTable.evaluate(data)
) as TableResult;
if (
table &&
table.headers &&
variable.type === 'nominal' &&
variable.enumerations
) {
table.headers = table.headers.map((header) => {
const category = variable.enumerations.find(
(v) => v.value === header.name,
);
if (!category || !category.label) return header;
return {
...header,
name: category.label,
};
});
}
return {
...table,
tableStyle: TableStyle.DEFAULT,
......@@ -168,6 +258,11 @@ export default class DataShieldConnector implements Connector {
const expResult: Experiment = {
id: `${data.algorithm.id}-${Date.now()}`,
variables: data.variables,
coVariables: data.coVariables,
author: {
username: user.username,
fullname: user.fullname ?? user.username,
},
name: data.name,
domain: data.domain,
datasets: data.datasets,
......@@ -176,21 +271,37 @@ export default class DataShieldConnector implements Connector {
},
};
const allVariablesId = [...data.variables, ...data.coVariables];
const allVariables = await this.engineService.getVariables(
expResult.domain,
allVariablesId,
request,
);
switch (data.algorithm.id) {
case 'MULTIPLE_HISTOGRAMS': {
expResult.results = await Promise.all<RawResult>(
data.variables.map((variable) =>
allVariables.map((variable) =>
this.getHistogram(variable, expResult.datasets, cookie),
),
);
break;
}
case 'DESCRIPTIVE_STATS': {
expResult.results = await Promise.all<TableResult>(
[...data.variables, ...data.coVariables].map((variable) =>
this.getDescriptiveStats(variable, expResult.datasets, cookie),
),
);
// Cannot be done in parallel because Datashield API has an issue with parallel request (response mismatching)
const results = [];
for (const variable of allVariables) {
const result = await this.getDescriptiveStats(
variable,
expResult.datasets,
cookie,
);
results.push(result);
}
expResult.results = results;
break;
}
}
......
......@@ -30,6 +30,38 @@ export const transformToDomain = jsonata(`
}
`);
export const transfoToHistoNominal = jsonata(`
(
{
"chart": {
"type": 'column'
},
"legend": {
"enabled": false
},
"series": [{
"name": "Count",
"data": global.*
}],
"title": {
"text": title ? title : ''
},
"tooltip": {
"enabled": true
},
"xAxis": {
"categories": $keys(global).(
$param := $lookup($$.lookup, $);
$param ? $param : $
)
},
"yAxis": {
"min": 0,
"minRange": 0.1,
"allowDecimals": true
}
})`);
export const transformToHisto = jsonata(`
(
$nbBreaks := $count(global.breaks);
......@@ -74,6 +106,21 @@ export const transformToTable = jsonata(`
"name": $,
"type": "string"
},
"data": $.$each(function($v, $k) {
$not($k in $params) ? $append($k,$v) : undefined
})[]
})
`);
export const transformToTableNominal = jsonata(`
(
$params := ["title"];
{
"name": "Descriptive Statistics",
"headers": $append(title, $keys($.*)).{
"name": $,
"type": "string"
},
"data": $.$each(function($v, $k) {
$not($k in $params) ? $append($k,$v.*) : undefined
})[]
......
......@@ -29,6 +29,7 @@ import {
import { ListExperiments } from './models/experiment/list-experiments.model';
import { FilterConfiguration } from './models/filter/filter-configuration';
import { FormulaOperation } from './models/formula/formula-operation.model';
import { Variable } from './models/variable.model';
/**
* Engine service.
......@@ -46,7 +47,7 @@ export default class EngineService implements Connector {
) {
import(`./connectors/${options.type}/${options.type}.connector`).then(
(conn) => {
const instance = new conn.default(options, httpService);
const instance = new conn.default(options, httpService, this);
if (instance.createExperiment && instance.runExperiment)
throw new InternalServerErrorException(
......@@ -96,7 +97,7 @@ export default class EngineService implements Connector {
const result = await fn();
this.cacheManager.set(key, result);
this.cacheManager.set(key, result, { ttl: this.cacheConf.ttl });
return result;
}
......@@ -118,6 +119,29 @@ export default class EngineService implements Connector {
);
}
/**
* It takes a domain ID and a list of variable IDs, and returns a list of variables that match the IDs
* @param {string} domainId - The domain ID of the domain you want to get variables from.
* @param {string[]} varIds - The list of variable IDs to get.
* @param {Request} request - The request object from the HTTP request.
* @returns An array of variables
*/
async getVariables(
domainId: string,
varIds: string[],
request: Request,
): Promise<Variable[]> {
if (!domainId || varIds.length === 0) return [];
const domains = await this.getDomains([], request);
return (
domains
.find((d) => d.id === domainId)
?.variables?.filter((v) => varIds.includes(v.id)) ?? []
);
}
async createExperiment(
data: ExperimentCreateInput,
isTransient: boolean,
......
import { Field, ObjectType } from '@nestjs/graphql';
import { BaseParameter } from './algorithm/base-parameter.model';
import { NominalParameter } from './algorithm/nominal-parameter.model';
import { NumberParameter } from './algorithm/number-parameter.model';
import { VariableParameter } from './algorithm/variable-parameter.model';
type Parameter = BaseParameter | NumberParameter | NominalParameter;
@ObjectType()
export class Algorithm {
@Field()
id: string;
@Field(() => [BaseParameter], { nullable: true, defaultValue: [] })
parameters?: BaseParameter[];
parameters?: Parameter[];
@Field(() => VariableParameter)
variable?: VariableParameter;
......
import { Field, ObjectType, registerEnumType } from '@nestjs/graphql';
import { BaseParameter } from './base-parameter.model';
enum AllowedLink {
export enum AllowedLink {
VARIABLE = 'VARIABLE',
COVARIABLE = 'COVARIABLE',
}
......
......@@ -6,11 +6,14 @@ export class VariableParameter {
hint?: string;
@Field({ nullable: true, defaultValue: false })
isRequired: boolean;
isRequired?: boolean;
@Field({ nullable: true, defaultValue: false })
hasMultiple: boolean;
hasMultiple?: boolean;
@Field(() => [String], { nullable: true })
allowedTypes: string[];
@Field(() => [String], {
nullable: true,
description: 'If undefined, all types are allowed',
})
allowedTypes?: string[];
}
......@@ -147,6 +147,8 @@ type VariableParameter {
hint: String
isRequired: Boolean
hasMultiple: Boolean
"""If undefined, all types are allowed"""
allowedTypes: [String!]
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment