Skip to content
Snippets Groups Projects
Commit ce54d36d authored by ThanKarab's avatar ThanKarab
Browse files

Added check to validate algorithm specification don't use the same name.

https://team-1617704806227.atlassian.net/browse/MIP-838
Modified logistic regression with def average and changed it's name in the specifications.
parent 216a2630
No related branches found
No related tags found
No related merge requests found
...@@ -21,6 +21,7 @@ from exareme2.algorithms.metrics import roc_curve ...@@ -21,6 +21,7 @@ from exareme2.algorithms.metrics import roc_curve
from exareme2.algorithms.preprocessing import DummyEncoder from exareme2.algorithms.preprocessing import DummyEncoder
from exareme2.algorithms.preprocessing import LabelBinarizer from exareme2.algorithms.preprocessing import LabelBinarizer
from exareme2.algorithms.specifications import AlgorithmName from exareme2.algorithms.specifications import AlgorithmName
from exareme2.algorithms.specifications import AlgorithmSpecification
from exareme2.udfgen import relation from exareme2.udfgen import relation
from exareme2.udfgen import secure_transfer from exareme2.udfgen import secure_transfer
from exareme2.udfgen import udf from exareme2.udfgen import udf
...@@ -34,9 +35,19 @@ class LogRegCVFedAverageDataLoader(AlgorithmDataLoader, algname=ALGORITHM_NAME): ...@@ -34,9 +35,19 @@ class LogRegCVFedAverageDataLoader(AlgorithmDataLoader, algname=ALGORITHM_NAME):
class LogRegCVFedAverageAlgorithm(Algorithm, algname=ALGORITHM_NAME): class LogRegCVFedAverageAlgorithm(Algorithm, algname=ALGORITHM_NAME):
@staticmethod @classmethod
def get_specification(): def get_specification(cls):
return LogisticRegressionCVAlgorithm.get_specification() # Use the LR with CV specification but change the name
LR_with_cv_specification = LogisticRegressionCVAlgorithm.get_specification()
LR_with_cv_fedavg = AlgorithmSpecification(
name=ALGORITHM_NAME,
desc=LR_with_cv_specification.desc,
label=LR_with_cv_specification.label,
enabled=LR_with_cv_specification.enabled,
inputdata=LR_with_cv_specification.inputdata,
parameters=LR_with_cv_specification.parameters,
)
return LR_with_cv_fedavg
def run(self, data, metadata): def run(self, data, metadata):
X, y = data X, y = data
......
...@@ -3,9 +3,12 @@ from enum import Enum ...@@ -3,9 +3,12 @@ from enum import Enum
from enum import unique from enum import unique
from importlib.resources import open_text from importlib.resources import open_text
from typing import Dict from typing import Dict
from typing import List
from typing import Type
import envtoml import envtoml
from exareme2 import Algorithm
from exareme2 import AttrDict from exareme2 import AttrDict
from exareme2 import algorithm_classes from exareme2 import algorithm_classes
from exareme2 import controller from exareme2 import controller
...@@ -29,16 +32,22 @@ else: ...@@ -29,16 +32,22 @@ else:
config = AttrDict(envtoml.load(fp)) config = AttrDict(envtoml.load(fp))
def _get_algorithms_specifications() -> Dict[str, AlgorithmSpecification]: def _get_algorithms_specifications(
specs = { algorithms: List[Type[Algorithm]],
algo_name: algorithm.get_specification() ) -> Dict[str, AlgorithmSpecification]:
for algo_name, algorithm in algorithm_classes.items() specs = {}
if algorithm.get_specification().enabled for algorithm in algorithms:
} if algorithm.get_specification().enabled:
algo_name = algorithm.get_specification().name
if algo_name in specs.keys():
raise ValueError(
f"The algorithm name '{algo_name}' exists more than once in the algorithm specifications."
)
specs[algo_name] = algorithm.get_specification()
return specs return specs
algorithms_specifications = _get_algorithms_specifications() algorithms_specifications = _get_algorithms_specifications(algorithm_classes.values())
transformers_specifications = { transformers_specifications = {
LongitudinalTransformerRunner.get_transformer_name(): LongitudinalTransformerRunner.get_specification() LongitudinalTransformerRunner.get_transformer_name(): LongitudinalTransformerRunner.get_specification()
} }
...@@ -68,7 +68,7 @@ class CategoricalNBTesting_predict(Algorithm, algname=ALGNAME_PRED): ...@@ -68,7 +68,7 @@ class CategoricalNBTesting_predict(Algorithm, algname=ALGNAME_PRED):
# but remove the "n_splits" parameter since this is a CV specific parameter # but remove the "n_splits" parameter since this is a CV specific parameter
categoricalNB_with_cv_specification = CategoricalNBAlgorithm.get_specification() categoricalNB_with_cv_specification = CategoricalNBAlgorithm.get_specification()
categoricalNB_predict_specification = AlgorithmSpecification( categoricalNB_predict_specification = AlgorithmSpecification(
name=ALGNAME_FIT, name=ALGNAME_PRED,
desc=categoricalNB_with_cv_specification.desc, desc=categoricalNB_with_cv_specification.desc,
label=categoricalNB_with_cv_specification.label, label=categoricalNB_with_cv_specification.label,
enabled=categoricalNB_with_cv_specification.enabled, enabled=categoricalNB_with_cv_specification.enabled,
......
import pytest
from exareme2 import Algorithm
from exareme2.algorithms.specifications import AlgorithmSpecification
from exareme2.algorithms.specifications import InputDataSpecification
from exareme2.algorithms.specifications import InputDataSpecifications
from exareme2.algorithms.specifications import InputDataStatType
from exareme2.algorithms.specifications import InputDataType
from exareme2.controller import _get_algorithms_specifications
def test_algorithm_specifications_with_same_name_raises_error():
spec = AlgorithmSpecification(
name="algorithm_name",
desc="",
label="",
enabled=True,
inputdata=InputDataSpecifications(
y=InputDataSpecification(
label="",
desc="",
types=[InputDataType.TEXT],
stattypes=[InputDataStatType.NOMINAL],
notblank=True,
multiple=True,
)
),
parameters=None,
)
class Algorithm1(Algorithm, algname="algorithm1"):
@classmethod
def get_specification(cls):
return spec
def run(self, data, metadata):
pass
class Algorithm2(Algorithm, algname="algorithm1"):
@classmethod
def get_specification(cls):
return spec
def run(self, data, metadata):
pass
algorithms = [Algorithm1, Algorithm2]
with pytest.raises(
ValueError,
match="The algorithm name .* exists more than once in the algorithm specifications.",
):
_get_algorithms_specifications(algorithms)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment