diff --git a/api/src/engine/connectors/datashield/datashield.connector.ts b/api/src/engine/connectors/datashield/datashield.connector.ts index e9d325e722b3d3ee3fee0d44597ffabdf2914574..d91cfa45cb57ab0b89dfd263be4d43775bb522fa 100644 --- a/api/src/engine/connectors/datashield/datashield.connector.ts +++ b/api/src/engine/connectors/datashield/datashield.connector.ts @@ -8,17 +8,20 @@ import { } from 'src/common/interfaces/utilities.interface'; import { errorAxiosHandler } from 'src/common/utils/shared.utils'; import { ENGINE_MODULE_OPTIONS } from 'src/engine/engine.constants'; +import EngineService from 'src/engine/engine.service'; import ConnectorConfiguration from 'src/engine/interfaces/connector-configuration.interface'; import Connector from 'src/engine/interfaces/connector.interface'; import EngineOptions from 'src/engine/interfaces/engine-options.interface'; import { Domain } from 'src/engine/models/domain.model'; import { Algorithm } from 'src/engine/models/experiment/algorithm.model'; +import { AllowedLink } from 'src/engine/models/experiment/algorithm/nominal-parameter.model'; import { Experiment } from 'src/engine/models/experiment/experiment.model'; import { RawResult } from 'src/engine/models/result/raw-result.model'; import { TableResult, TableStyle, } from 'src/engine/models/result/table-result.model'; +import { Variable } from 'src/engine/models/variable.model'; import { ExperimentCreateInput } from 'src/experiments/models/input/experiment-create.input'; import { User } from 'src/users/models/user.model'; import { @@ -27,6 +30,8 @@ import { transformToDomain, transformToHisto, transformToTable, + transformToTableNominal, + transfoToHistoNominal as transformToHistoNominal, } from './transformations'; export default class DataShieldConnector implements Connector { @@ -35,6 +40,7 @@ export default class DataShieldConnector implements Connector { constructor( @Inject(ENGINE_MODULE_OPTIONS) private readonly options: EngineOptions, private readonly httpService: HttpService, + private readonly engineService: EngineService, ) {} getConfiguration(): ConnectorConfiguration { @@ -77,17 +83,65 @@ export default class DataShieldConnector implements Connector { } async getAlgorithms(): Promise<Algorithm[]> { - return []; + return [ + { + id: 'linear-regression', + label: 'Linear Regression', + description: + 'Linear regression analysis is a method of statistical analysis that fits a linear function in order to predict the value of a covariate as a function of one or more variables. Linear regression is a simple model that is easy to understand and interpret.', + variable: { + isRequired: true, + allowedTypes: ['number'], + hasMultiple: false, + }, + coVariable: { + isRequired: true, + allowedTypes: ['number'], + hasMultiple: true, + }, + }, + { + id: 'logistic-regression', + label: 'Logistic Regression', + description: + 'Logistic regression is a statistical method for predicting the probability of a binary event.', + variable: { + isRequired: true, + allowedTypes: ['nominal'], + hasMultiple: false, + hint: 'A binary event to predict', + }, + coVariable: { + isRequired: true, + allowedTypes: ['number'], + hasMultiple: true, + }, + parameters: [ + { + name: 'pos-level', + label: 'Positive level', + linkedTo: AllowedLink.VARIABLE, + isRequired: true, + }, + { + name: 'neg-level', + label: 'Negative level', + linkedTo: AllowedLink.VARIABLE, + isRequired: true, + }, + ], + }, + ]; } async getHistogram( - variable: string, + variable: Variable, datasets: string[], cookie?: string, ): Promise<RawResult> { const url = new URL(this.options.baseurl + `histogram`); - url.searchParams.append('var', variable); + url.searchParams.append('var', variable.id); url.searchParams.append('type', 'combine'); url.searchParams.append('cohorts', datasets.join(',')); @@ -114,10 +168,20 @@ export default class DataShieldConnector implements Connector { }; } - const title = variable.replace(/\./g, ' ').trim(); + const title = variable.label ?? variable.id; const data = { ...response.data, title }; - const chart = transformToHisto.evaluate(data); + if (variable.type === 'nominal' && variable.enumerations) { + data['lookup'] = variable.enumerations.reduce((prev, curr) => { + prev[curr.value] = curr.label; + return prev; + }, {}); + } + + const chart = + variable.type === 'nominal' + ? transformToHistoNominal.evaluate(data) + : transformToHisto.evaluate(data); return { rawdata: { @@ -128,13 +192,13 @@ export default class DataShieldConnector implements Connector { } async getDescriptiveStats( - variable: string, + variable: Variable, datasets: string[], cookie?: string, ): Promise<TableResult> { const url = new URL(this.options.baseurl + 'quantiles'); - url.searchParams.append('var', variable); + url.searchParams.append('var', variable.id); url.searchParams.append('type', 'split'); url.searchParams.append('cohorts', datasets.join(',')); @@ -148,9 +212,35 @@ export default class DataShieldConnector implements Connector { }), ); - const title = variable.replace(/\./g, ' ').trim(); + const title = variable.label ?? variable.id; const data = { ...response.data, title }; - const table = transformToTable.evaluate(data); + + const table = ( + variable.enumerations + ? transformToTableNominal.evaluate(data) + : transformToTable.evaluate(data) + ) as TableResult; + + if ( + table && + table.headers && + variable.type === 'nominal' && + variable.enumerations + ) { + table.headers = table.headers.map((header) => { + const category = variable.enumerations.find( + (v) => v.value === header.name, + ); + + if (!category || !category.label) return header; + + return { + ...header, + name: category.label, + }; + }); + } + return { ...table, tableStyle: TableStyle.DEFAULT, @@ -168,6 +258,11 @@ export default class DataShieldConnector implements Connector { const expResult: Experiment = { id: `${data.algorithm.id}-${Date.now()}`, variables: data.variables, + coVariables: data.coVariables, + author: { + username: user.username, + fullname: user.fullname ?? user.username, + }, name: data.name, domain: data.domain, datasets: data.datasets, @@ -176,21 +271,37 @@ export default class DataShieldConnector implements Connector { }, }; + const allVariablesId = [...data.variables, ...data.coVariables]; + + const allVariables = await this.engineService.getVariables( + expResult.domain, + allVariablesId, + request, + ); + switch (data.algorithm.id) { case 'MULTIPLE_HISTOGRAMS': { expResult.results = await Promise.all<RawResult>( - data.variables.map((variable) => + allVariables.map((variable) => this.getHistogram(variable, expResult.datasets, cookie), ), ); break; } case 'DESCRIPTIVE_STATS': { - expResult.results = await Promise.all<TableResult>( - [...data.variables, ...data.coVariables].map((variable) => - this.getDescriptiveStats(variable, expResult.datasets, cookie), - ), - ); + // Cannot be done in parallel because Datashield API has an issue with parallel request (response mismatching) + const results = []; + for (const variable of allVariables) { + const result = await this.getDescriptiveStats( + variable, + expResult.datasets, + cookie, + ); + + results.push(result); + } + + expResult.results = results; break; } } diff --git a/api/src/engine/connectors/datashield/transformations.ts b/api/src/engine/connectors/datashield/transformations.ts index 88c7abce38bce6f9774d123bc1255656bf017355..b1991aa7fa1e3331c8295068787f226c7a801209 100644 --- a/api/src/engine/connectors/datashield/transformations.ts +++ b/api/src/engine/connectors/datashield/transformations.ts @@ -30,6 +30,38 @@ export const transformToDomain = jsonata(` } `); +export const transfoToHistoNominal = jsonata(` +( + { + "chart": { + "type": 'column' + }, + "legend": { + "enabled": false + }, + "series": [{ + "name": "Count", + "data": global.* + }], + "title": { + "text": title ? title : '' + }, + "tooltip": { + "enabled": true + }, + "xAxis": { + "categories": $keys(global).( + $param := $lookup($$.lookup, $); + $param ? $param : $ + ) + }, + "yAxis": { + "min": 0, + "minRange": 0.1, + "allowDecimals": true + } +})`); + export const transformToHisto = jsonata(` ( $nbBreaks := $count(global.breaks); @@ -74,6 +106,21 @@ export const transformToTable = jsonata(` "name": $, "type": "string" }, + "data": $.$each(function($v, $k) { + $not($k in $params) ? $append($k,$v) : undefined + })[] +}) +`); + +export const transformToTableNominal = jsonata(` +( + $params := ["title"]; + { + "name": "Descriptive Statistics", + "headers": $append(title, $keys($.*)).{ + "name": $, + "type": "string" + }, "data": $.$each(function($v, $k) { $not($k in $params) ? $append($k,$v.*) : undefined })[] diff --git a/api/src/engine/engine.service.ts b/api/src/engine/engine.service.ts index a7713c797f50cb21442df160c46d99e64236feef..7008ebe8daefe54091c5aee824b5eadd30da3fed 100644 --- a/api/src/engine/engine.service.ts +++ b/api/src/engine/engine.service.ts @@ -29,6 +29,7 @@ import { import { ListExperiments } from './models/experiment/list-experiments.model'; import { FilterConfiguration } from './models/filter/filter-configuration'; import { FormulaOperation } from './models/formula/formula-operation.model'; +import { Variable } from './models/variable.model'; /** * Engine service. @@ -46,7 +47,7 @@ export default class EngineService implements Connector { ) { import(`./connectors/${options.type}/${options.type}.connector`).then( (conn) => { - const instance = new conn.default(options, httpService); + const instance = new conn.default(options, httpService, this); if (instance.createExperiment && instance.runExperiment) throw new InternalServerErrorException( @@ -96,7 +97,7 @@ export default class EngineService implements Connector { const result = await fn(); - this.cacheManager.set(key, result); + this.cacheManager.set(key, result, { ttl: this.cacheConf.ttl }); return result; } @@ -118,6 +119,29 @@ export default class EngineService implements Connector { ); } + /** + * It takes a domain ID and a list of variable IDs, and returns a list of variables that match the IDs + * @param {string} domainId - The domain ID of the domain you want to get variables from. + * @param {string[]} varIds - The list of variable IDs to get. + * @param {Request} request - The request object from the HTTP request. + * @returns An array of variables + */ + async getVariables( + domainId: string, + varIds: string[], + request: Request, + ): Promise<Variable[]> { + if (!domainId || varIds.length === 0) return []; + + const domains = await this.getDomains([], request); + + return ( + domains + .find((d) => d.id === domainId) + ?.variables?.filter((v) => varIds.includes(v.id)) ?? [] + ); + } + async createExperiment( data: ExperimentCreateInput, isTransient: boolean, diff --git a/api/src/engine/models/experiment/algorithm.model.ts b/api/src/engine/models/experiment/algorithm.model.ts index 761a0986d17aee99b192e939c9385637e0251f86..b697092e7d7ae4240dea01403e6ac2f12956d571 100644 --- a/api/src/engine/models/experiment/algorithm.model.ts +++ b/api/src/engine/models/experiment/algorithm.model.ts @@ -1,14 +1,18 @@ import { Field, ObjectType } from '@nestjs/graphql'; import { BaseParameter } from './algorithm/base-parameter.model'; +import { NominalParameter } from './algorithm/nominal-parameter.model'; +import { NumberParameter } from './algorithm/number-parameter.model'; import { VariableParameter } from './algorithm/variable-parameter.model'; +type Parameter = BaseParameter | NumberParameter | NominalParameter; + @ObjectType() export class Algorithm { @Field() id: string; @Field(() => [BaseParameter], { nullable: true, defaultValue: [] }) - parameters?: BaseParameter[]; + parameters?: Parameter[]; @Field(() => VariableParameter) variable?: VariableParameter; diff --git a/api/src/engine/models/experiment/algorithm/nominal-parameter.model.ts b/api/src/engine/models/experiment/algorithm/nominal-parameter.model.ts index 2008b60e545a5e91820295b5f285478df28e80fa..b52f6261578adb3db68feb42a4bd90fd4103e282 100644 --- a/api/src/engine/models/experiment/algorithm/nominal-parameter.model.ts +++ b/api/src/engine/models/experiment/algorithm/nominal-parameter.model.ts @@ -1,7 +1,7 @@ import { Field, ObjectType, registerEnumType } from '@nestjs/graphql'; import { BaseParameter } from './base-parameter.model'; -enum AllowedLink { +export enum AllowedLink { VARIABLE = 'VARIABLE', COVARIABLE = 'COVARIABLE', } diff --git a/api/src/engine/models/experiment/algorithm/variable-parameter.model.ts b/api/src/engine/models/experiment/algorithm/variable-parameter.model.ts index 0f9e55f7d4831de18538b3070b31bd8e390650e3..2ef85536371278a96d16bcb377df8c7e50a4a242 100644 --- a/api/src/engine/models/experiment/algorithm/variable-parameter.model.ts +++ b/api/src/engine/models/experiment/algorithm/variable-parameter.model.ts @@ -6,11 +6,14 @@ export class VariableParameter { hint?: string; @Field({ nullable: true, defaultValue: false }) - isRequired: boolean; + isRequired?: boolean; @Field({ nullable: true, defaultValue: false }) - hasMultiple: boolean; + hasMultiple?: boolean; - @Field(() => [String], { nullable: true }) - allowedTypes: string[]; + @Field(() => [String], { + nullable: true, + description: 'If undefined, all types are allowed', + }) + allowedTypes?: string[]; } diff --git a/api/src/schema.gql b/api/src/schema.gql index 57d7d6c62e43a152a9a073a63bc4ff86d6736327..8a6d3ba0d14b20e3c5575c86a64c0e297c84a7c0 100644 --- a/api/src/schema.gql +++ b/api/src/schema.gql @@ -147,6 +147,8 @@ type VariableParameter { hint: String isRequired: Boolean hasMultiple: Boolean + + """If undefined, all types are allowed""" allowedTypes: [String!] }