diff --git a/mipengine/algorithm_specification.py b/mipengine/algorithm_specification.py index 4a1f7b007963e3023450f91073c25e5ae80b0ed3..024b6b9539e5ab9dc32c5f6aa6880bf9822dcbf4 100644 --- a/mipengine/algorithm_specification.py +++ b/mipengine/algorithm_specification.py @@ -62,7 +62,7 @@ class InputDataSpecifications(ImmutableBaseModel): class ParameterEnumSpecification(ImmutableBaseModel): type: ParameterEnumType - source: Any + source: List[str] class ParameterSpecification(ImmutableBaseModel): @@ -79,11 +79,26 @@ class ParameterSpecification(ImmutableBaseModel): max: Optional[float] +def _validate_parameter_with_enums_type_fixed_var_CDE_enums(param_value, cls_values): + if len(param_value.enums.source) != 1: + raise ValueError( + f"In algorithm '{cls_values['label']}', parameter '{param_value.label}' has enums type 'fixed_var_CDE_enums' " + f"that supports only one value. Value given: {param_value.enums.source}." + ) + + def _validate_parameter_with_enums_type_input_var_CDE_enums(param_value, cls_values): - if param_value.enums.source not in ["x", "y"]: + if len(param_value.enums.source) != 1: + raise ValueError( + f"In algorithm '{cls_values['label']}', parameter '{param_value.label}' has enums type 'input_var_CDE_enums' " + f"that supports only one value. Value given: {param_value.enums.source}." + ) + + value = param_value.enums.source[0] # Only one value is allowed + if value not in ["x", "y"]: raise ValueError( f"In algorithm '{cls_values['label']}', parameter '{param_value.label}' has enums type 'input_var_CDE_enums' " - f"that supports only 'x' or 'y' as source. Value given: '{param_value.enums.source}'." + f"that supports only 'x' or 'y' as source. Value given: '{value}'." ) if param_value.multiple: raise ValueError( @@ -91,9 +106,7 @@ def _validate_parameter_with_enums_type_input_var_CDE_enums(param_value, cls_val f"that doesn't support 'multiple=True', in the parameter." ) inputdata_var = ( - cls_values["inputdata"].x - if param_value.enums.source == "x" - else cls_values["inputdata"].y + cls_values["inputdata"].x if value == "x" else cls_values["inputdata"].y ) if inputdata_var.multiple: raise ValueError( @@ -115,9 +128,11 @@ def _validate_parameter_with_enums_type_input_var_names(param_value, cls_values) def _validate_parameter_enums(param_value, cls_values): if not param_value.enums: return - if param_value.enums.type == ParameterEnumType.INPUT_VAR_CDE_ENUMS: + if param_value.enums.type == ParameterEnumType.FIXED_VAR_CDE_ENUMS: + _validate_parameter_with_enums_type_fixed_var_CDE_enums(param_value, cls_values) + elif param_value.enums.type == ParameterEnumType.INPUT_VAR_CDE_ENUMS: _validate_parameter_with_enums_type_input_var_CDE_enums(param_value, cls_values) - if param_value.enums.type == ParameterEnumType.INPUT_VAR_NAMES: + elif param_value.enums.type == ParameterEnumType.INPUT_VAR_NAMES: _validate_parameter_with_enums_type_input_var_names(param_value, cls_values) diff --git a/mipengine/algorithms/linear_regression_longitudinal.py b/mipengine/algorithms/linear_regression_longitudinal.py index 19faee426f7ac3a1043dfff51dd86941a3a45639..1008fe3f135eac5254d2019fa8a80a5cd0d6e68b 100644 --- a/mipengine/algorithms/linear_regression_longitudinal.py +++ b/mipengine/algorithms/linear_regression_longitudinal.py @@ -49,7 +49,7 @@ class LinearRegressionLongitudinal(Algorithm, algname="linear_regression_longitu notblank=True, multiple=False, enums=ParameterEnumSpecification( - type=ParameterEnumType.FIXED_VAR_CDE_ENUMS, source="visitid" + type=ParameterEnumType.FIXED_VAR_CDE_ENUMS, source=["visitid"] ), ), "visit2": ParameterSpecification( @@ -59,7 +59,7 @@ class LinearRegressionLongitudinal(Algorithm, algname="linear_regression_longitu notblank=True, multiple=False, enums=ParameterEnumSpecification( - type=ParameterEnumType.FIXED_VAR_CDE_ENUMS, source="visitid" + type=ParameterEnumType.FIXED_VAR_CDE_ENUMS, source=["visitid"] ), ), "strategies": ParameterSpecification( diff --git a/mipengine/algorithms/logistic_regression.py b/mipengine/algorithms/logistic_regression.py index 8154f445bfd70c971568e04725e17b68bd77de7a..95140364041db0ce4c25488376fb70cde75d2998 100644 --- a/mipengine/algorithms/logistic_regression.py +++ b/mipengine/algorithms/logistic_regression.py @@ -65,7 +65,7 @@ class LogisticRegressionAlgorithm(Algorithm, algname="logistic_regression"): multiple=False, enums=ParameterEnumSpecification( type=ParameterEnumType.INPUT_VAR_CDE_ENUMS, - source="y", + source=["y"], ), ), }, diff --git a/mipengine/algorithms/logistic_regression_cv.py b/mipengine/algorithms/logistic_regression_cv.py index 7be694441c7f5d5c9b345ab274e0857cb3cf1435..9fa66568910b9b26f567831ec37a081a09186a9f 100644 --- a/mipengine/algorithms/logistic_regression_cv.py +++ b/mipengine/algorithms/logistic_regression_cv.py @@ -59,7 +59,7 @@ class LogisticRegressionCVAlgorithm(Algorithm, algname="logistic_regression_cv") multiple=False, enums=ParameterEnumSpecification( type=ParameterEnumType.INPUT_VAR_CDE_ENUMS, - source="y", + source=["y"], ), ), "n_splits": ParameterSpecification( diff --git a/mipengine/controller/api/validator.py b/mipengine/controller/api/validator.py index 49ec4fb1cf539983d4400fd922e90227477dfb48..86d1883c938ee4945ab0c9b70debbb3d4076ddb9 100644 --- a/mipengine/controller/api/validator.py +++ b/mipengine/controller/api/validator.py @@ -369,17 +369,20 @@ def _validate_param_enums_of_type_fixed_var_CDE_enums( parameter_spec_label: str, data_model_cdes: Dict[str, CommonDataElement], ): - if parameter_spec_enums.source not in data_model_cdes.keys(): + param_spec_enums_source = parameter_spec_enums.source[ + 0 + ] # Fixed var CDE enums allows only one source value + if param_spec_enums_source not in data_model_cdes.keys(): raise ValueError( - f"Parameter's '{parameter_spec_label}' enums source '{parameter_spec_enums.source}' does " + f"Parameter's '{parameter_spec_label}' enums source '{param_spec_enums_source}' does " f"not exist in the data model provided." ) fixed_var_CDE_enums = list( - data_model_cdes[parameter_spec_enums.source].enumerations.keys() + data_model_cdes[param_spec_enums_source].enumerations.keys() ) if parameter_value not in fixed_var_CDE_enums: raise BadUserInput( - f"Parameter's '{parameter_spec_label}' enums, that are taken from the CDE '{parameter_spec_enums.source}', " + f"Parameter's '{parameter_spec_label}' enums, that are taken from the CDE '{param_spec_enums_source}', " f"should be one of the following: '{list(fixed_var_CDE_enums)}'." ) @@ -391,9 +394,12 @@ def _validate_param_enums_of_type_input_var_CDE_enums( inputdata: AlgorithmInputDataDTO, data_model_cdes: Dict[str, CommonDataElement], ): - if parameter_spec_enums.source == "x": + param_spec_enums_source = parameter_spec_enums.source[ + 0 + ] # Input var CDE enums allows only one source value + if param_spec_enums_source == "x": input_vars = inputdata.x - elif parameter_spec_enums.source == "y": + elif param_spec_enums_source == "y": input_vars = inputdata.y else: raise NotImplementedError(f"Source should be either 'x' or 'y'.") diff --git a/tests/standalone_tests/test_validate_algorithm_request.py b/tests/standalone_tests/test_validate_algorithm_request.py index 2742c90b66c9a3b1158ef00fc9271164fb599981..e14e990ff4c03bb9ae8943b0bba0541477ea5e6f 100644 --- a/tests/standalone_tests/test_validate_algorithm_request.py +++ b/tests/standalone_tests/test_validate_algorithm_request.py @@ -4,7 +4,6 @@ from mipengine.algorithm_specification import AlgorithmSpecification from mipengine.algorithm_specification import InputDataSpecification from mipengine.algorithm_specification import InputDataSpecifications from mipengine.algorithm_specification import ParameterEnumSpecification -from mipengine.algorithm_specification import ParameterEnumType from mipengine.algorithm_specification import ParameterSpecification from mipengine.controller.api.algorithm_request_dto import AlgorithmInputDataDTO from mipengine.controller.api.algorithm_request_dto import AlgorithmRequestDTO @@ -347,7 +346,7 @@ def algorithms_specs(): multiple=False, enums=ParameterEnumSpecification( type=ParameterEnumType.INPUT_VAR_CDE_ENUMS, - source="y", + source=["y"], ), ), "param_with_enum_type_fixed_var_CDE_enums": ParameterSpecification( @@ -358,7 +357,7 @@ def algorithms_specs(): multiple=False, enums=ParameterEnumSpecification( type=ParameterEnumType.FIXED_VAR_CDE_ENUMS, - source="text_cde_categ", + source=["text_cde_categ"], ), ), "param_with_enum_type_fixed_var_CDE_enums_wrong_CDE": ParameterSpecification( @@ -369,7 +368,7 @@ def algorithms_specs(): multiple=False, enums=ParameterEnumSpecification( type=ParameterEnumType.FIXED_VAR_CDE_ENUMS, - source="non_existing_CDE", + source=["non_existing_CDE"], ), ), "param_with_enum_type_input_var_names": ParameterSpecification( diff --git a/tests/standalone_tests/test_validate_algorithm_specifications.py b/tests/standalone_tests/test_validate_algorithm_specifications.py index db5d7e63a78b15611948016cd7fc2b5a6aa8264d..b6545ee9013ef4340ae18f59b64a5d2f6db7a334 100644 --- a/tests/standalone_tests/test_validate_algorithm_specifications.py +++ b/tests/standalone_tests/test_validate_algorithm_specifications.py @@ -42,7 +42,8 @@ def test_validate_parameter_spec_input_var_CDE_enums_source_is_x_or_y(): notblank=False, multiple=False, enums=ParameterEnumSpecification( - type=ParameterEnumType.INPUT_VAR_CDE_ENUMS, source="not_x_or_y" + type=ParameterEnumType.INPUT_VAR_CDE_ENUMS, + source=["not_x_or_y"], ), ), }, @@ -79,7 +80,7 @@ def test_validate_parameter_spec_input_var_CDE_enums_multiple_false(): notblank=False, multiple=True, enums=ParameterEnumSpecification( - type=ParameterEnumType.INPUT_VAR_CDE_ENUMS, source="y" + type=ParameterEnumType.INPUT_VAR_CDE_ENUMS, source=["y"] ), ), }, @@ -116,7 +117,7 @@ def test_validate_parameter_spec_input_var_CDE_enums_inputdata_has_multiple_fals notblank=False, multiple=False, enums=ParameterEnumSpecification( - type=ParameterEnumType.INPUT_VAR_CDE_ENUMS, source="y" + type=ParameterEnumType.INPUT_VAR_CDE_ENUMS, source=["y"] ), ), }, @@ -160,6 +161,82 @@ def test_validate_parameter_spec_input_var_names_type_must_be_text(): ) +def test_validate_parameter_spec_input_var_CDE_enums_only_one_value(): + exception_type = ValidationError + exception_message = ( + ".*In algorithm 'sample_algo', parameter 'sample_label' has enums type 'input_var_CDE_enums' " + "that supports only one value." + ) + with pytest.raises(exception_type, match=exception_message): + AlgorithmSpecification( + name="sample_algo", + desc="sample", + label="sample_algo", + enabled=True, + inputdata=InputDataSpecifications( + y=InputDataSpecification( + label="y", + desc="y", + types=[InputDataType.TEXT], + stattypes=[InputDataStatType.NOMINAL], + notblank=True, + multiple=False, + ) + ), + parameters={ + "inputdata_cde_enum_param": ParameterSpecification( + label="sample_label", + desc="sample", + types=[ParameterType.TEXT], + notblank=False, + multiple=True, + enums=ParameterEnumSpecification( + type=ParameterEnumType.INPUT_VAR_CDE_ENUMS, + source=["y", "second_value"], + ), + ), + }, + ) + + +def test_validate_parameter_spec_fixed_var_CDE_enums_only_one_value(): + exception_type = ValidationError + exception_message = ( + ".*In algorithm 'sample_algo', parameter 'sample_label' has enums type 'fixed_var_CDE_enums' " + "that supports only one value." + ) + with pytest.raises(exception_type, match=exception_message): + AlgorithmSpecification( + name="sample_algo", + desc="sample", + label="sample_algo", + enabled=True, + inputdata=InputDataSpecifications( + y=InputDataSpecification( + label="y", + desc="y", + types=[InputDataType.TEXT], + stattypes=[InputDataStatType.NOMINAL], + notblank=True, + multiple=False, + ) + ), + parameters={ + "inputdata_cde_enum_param": ParameterSpecification( + label="sample_label", + desc="sample", + types=[ParameterType.TEXT], + notblank=False, + multiple=True, + enums=ParameterEnumSpecification( + type=ParameterEnumType.FIXED_VAR_CDE_ENUMS, + source=["y", "second_value"], + ), + ), + }, + ) + + def test_validate_parameter_dict_type_given_with_other_type(): exception_type = ValidationError exception_message = (