diff --git a/.github/workflows/smpc_env_tests.yml b/.github/workflows/smpc_env_tests.yml index b3ecded4e2c40fd8b5e9b16c009f90abf5306ff3..e5e1a66feff38bc9ba8ef425dc77625c15787534 100644 --- a/.github/workflows/smpc_env_tests.yml +++ b/.github/workflows/smpc_env_tests.yml @@ -159,6 +159,7 @@ jobs: run: | kubectl taint nodes master node-role.kubernetes.io/master- kubectl label node master nodeType=master + kubectl label node master smpcType=player kubectl label node localnode1 nodeType=worker kubectl label node localnode2 nodeType=worker kubectl label node localnode1 smpcType=player @@ -349,6 +350,6 @@ jobs: with: run: kubectl logs -l nodeType=localnode --tail -1 -c smpc-client -# TODO SMPC currently doesn't support decimals +# SMPC tests currently hang due to https://team-1617704806227.atlassian.net/browse/MIP-608 # - name: Run the first 5 algorithm validation tests from each algorithm -# run: poetry run pytest tests/algorithm_validation_tests/ -k "input0- or input1- or input2- or input3- or input4-" -vvvv +# run: timeout 60 poetry run pytest tests/algorithm_validation_tests/ -k "input0- or input1- or input2- or input3- or input4-" -vvvv -x diff --git a/kubernetes/templates/mipengine-controller.yaml b/kubernetes/templates/mipengine-controller.yaml index 6250169c2174a6e51000f4370fb3fed005fda590..6aebfd49846b0fbcef055560bbbc6354f97c4556 100644 --- a/kubernetes/templates/mipengine-controller.yaml +++ b/kubernetes/templates/mipengine-controller.yaml @@ -37,15 +37,15 @@ spec: - name: DEPLOYMENT_TYPE value: "KUBERNETES" - name: NODE_LANDSCAPE_AGGREGATOR_UPDATE_INTERVAL - value: "{{ .Values.controller.node_landscape_aggregator_update_interval }}" + value: {{ quote .Values.controller.node_landscape_aggregator_update_interval }} - name: NODES_CLEANUP_INTERVAL - value: "{{ .Values.controller.nodes_cleanup_interval }}" + value: {{ quote .Values.controller.nodes_cleanup_interval }} - name: NODES_CLEANUP_CONTEXTID_RELEASE_TIMELIMIT value: "86400" # One day in seconds - name: CELERY_TASKS_TIMEOUT - value: "{{ .Values.controller.celery_tasks_timeout }}" + value: {{ quote .Values.controller.celery_tasks_timeout }} - name: CELERY_RUN_UDF_TASK_TIMEOUT - value: "{{ .Values.controller.celery_run_udf_task_timeout }}" + value: {{ quote .Values.controller.celery_run_udf_task_timeout }} - name: LOCALNODES_DNS value: "mipengine-nodes-service" - name: LOCALNODES_PORT @@ -55,6 +55,16 @@ spec: {{ if .Values.smpc.enabled }} - name: SMPC_OPTIONAL value: {{ quote .Values.smpc.optional }} + - name: SMPC_COORDINATOR_IP + valueFrom: + fieldRef: + fieldPath: status.podIP + - name: SMPC_COORDINATOR_ADDRESS + value: "http://$(SMPC_COORDINATOR_IP):12314" + - name: SMPC_GET_RESULT_INTERVAL + value: {{ quote .Values.smpc.get_result_interval }} + - name: SMPC_GET_RESULT_MAX_RETRIES + value: {{ quote .Values.smpc.get_result_max_retries }} {{ end }} ### --- SMPC components --- @@ -85,7 +95,7 @@ spec: imagePullPolicy: IfNotPresent command: ["python", "coordinator.py"] ports: - - containerPort: 12134 + - containerPort: 12314 env: - name: POD_IP valueFrom: diff --git a/kubernetes/templates/mipengine-localnode.yaml b/kubernetes/templates/mipengine-localnode.yaml index 4193e4ccd313c52038760979979472f0742bcb21..bc0cf7022dfb866540ba98d4a0e3c3f96eeca6cb 100644 --- a/kubernetes/templates/mipengine-localnode.yaml +++ b/kubernetes/templates/mipengine-localnode.yaml @@ -117,7 +117,7 @@ spec: fieldRef: fieldPath: status.podIP - name: SMPC_CLIENT_ADDRESS - value: "http://$(SMPC_CLIENT_IP)9000" + value: "http://$(SMPC_CLIENT_IP):9000" {{ end }} {{ if .Values.smpc.enabled }} diff --git a/kubernetes/values.yaml b/kubernetes/values.yaml index 6f3e225d7e81a944758938a8091989827d7d5068..ce00a72264c1227e55252fc218dbdc5e3271ac1b 100644 --- a/kubernetes/values.yaml +++ b/kubernetes/values.yaml @@ -1,4 +1,4 @@ -localnodes: 3 +localnodes: 2 mipengine_images: repository: madgik @@ -18,8 +18,10 @@ controller: cleanup_file_folder: /opt/cleanup smpc: - enabled: false + enabled: true optional: false image: gpikra/coordinator:v6.0.0 db_image: mongo:5.0.8 queue_image: redis:alpine3.15 + get_result_interval: 5 + get_result_max_retries: 100 diff --git a/mipengine/algorithms/linear_regression.py b/mipengine/algorithms/linear_regression.py index 67e8ff74ebcef539f935d0c4b1c546769d8e3643..247e0a2a73071821fc3c960307eb9ca67b159f59 100644 --- a/mipengine/algorithms/linear_regression.py +++ b/mipengine/algorithms/linear_regression.py @@ -108,10 +108,18 @@ class LinearRegression: sy = float(y.sum()) n_obs = len(y) stransfer = {} - stransfer["xTx"] = {"data": xTx.to_numpy().tolist(), "operation": "sum"} - stransfer["xTy"] = {"data": xTy.to_numpy().tolist(), "operation": "sum"} - stransfer["sy"] = {"data": sy, "operation": "sum"} - stransfer["n_obs"] = {"data": n_obs, "operation": "sum"} + stransfer["xTx"] = { + "data": xTx.to_numpy().tolist(), + "operation": "sum", + "type": "float", + } + stransfer["xTy"] = { + "data": xTy.to_numpy().tolist(), + "operation": "sum", + "type": "float", + } + stransfer["sy"] = {"data": sy, "operation": "sum", "type": "float"} + stransfer["n_obs"] = {"data": n_obs, "operation": "sum", "type": "int"} return stransfer @staticmethod @@ -201,8 +209,8 @@ class LinearRegression: tss = float(sum((y - y_mean) ** 2)) stransfer = {} - stransfer["rss"] = {"data": rss, "operation": "sum"} - stransfer["tss"] = {"data": tss, "operation": "sum"} + stransfer["rss"] = {"data": rss, "operation": "sum", "type": "float"} + stransfer["tss"] = {"data": tss, "operation": "sum", "type": "float"} return stransfer @staticmethod diff --git a/mipengine/algorithms/one_way_anova.py b/mipengine/algorithms/one_way_anova.py index a63e66f233b1e2e4b985e15cb1f3999605a16c16..b346b6806c76b49df3c3066672678de114a3b96d 100644 --- a/mipengine/algorithms/one_way_anova.py +++ b/mipengine/algorithms/one_way_anova.py @@ -106,7 +106,7 @@ T = TypeVar("T") y=relation(schema=S), x=relation(schema=T), covar_enums=literal(), - return_type=[secure_transfer(sum_op=True), transfer()], + return_type=[secure_transfer(sum_op=True, min_op=True, max_op=True), transfer()], ) def local1(y, x, covar_enums): import sys @@ -159,35 +159,46 @@ def local1(y, x, covar_enums): group_stats_df = group_stats.append(diff_df) sec_transfer_ = {} - sec_transfer_["n_obs"] = {"data": n_obs, "operation": "sum"} + sec_transfer_["n_obs"] = {"data": n_obs, "operation": "sum", "type": "int"} sec_transfer_["overall_stats_sum"] = { "data": overall_stats["sum"].tolist(), "operation": "sum", + "type": "float", } sec_transfer_["overall_stats_count"] = { "data": overall_stats["count"].tolist(), "operation": "sum", + "type": "float", + } + sec_transfer_["overall_ssq"] = { + "data": overall_ssq.item(), + "operation": "sum", + "type": "float", } - sec_transfer_["overall_ssq"] = {"data": overall_ssq.item(), "operation": "sum"} sec_transfer_["group_stats_sum"] = { "data": group_stats_df["sum"].tolist(), "operation": "sum", + "type": "float", } sec_transfer_["group_stats_count"] = { "data": group_stats_df["count"].tolist(), "operation": "sum", + "type": "float", } sec_transfer_["group_stats_ssq"] = { "data": group_stats_df["group_ssq"].tolist(), "operation": "sum", + "type": "float", } sec_transfer_["min_per_group"] = { "data": group_stats_df["min_per_group"].tolist(), "operation": "min", + "type": "float", } sec_transfer_["max_per_group"] = { "data": group_stats_df["max_per_group"].tolist(), "operation": "max", + "type": "float", } transfer_ = { "var_label": var_label, diff --git a/mipengine/algorithms/pca.py b/mipengine/algorithms/pca.py index 20de55d313488dddc45fd63c07222d7797952c99..3c8ea6948ba9fbe3fb3eebb2cb315d5ee6290982 100644 --- a/mipengine/algorithms/pca.py +++ b/mipengine/algorithms/pca.py @@ -70,9 +70,9 @@ def local1(x): sxx = (x**2).sum(axis=0) transfer_ = {} - transfer_["n_obs"] = {"data": n_obs, "operation": "sum"} - transfer_["sx"] = {"data": sx.tolist(), "operation": "sum"} - transfer_["sxx"] = {"data": sxx.tolist(), "operation": "sum"} + transfer_["n_obs"] = {"data": n_obs, "operation": "sum", "type": "int"} + transfer_["sx"] = {"data": sx.tolist(), "operation": "sum", "type": "float"} + transfer_["sxx"] = {"data": sxx.tolist(), "operation": "sum", "type": "float"} return transfer_ @@ -103,7 +103,13 @@ def local2(x, global_transfer): x /= sigmas gramian = x.T @ x - transfer_ = {"gramian": {"data": gramian.values.tolist(), "operation": "sum"}} + transfer_ = { + "gramian": { + "data": gramian.values.tolist(), + "operation": "sum", + "type": "float", + } + } return transfer_ diff --git a/mipengine/algorithms/pearson.py b/mipengine/algorithms/pearson.py index 6f7dd5d7e6be7ce3d9e69b04e780f828c5213c11..b59a37007055ee03b8e502ccf518f3d4befc1dfb 100644 --- a/mipengine/algorithms/pearson.py +++ b/mipengine/algorithms/pearson.py @@ -121,12 +121,12 @@ def local1(y, x): syy = (Y**2).sum(axis=0) transfer_ = {} - transfer_["n_obs"] = {"data": n_obs, "operation": "sum"} - transfer_["sx"] = {"data": sx.tolist(), "operation": "sum"} - transfer_["sxx"] = {"data": sxx.tolist(), "operation": "sum"} - transfer_["sxy"] = {"data": sxy.tolist(), "operation": "sum"} - transfer_["sy"] = {"data": sy.tolist(), "operation": "sum"} - transfer_["syy"] = {"data": syy.tolist(), "operation": "sum"} + transfer_["n_obs"] = {"data": n_obs, "operation": "sum", "type": "int"} + transfer_["sx"] = {"data": sx.tolist(), "operation": "sum", "type": "float"} + transfer_["sxx"] = {"data": sxx.tolist(), "operation": "sum", "type": "float"} + transfer_["sxy"] = {"data": sxy.tolist(), "operation": "sum", "type": "float"} + transfer_["sy"] = {"data": sy.tolist(), "operation": "sum", "type": "float"} + transfer_["syy"] = {"data": syy.tolist(), "operation": "sum", "type": "float"} return transfer_ diff --git a/mipengine/controller/algorithm_executor.py b/mipengine/controller/algorithm_executor.py index 5f9bc9428ffba0d6d0ed945ec3a2a2287b18f4db..c0c21f312237e5fff3abcac821cd98be044d5bdf 100644 --- a/mipengine/controller/algorithm_executor.py +++ b/mipengine/controller/algorithm_executor.py @@ -450,7 +450,7 @@ class _AlgorithmExecutionInterface: command_id, local_nodes_smpc_tables ) - (sum_op, min_op, max_op, union_op) = trigger_smpc_operations( + (sum_op, min_op, max_op) = trigger_smpc_operations( logger=self._logger, context_id=self._global_node.context_id, command_id=command_id, @@ -464,14 +464,12 @@ class _AlgorithmExecutionInterface: sum_op=sum_op, min_op=min_op, max_op=max_op, - union_op=union_op, ) ( sum_op_result_table, min_op_result_table, max_op_result_table, - union_op_result_table, ) = get_smpc_results( node=self._global_node, context_id=self._global_node.context_id, @@ -479,7 +477,6 @@ class _AlgorithmExecutionInterface: sum_op=sum_op, min_op=min_op, max_op=max_op, - union_op=union_op, ) return GlobalNodeSMPCTables( @@ -489,7 +486,6 @@ class _AlgorithmExecutionInterface: sum_op=sum_op_result_table, min_op=min_op_result_table, max_op=max_op_result_table, - union_op=union_op_result_table, ), ) @@ -720,9 +716,6 @@ class _SingleLocalNodeAlgorithmExecutionInterface(_AlgorithmExecutionInterface): sum_op=local_nodes_data.nodes_smpc_tables[self._global_node].sum_op, min_op=local_nodes_data.nodes_smpc_tables[self._global_node].min_op, max_op=local_nodes_data.nodes_smpc_tables[self._global_node].max_op, - union_op=local_nodes_data.nodes_smpc_tables[ - self._global_node - ].union_op, ), ) diff --git a/mipengine/controller/algorithm_executor_node_data_objects.py b/mipengine/controller/algorithm_executor_node_data_objects.py index 6ed7455629f684e8e1da2717f02a110b0454941a..fdb2be4cc95a7d03f83ae4fb846efd511b8c8cc1 100644 --- a/mipengine/controller/algorithm_executor_node_data_objects.py +++ b/mipengine/controller/algorithm_executor_node_data_objects.py @@ -65,14 +65,12 @@ class SMPCTableNames(NodeData): sum_op: TableName min_op: TableName max_op: TableName - union_op: TableName - def __init__(self, template, sum_op, min_op, max_op, union_op): + def __init__(self, template, sum_op, min_op, max_op): self.template = template self.sum_op = sum_op self.min_op = min_op self.max_op = max_op - self.union_op = union_op def __repr__(self): return self.full_table_name diff --git a/mipengine/controller/algorithm_executor_nodes.py b/mipengine/controller/algorithm_executor_nodes.py index 8e08c1c69f7999d64895ea85c3886e472973731a..77022d612382994f23e6287d3989355566e41bf8 100644 --- a/mipengine/controller/algorithm_executor_nodes.py +++ b/mipengine/controller/algorithm_executor_nodes.py @@ -289,9 +289,6 @@ class LocalNode(_Node): max_op=create_node_table_from_node_table_dto( result.value.max_op_values ), - union_op=create_node_table_from_node_table_dto( - result.value.union_op_values - ), ) ) else: diff --git a/mipengine/controller/algorithm_executor_smpc_helper.py b/mipengine/controller/algorithm_executor_smpc_helper.py index 19d3efd51c708baa90eead5c306d20612e321670..8000306feccc94042bf1c755df8127d88f172fad 100644 --- a/mipengine/controller/algorithm_executor_smpc_helper.py +++ b/mipengine/controller/algorithm_executor_smpc_helper.py @@ -52,14 +52,10 @@ def load_data_to_smpc_clients( max_op_smpc_clients = load_operation_data_to_smpc_clients( command_id, smpc_tables.max_op_local_nodes_table, SMPCRequestType.MAX ) - union_op_smpc_clients = load_operation_data_to_smpc_clients( - command_id, smpc_tables.union_op_local_nodes_table, SMPCRequestType.UNION - ) return ( sum_op_smpc_clients, min_op_smpc_clients, max_op_smpc_clients, - union_op_smpc_clients, ) @@ -89,13 +85,12 @@ def trigger_smpc_operations( logger: Logger, context_id: str, command_id: int, - smpc_clients_per_op: Tuple[List[str], List[str], List[str], List[str]], -) -> Tuple[bool, bool, bool, bool]: + smpc_clients_per_op: Tuple[List[str], List[str], List[str]], +) -> Tuple[bool, bool, bool]: ( sum_op_smpc_clients, min_op_smpc_clients, max_op_smpc_clients, - union_op_smpc_clients, ) = smpc_clients_per_op sum_op = trigger_smpc_operation( logger, context_id, command_id, SMPCRequestType.SUM, sum_op_smpc_clients @@ -106,10 +101,7 @@ def trigger_smpc_operations( max_op = trigger_smpc_operation( logger, context_id, command_id, SMPCRequestType.MAX, max_op_smpc_clients ) - union_op = trigger_smpc_operation( - logger, context_id, command_id, SMPCRequestType.UNION, union_op_smpc_clients - ) - return sum_op, min_op, max_op, union_op + return sum_op, min_op, max_op def wait_for_smpc_result_to_be_ready( @@ -163,7 +155,6 @@ def wait_for_smpc_results_to_be_ready( sum_op: bool, min_op: bool, max_op: bool, - union_op: bool, ): wait_for_smpc_result_to_be_ready( logger, context_id, command_id, SMPCRequestType.SUM @@ -174,9 +165,6 @@ def wait_for_smpc_results_to_be_ready( wait_for_smpc_result_to_be_ready( logger, context_id, command_id, SMPCRequestType.MAX ) if max_op else None - wait_for_smpc_result_to_be_ready( - logger, context_id, command_id, SMPCRequestType.UNION - ) if union_op else None def get_smpc_results( @@ -186,8 +174,7 @@ def get_smpc_results( sum_op: bool, min_op: bool, max_op: bool, - union_op: bool, -) -> Tuple[TableName, TableName, TableName, TableName]: +) -> Tuple[TableName, TableName, TableName]: sum_op_result_table = ( TableName( table_name=node.get_smpc_result( @@ -233,25 +220,9 @@ def get_smpc_results( if max_op else None ) - union_op_result_table = ( - TableName( - table_name=node.get_smpc_result( - jobid=get_smpc_job_id( - context_id=context_id, - command_id=command_id, - operation=SMPCRequestType.UNION, - ), - command_id=str(command_id), - command_subid="3", - ) - ) - if union_op - else None - ) return ( sum_op_result_table, min_op_result_table, max_op_result_table, - union_op_result_table, ) diff --git a/mipengine/controller/algorithm_flow_data_objects.py b/mipengine/controller/algorithm_flow_data_objects.py index ef7f68dd8c3ae8c9adc832dfde2c1869056dad75..5894993ac61c028fcdf469b310811618b578d868 100644 --- a/mipengine/controller/algorithm_flow_data_objects.py +++ b/mipengine/controller/algorithm_flow_data_objects.py @@ -188,15 +188,6 @@ class LocalNodesSMPCTables(LocalNodesData): nodes_tables[node] = tables.max_op return LocalNodesTable(nodes_tables) - @property - def union_op_local_nodes_table(self) -> Optional[LocalNodesTable]: - nodes_tables = {} - for node, tables in self.nodes_smpc_tables.items(): - if not tables.union_op: - return None - nodes_tables[node] = tables.union_op - return LocalNodesTable(nodes_tables) - class GlobalNodeSMPCTables(GlobalNodeData): _node: GlobalNode @@ -286,9 +277,6 @@ def _algoexec_udf_arg_to_node_udf_arg( max_op_values=create_node_table_dto_from_global_node_table( algoexec_arg.smpc_tables.max_op ), - union_op_values=create_node_table_dto_from_global_node_table( - algoexec_arg.smpc_tables.union_op - ), ) ) else: diff --git a/mipengine/controller/api/error_handlers.py b/mipengine/controller/api/error_handlers.py index 40633651794eaa2958dcde4270669d93dbd45033..4ba823a07095e2add6d17da07301f6d3b0c87385 100644 --- a/mipengine/controller/api/error_handlers.py +++ b/mipengine/controller/api/error_handlers.py @@ -66,7 +66,7 @@ def handle_privacy_error(error: InsufficientDataError): @error_handlers.app_errorhandler(SMPCUsageError) -def handle_privacy_error(error: SMPCUsageError): +def handle_smpc_error(error: SMPCUsageError): return error.message, HTTPStatusCode.SMPC_USAGE_ERROR diff --git a/mipengine/node/tasks/smpc.py b/mipengine/node/tasks/smpc.py index 751142f2bec7cff85a3bcb4fb8c8a722f06e3dfb..dff0446de74650636c78524386e89ba43f77e688 100644 --- a/mipengine/node/tasks/smpc.py +++ b/mipengine/node/tasks/smpc.py @@ -17,7 +17,7 @@ from mipengine.node_tasks_DTOs import ColumnInfo from mipengine.node_tasks_DTOs import TableSchema from mipengine.node_tasks_DTOs import TableType from mipengine.smpc_cluster_comm_helpers import SMPCComputationError -from mipengine.smpc_DTOs import SMPCRequestType +from mipengine.smpc_cluster_comm_helpers import SMPCUsageError from mipengine.smpc_DTOs import SMPCResponseWithOutput from mipengine.table_data_DTOs import ColumnData @@ -172,4 +172,8 @@ def _create_smpc_results_table( def _get_smpc_values_from_table_data(table_data: List[ColumnData]): node_id_column, values_column = table_data + + if not values_column.data: + raise SMPCUsageError("A node doesn't have data to contribute to the SMPC.") + return values_column.data diff --git a/mipengine/node/tasks/udfs.py b/mipengine/node/tasks/udfs.py index 1a19b933f8eff6f7325ce312d310b19c157ca25b..00e33c7b1f547b143bf894ea9aabd2e7e09538ba 100644 --- a/mipengine/node/tasks/udfs.py +++ b/mipengine/node/tasks/udfs.py @@ -186,17 +186,11 @@ def _convert_smpc_udf2udfgen_arg(udf_argument: NodeSMPCDTO): if udf_argument.value.max_op_values else None ) - union_op = ( - _create_table_info_from_tablename(udf_argument.value.union_op_values.value) - if udf_argument.value.union_op_values - else None - ) return SMPCTablesInfo( template=template, sum_op_values=sum_op, min_op_values=min_op, max_op_values=max_op, - union_op_values=union_op, ) @@ -274,10 +268,6 @@ def _get_all_table_results_from_smpc_result( table_results.append( smpc_result.max_op_values ) if smpc_result.max_op_values else None - table_results.append( - smpc_result.union_op_values - ) if smpc_result.union_op_values else None - return table_results @@ -375,21 +365,12 @@ def _convert_udfgen2udf_smpc_result_and_mapping( else: max_op_udf_result = None - if udfgen_result.union_op_values: - (union_op_udf_result, mapping,) = _convert_udfgen2udf_table_result_and_mapping( - udfgen_result.union_op_values, context_id, command_id, command_subid + 4 - ) - table_names_tmpl_mapping.update(mapping) - else: - union_op_udf_result = None - result = NodeSMPCDTO( value=NodeSMPCValueDTO( template=template_udf_result, sum_op_values=sum_op_udf_result, min_op_values=min_op_udf_result, max_op_values=max_op_udf_result, - union_op_values=union_op_udf_result, ) ) return result, table_names_tmpl_mapping diff --git a/mipengine/node_tasks_DTOs.py b/mipengine/node_tasks_DTOs.py index 60cc53549668890374d4cf87238a64fc46023443..2241ac12746e2d48c9e153e4b6499d9ee997b5b3 100644 --- a/mipengine/node_tasks_DTOs.py +++ b/mipengine/node_tasks_DTOs.py @@ -183,7 +183,6 @@ class NodeSMPCValueDTO(ImmutableBaseModel): sum_op_values: NodeTableDTO = None min_op_values: NodeTableDTO = None max_op_values: NodeTableDTO = None - union_op_values: NodeTableDTO = None class NodeSMPCDTO(NodeUDFDTO): diff --git a/mipengine/smpc_DTOs.py b/mipengine/smpc_DTOs.py index 53380bbe18757cfd0853eb09dbb93c36a0a075a5..d1f6d5b51e6e414a21eef8135af2cc82dd0d07f5 100644 --- a/mipengine/smpc_DTOs.py +++ b/mipengine/smpc_DTOs.py @@ -5,11 +5,9 @@ from pydantic import BaseModel class SMPCRequestType(enum.Enum): - SUM = "sum" - MIN = "min" - MAX = "max" - UNION = "union" - PRODUCT = "product" + SUM = "fsum" + MIN = "fmin" + MAX = "fmax" def __str__(self): return self.name @@ -38,7 +36,7 @@ class SMPCResponse(BaseModel): class SMPCResponseWithOutput(BaseModel): - computationOutput: List[int] + computationOutput: List[float] computationType: SMPCRequestType jobId: str status: SMPCResponseStatus diff --git a/mipengine/smpc_cluster_comm_helpers.py b/mipengine/smpc_cluster_comm_helpers.py index 01cc8507bba379f1afb214ea656acc4f43386778..c705da739840a0176fd6118cbbc900550dce0d90 100644 --- a/mipengine/smpc_cluster_comm_helpers.py +++ b/mipengine/smpc_cluster_comm_helpers.py @@ -12,15 +12,21 @@ TRIGGER_COMPUTATION_ENDPOINT = "/api/secure-aggregation/job-id/" GET_RESULT_ENDPOINT = "/api/get-result/job-id/" +def _get_smpc_load_data_request_data_structure(data_values: str): + """ + The current approach with the SMPC cluster is to send all computations as floats. + That way we don't need to have separate operations for sum-int and sum-float. + """ + data = {"type": "float", "data": json.loads(data_values)} + return json.dumps(data) + + def load_data_to_smpc_client(client_address: str, jobid: str, values: str): request_url = client_address + ADD_DATASET_ENDPOINT + jobid request_headers = {"Content-type": "application/json", "Accept": "text/plain"} - # TODO (SMPC) Currently only ints are supported so it's hardcoded - # https://team-1617704806227.atlassian.net/browse/MIP-518 - data = {"type": "int", "data": json.loads(values)} response = requests.post( url=request_url, - data=json.dumps(data), + data=_get_smpc_load_data_request_data_structure(values), headers=request_headers, ) if response.status_code != 200: diff --git a/mipengine/udfgen/udfgen_DTOs.py b/mipengine/udfgen/udfgen_DTOs.py index 5ff9d81e650a7132776e7bd6d46ad2d97c5dea8e..35258cab39891b6934a93d011184278b5aa7ff35 100644 --- a/mipengine/udfgen/udfgen_DTOs.py +++ b/mipengine/udfgen/udfgen_DTOs.py @@ -46,7 +46,6 @@ class SMPCUDFGenResult(UDFGenResult): sum_op_values: Optional[TableUDFGenResult] = None min_op_values: Optional[TableUDFGenResult] = None max_op_values: Optional[TableUDFGenResult] = None - union_op_values: Optional[TableUDFGenResult] = None def __eq__(self, other): if self.template != other.template: @@ -57,8 +56,6 @@ class SMPCUDFGenResult(UDFGenResult): return False if self.max_op_values != other.max_op_values: return False - if self.union_op_values != other.union_op_values: - return False return True def __repr__(self): @@ -68,7 +65,6 @@ class SMPCUDFGenResult(UDFGenResult): f"sum_op_values={self.sum_op_values}, " f"min_op_values={self.min_op_values}, " f"max_op_values={self.max_op_values}, " - f"union_op_values={self.union_op_values}" f")" ) @@ -96,7 +92,6 @@ class SMPCTablesInfo(UDFGenBaseModel): sum_op_values: Optional[TableInfo] = None min_op_values: Optional[TableInfo] = None max_op_values: Optional[TableInfo] = None - union_op_values: Optional[TableInfo] = None def __repr__(self): return ( @@ -105,6 +100,5 @@ class SMPCTablesInfo(UDFGenBaseModel): f"sum_op_values={self.sum_op_values}, " f"min_op_values={self.min_op_values}, " f"max_op_values={self.max_op_values}, " - f"union_op_values={self.union_op_values}" f")" ) diff --git a/mipengine/udfgen/udfgenerator.py b/mipengine/udfgen/udfgenerator.py index ba9832d037d40d4fba0ca65a003c986a1e743d10..f859f07ed0bb5801544278dd40de16ffca330e49 100644 --- a/mipengine/udfgen/udfgenerator.py +++ b/mipengine/udfgen/udfgenerator.py @@ -184,7 +184,7 @@ Local UDF step Example ... def local_step(x, y): ... state["x"] = x["key"] ... state["y"] = y["key"] -... transfer["sum"] = {"data": x["key"] + y["key"], "operation": "sum"} +... transfer["sum"] = {"data": x["key"] + y["key"], "operation": "sum", "type": "float"} ... return state, transfer Global UDF step Example @@ -375,9 +375,8 @@ def get_smpc_build_template(secure_transfer_type): stmts.extend(get_smpc_op_template(secure_transfer_type.sum_op, "sum_op")) stmts.extend(get_smpc_op_template(secure_transfer_type.min_op, "min_op")) stmts.extend(get_smpc_op_template(secure_transfer_type.max_op, "max_op")) - stmts.extend(get_smpc_op_template(secure_transfer_type.union_op, "union_op")) stmts.append( - "{varname} = udfio.construct_secure_transfer_dict(__template,__sum_op_values,__min_op_values,__max_op_values,__union_op_values)" + "{varname} = udfio.construct_secure_transfer_dict(__template,__sum_op_values,__min_op_values,__max_op_values)" ) return LN.join(stmts) @@ -437,14 +436,13 @@ def _get_secure_transfer_op_return_stmt_template(op_enabled, table_name_tmpl, op def _get_secure_transfer_main_return_stmt_template(output_type, smpc_used): if smpc_used: return_stmts = [ - "template, sum_op, min_op, max_op, union_op = udfio.split_secure_transfer_dict({return_name})" + "template, sum_op, min_op, max_op = udfio.split_secure_transfer_dict({return_name})" ] ( _, sum_op_tmpl, min_op_tmpl, max_op_tmpl, - union_op_tmpl, ) = _get_smpc_table_template_names(_get_main_table_template_name()) return_stmts.extend( _get_secure_transfer_op_return_stmt_template( @@ -461,11 +459,6 @@ def _get_secure_transfer_main_return_stmt_template(output_type, smpc_used): output_type.max_op, max_op_tmpl, "max_op" ) ) - return_stmts.extend( - _get_secure_transfer_op_return_stmt_template( - output_type.union_op, union_op_tmpl, "union_op" - ) - ) return_stmts.append("return json.dumps(template)") return LN.join(return_stmts) else: @@ -500,14 +493,13 @@ def _get_secure_transfer_sec_return_stmt_template( ): if smpc_used: return_stmts = [ - "template, sum_op, min_op, max_op, union_op = udfio.split_secure_transfer_dict({return_name})" + "template, sum_op, min_op, max_op = udfio.split_secure_transfer_dict({return_name})" ] ( template_tmpl, sum_op_tmpl, min_op_tmpl, max_op_tmpl, - union_op_tmpl, ) = _get_smpc_table_template_names(tablename_placeholder) return_stmts.append( '_conn.execute(f"INSERT INTO $' @@ -529,11 +521,6 @@ def _get_secure_transfer_sec_return_stmt_template( output_type.max_op, max_op_tmpl, "max_op" ) ) - return_stmts.extend( - _get_secure_transfer_op_return_stmt_template( - output_type.union_op, union_op_tmpl, "union_op" - ) - ) return LN.join(return_stmts) else: # Treated as a TransferType @@ -905,13 +892,11 @@ class SecureTransferType(DictType, InputType, LoopbackOutputType): _sum_op: bool _min_op: bool _max_op: bool - _union_op: bool - def __init__(self, sum_op=False, min_op=False, max_op=False, union_op=False): + def __init__(self, sum_op=False, min_op=False, max_op=False): self._sum_op = sum_op self._min_op = min_op self._max_op = max_op - self._union_op = union_op @property def sum_op(self): @@ -925,17 +910,13 @@ class SecureTransferType(DictType, InputType, LoopbackOutputType): def max_op(self): return self._max_op - @property - def union_op(self): - return self._union_op - -def secure_transfer(sum_op=False, min_op=False, max_op=False, union_op=False): - if not sum_op and not min_op and not max_op and not union_op: +def secure_transfer(sum_op=False, min_op=False, max_op=False): + if not sum_op and not min_op and not max_op: raise UDFBadDefinition( "In a secure_transfer at least one operation should be enabled." ) - return SecureTransferType(sum_op, min_op, max_op, union_op) + return SecureTransferType(sum_op, min_op, max_op) class StateType(DictType, InputType, LoopbackOutputType): @@ -1068,7 +1049,6 @@ class SMPCSecureTransferArg(UDFArgument): sum_op_values_table_name: str min_op_values_table_name: str max_op_values_table_name: str - union_op_values_table_name: str def __init__( self, @@ -1076,26 +1056,21 @@ class SMPCSecureTransferArg(UDFArgument): sum_op_values_table_name: str, min_op_values_table_name: str, max_op_values_table_name: str, - union_op_values_table_name: str, ): sum_op = False min_op = False max_op = False - union_op = False if sum_op_values_table_name: sum_op = True if min_op_values_table_name: min_op = True if max_op_values_table_name: max_op = True - if union_op_values_table_name: - union_op = True - self.type = SecureTransferType(sum_op, min_op, max_op, union_op) + self.type = SecureTransferType(sum_op, min_op, max_op) self.template_table_name = template_table_name self.sum_op_values_table_name = sum_op_values_table_name self.min_op_values_table_name = min_op_values_table_name self.max_op_values_table_name = max_op_values_table_name - self.union_op_values_table_name = union_op_values_table_name class LiteralArg(UDFArgument): @@ -1253,7 +1228,6 @@ class SMPCBuild(ASTNode): sum_op_values_table_name=self.arg.sum_op_values_table_name, min_op_values_table_name=self.arg.min_op_values_table_name, max_op_values_table_name=self.arg.max_op_values_table_name, - union_op_values_table_name=self.arg.union_op_values_table_name, ) @@ -2055,21 +2029,17 @@ def convert_smpc_udf_input_to_udf_arg(smpc_udf_input: SMPCTablesInfo): sum_op_table_name = None min_op_table_name = None max_op_table_name = None - union_op_table_name = None if smpc_udf_input.sum_op_values: sum_op_table_name = smpc_udf_input.sum_op_values.name if smpc_udf_input.min_op_values: min_op_table_name = smpc_udf_input.min_op_values.name if smpc_udf_input.max_op_values: max_op_table_name = smpc_udf_input.max_op_values.name - if smpc_udf_input.union_op_values: - union_op_table_name = smpc_udf_input.union_op_values.name return SMPCSecureTransferArg( template_table_name=smpc_udf_input.template.name, sum_op_values_table_name=sum_op_table_name, min_op_values_table_name=min_op_table_name, max_op_values_table_name=max_op_table_name, - union_op_values_table_name=union_op_table_name, ) @@ -2544,7 +2514,6 @@ def _get_smpc_table_template_names(prefix: str): prefix + "_sum_op", prefix + "_min_op", prefix + "_max_op", - prefix + "_union_op", ) @@ -2574,27 +2543,22 @@ def _create_smpc_udf_output(output_type: SecureTransferType, table_name_prefix: sum_op_tmpl, min_op_tmpl, max_op_tmpl, - union_op_tmpl, ) = _get_smpc_table_template_names(table_name_prefix) template = _create_table_udf_output(output_type, template_tmpl) sum_op = None min_op = None max_op = None - union_op = None if output_type.sum_op: sum_op = _create_table_udf_output(output_type, sum_op_tmpl) if output_type.min_op: min_op = _create_table_udf_output(output_type, min_op_tmpl) if output_type.max_op: max_op = _create_table_udf_output(output_type, max_op_tmpl) - if output_type.union_op: - union_op = _create_table_udf_output(output_type, union_op_tmpl) return SMPCUDFGenResult( template=template, sum_op_values=sum_op, min_op_values=min_op, max_op_values=max_op, - union_op_values=union_op, ) diff --git a/mipengine/udfgen/udfio.py b/mipengine/udfgen/udfio.py index dd03ddd14df21aa519883ecf263a62360cdedf16..2e68f3d85cacf46e071f91c27da09dd36454f1cb 100644 --- a/mipengine/udfgen/udfio.py +++ b/mipengine/udfgen/udfio.py @@ -5,6 +5,7 @@ from functools import partial from functools import reduce from typing import Any from typing import List +from typing import Set from typing import Tuple from typing import Type @@ -121,8 +122,17 @@ def merge_tensor_to_list(columns): # ~~~~~~~~~~~~~~~~~~~~~~~~ Secure Transfer methods ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +smpc_sum_op = "sum" +smpc_min_op = "min" +smpc_max_op = "max" +smpc_numeric_operations = [smpc_sum_op, smpc_min_op, smpc_max_op] -numeric_operations = ["sum", "min", "max"] +smpc_int_type = "int" +smpc_float_type = "float" +smpc_numeric_types = [smpc_int_type, smpc_float_type] +smpc_transfer_op_key = "operation" +smpc_transfer_val_type_key = "type" +smpc_transfer_data_key = "data" def secure_transfers_to_merged_dict(transfers: List[dict]): @@ -135,35 +145,86 @@ def secure_transfers_to_merged_dict(transfers: List[dict]): # Get all keys from a list of dicts all_keys = set().union(*(d.keys() for d in transfers)) + _validate_transfers_have_all_keys(transfers, all_keys) for key in all_keys: - operation = transfers[0][key]["operation"] - if operation in numeric_operations: - result[key] = _operation_on_secure_transfer_key_data( - key, transfers, operation - ) - else: - raise NotImplementedError( - f"Secure transfer operation not supported: {operation}" - ) + _validate_transfers_operation(transfers, key) + _validate_transfers_type(transfers, key) + result[key] = _operation_on_secure_transfer_key_data( + key, + transfers, + transfers[0][key][smpc_transfer_op_key], + ) return result -def _operation_on_secure_transfer_key_data(key, transfers: List[dict], operation: str): +def _validate_transfers_have_all_keys(transfers: List[dict], all_keys: Set[str]): + for key in all_keys: + for transfer in transfers: + if key not in transfer.keys(): + raise ValueError( + f"All secure transfer dicts should have the same keys. Transfer: {transfer} doesn't have key: {key}" + ) + + +def _validate_transfers_operation(transfers: List[dict], key: str): """ - Given a list of secure_transfer dicts, it makes the appropriate operation on the data of the key provided. + Validates that all transfer dicts have proper 'operation' values for the 'key' provided. """ - result = transfers[0][key]["data"] + _validate_transfer_key_operation(transfers[0][key]) + first_transfer_operation = transfers[0][key][smpc_transfer_op_key] for transfer in transfers[1:]: - if transfer[key]["operation"] not in numeric_operations: + _validate_transfer_key_operation(transfer[key]) + if transfer[key][smpc_transfer_op_key] != first_transfer_operation: raise ValueError( - f"Secure Transfer operation is not supported: {transfer[key]['operation']}" + f"Similar secure transfer keys should have the same operation value. " + f"'{first_transfer_operation}' != {transfer[key][smpc_transfer_op_key]}" ) - if transfer[key]["operation"] != operation: + + +def _validate_transfer_key_operation(transfer_value: dict): + try: + operation = transfer_value[smpc_transfer_op_key] + except KeyError: + raise ValueError( + "Secure Transfer operation is not provided. Expected format: {'a' : {'data': X, 'operation': Y, 'type': Z}}" + ) + if operation not in smpc_numeric_operations: + raise ValueError(f"Secure Transfer operation is not supported: '{operation}'.") + + +def _validate_transfers_type(transfers: List[dict], key: str): + """ + Validates that all transfer dicts have proper 'type' values for the 'key' provided. + """ + _validate_transfer_key_type(transfers[0][key]) + first_transfer_value_type = transfers[0][key][smpc_transfer_val_type_key] + for transfer in transfers[1:]: + _validate_transfer_key_type(transfer[key]) + if transfer[key][smpc_transfer_val_type_key] != first_transfer_value_type: raise ValueError( - f"All secure transfer keys should have the same 'operation' value. " - f"'{operation}' != {transfer[key]['operation']}" + f"Similar secure transfer keys should have the same type value. " + f"'{first_transfer_value_type}' != {transfer[key][smpc_transfer_val_type_key]}" ) - result = _calc_values(result, transfer[key]["data"], operation) + + +def _validate_transfer_key_type(transfer_value: dict): + try: + values_type = transfer_value[smpc_transfer_val_type_key] + except KeyError: + raise ValueError( + "Secure Transfer type is not provided. Expected format: {'a' : {'data': X, 'operation': Y, 'type': Z}}" + ) + if values_type not in smpc_numeric_types: + raise ValueError(f"Secure Transfer type is not supported: '{values_type}'.") + + +def _operation_on_secure_transfer_key_data(key, transfers: List[dict], operation: str): + """ + Given a list of secure_transfer dicts, it makes the appropriate operation on the data of the key provided. + """ + result = transfers[0][key][smpc_transfer_data_key] + for transfer in transfers[1:]: + result = _calc_values(result, transfer[key][smpc_transfer_data_key], operation) return result @@ -197,35 +258,37 @@ def _validate_calc_values(value1, value2): def _calc_numeric_values(value1: Any, value2: Any, operation: str): - if operation == "sum": + if operation == smpc_sum_op: return value1 + value2 - elif operation == "min": - return value1 if value1 < value2 else value2 - elif operation == "max": - return value1 if value1 > value2 else value2 + elif operation == smpc_min_op: + return min(value1, value2) + elif operation == smpc_max_op: + return max(value1, value2) else: raise NotImplementedError -def split_secure_transfer_dict(dict_: dict) -> Tuple[dict, list, list, list, list]: +def split_secure_transfer_dict(secure_transfer: dict) -> Tuple[dict, list, list, list]: """ When SMPC is used, a secure transfer dict should be split in different parts: 1) The template of the dict with relative positions instead of values, 2) flattened lists for each operation, containing the values. """ secure_transfer_template = {} - op_flat_data = {"sum": [], "min": [], "max": []} - op_indexes = {"sum": 0, "min": 0, "max": 0} - for key, data_transfer in dict_.items(): + op_flat_data = {smpc_sum_op: [], smpc_min_op: [], smpc_max_op: []} + op_indexes = {smpc_sum_op: 0, smpc_min_op: 0, smpc_max_op: 0} + for key, data_transfer in secure_transfer.items(): _validate_secure_transfer_item(key, data_transfer) - cur_op = data_transfer["operation"] + cur_op = data_transfer[smpc_transfer_op_key] try: ( data_transfer_tmpl, cur_flat_data, op_indexes[cur_op], ) = _flatten_data_and_keep_relative_positions( - op_indexes[cur_op], data_transfer["data"], [list, int, float] + op_indexes[cur_op], + data_transfer[smpc_transfer_data_key], + [list, int, float], ) except TypeError as e: raise TypeError( @@ -233,30 +296,42 @@ def split_secure_transfer_dict(dict_: dict) -> Tuple[dict, list, list, list, lis ) op_flat_data[cur_op].extend(cur_flat_data) - secure_transfer_template[key] = dict_[key] - secure_transfer_template[key]["data"] = data_transfer_tmpl - + secure_transfer_key_template = { + smpc_transfer_op_key: secure_transfer[key][smpc_transfer_op_key], + smpc_transfer_val_type_key: secure_transfer[key][ + smpc_transfer_val_type_key + ], + smpc_transfer_data_key: data_transfer_tmpl, + } + secure_transfer_template[key] = secure_transfer_key_template return ( secure_transfer_template, - op_flat_data["sum"], - op_flat_data["min"], - op_flat_data["max"], - [], + op_flat_data[smpc_sum_op], + op_flat_data[smpc_min_op], + op_flat_data[smpc_max_op], ) def _validate_secure_transfer_item(key: str, data_transfer: dict): - if "operation" not in data_transfer.keys(): + if smpc_transfer_op_key not in data_transfer.keys(): raise ValueError( f"Each Secure Transfer key should contain an operation. Key: {key}" ) - if "data" not in data_transfer.keys(): + if smpc_transfer_val_type_key not in data_transfer.keys(): + raise ValueError(f"Each Secure Transfer key should contain a type. Key: {key}") + + if smpc_transfer_data_key not in data_transfer.keys(): raise ValueError(f"Each Secure Transfer key should contain data. Key: {key}") - if data_transfer["operation"] not in numeric_operations: + if data_transfer[smpc_transfer_op_key] not in smpc_numeric_operations: + raise ValueError( + f"Secure Transfer operation is not supported: {data_transfer[smpc_transfer_op_key]}" + ) + + if data_transfer[smpc_transfer_val_type_key] not in smpc_numeric_types: raise ValueError( - f"Secure Transfer operation is not supported: {data_transfer['operation']}" + f"Secure Transfer type is not supported: {data_transfer[smpc_transfer_val_type_key]}" ) @@ -265,30 +340,40 @@ def construct_secure_transfer_dict( sum_op_values: List[int] = None, min_op_values: List[int] = None, max_op_values: List[int] = None, - union_op_values: List[int] = None, ) -> dict: """ When SMPC is used, a secure_transfer dict is broken into template and values. In order to be used from a udf it needs to take it's final key - value form. """ final_dict = {} - for key, data_transfer in template.items(): - if data_transfer["operation"] == "sum": - unflattened_data = _unflatten_data_using_relative_positions( - data_transfer["data"], sum_op_values, [int, float] + for key, data_transfer_tmpl in template.items(): + if data_transfer_tmpl[smpc_transfer_op_key] == smpc_sum_op: + structured_data = _structure_data_using_relative_positions( + data_transfer_tmpl[smpc_transfer_data_key], + data_transfer_tmpl[smpc_transfer_val_type_key], + sum_op_values, + [int, float], ) - elif data_transfer["operation"] == "min": - unflattened_data = _unflatten_data_using_relative_positions( - data_transfer["data"], min_op_values, [int, float] + elif data_transfer_tmpl[smpc_transfer_op_key] == smpc_min_op: + structured_data = _structure_data_using_relative_positions( + data_transfer_tmpl[smpc_transfer_data_key], + data_transfer_tmpl[smpc_transfer_val_type_key], + min_op_values, + [int, float], ) - elif data_transfer["operation"] == "max": - unflattened_data = _unflatten_data_using_relative_positions( - data_transfer["data"], max_op_values, [int, float] + elif data_transfer_tmpl[smpc_transfer_op_key] == smpc_max_op: + structured_data = _structure_data_using_relative_positions( + data_transfer_tmpl[smpc_transfer_data_key], + data_transfer_tmpl[smpc_transfer_val_type_key], + max_op_values, + [int, float], ) else: - raise ValueError(f"Operation not supported: {data_transfer['operation']}") + raise ValueError( + f"Operation not supported: {data_transfer_tmpl[smpc_transfer_op_key]}" + ) - final_dict[key] = unflattened_data + final_dict[key] = structured_data return final_dict @@ -328,8 +413,9 @@ def _flatten_data_and_keep_relative_positions( return index, [data], index + 1 -def _unflatten_data_using_relative_positions( +def _structure_data_using_relative_positions( data_tmpl: Any, + data_values_type: str, flat_values: List[Any], allowed_types: List[Type], ): @@ -340,11 +426,16 @@ def _unflatten_data_using_relative_positions( """ if isinstance(data_tmpl, list): return [ - _unflatten_data_using_relative_positions(elem, flat_values, allowed_types) + _structure_data_using_relative_positions( + elem, data_values_type, flat_values, allowed_types + ) for elem in data_tmpl ] if type(data_tmpl) not in allowed_types: raise TypeError(f"Types allowed: {allowed_types}") + if data_values_type == smpc_int_type: + return int(flat_values[data_tmpl]) + return flat_values[data_tmpl] diff --git a/tests/algorithms/orphan_udfs.py b/tests/algorithms/orphan_udfs.py index 55406ac9daee370330548d625c6ae864d4ab3323..b0a16da5acd57d5e5ff18a528a839ab11cc8d3b9 100644 --- a/tests/algorithms/orphan_udfs.py +++ b/tests/algorithms/orphan_udfs.py @@ -34,7 +34,7 @@ def smpc_local_step(table: DataFrame): sum_ = 0 for element, *_ in table.values: sum_ += element - secure_transfer_ = {"sum": {"data": int(sum_), "operation": "sum"}} + secure_transfer_ = {"sum": {"data": int(sum_), "operation": "sum", "type": "int"}} return secure_transfer_ diff --git a/tests/algorithms/smpc_standard_deviation.py b/tests/algorithms/smpc_standard_deviation.py index 739feff932cb98a282487661642877f272cfca0a..f1463881fc2de03cd36ec34f51d173ead25c7b03 100644 --- a/tests/algorithms/smpc_standard_deviation.py +++ b/tests/algorithms/smpc_standard_deviation.py @@ -92,10 +92,10 @@ def smpc_local_step_1(table): if element > max_value: max_value = element secure_transfer_ = { - "sum": {"data": float(sum_), "operation": "sum"}, - "min": {"data": float(min_value), "operation": "min"}, - "max": {"data": float(max_value), "operation": "max"}, - "count": {"data": len(table), "operation": "sum"}, + "sum": {"data": float(sum_), "operation": "sum", "type": "float"}, + "min": {"data": float(min_value), "operation": "min", "type": "float"}, + "max": {"data": float(max_value), "operation": "max", "type": "float"}, + "count": {"data": len(table), "operation": "sum", "type": "float"}, } return state_, secure_transfer_ @@ -129,7 +129,7 @@ def smpc_local_step_2(prev_state, global_transfer): secure_transfer_ = { "deviation_sum": { "data": float(deviation_sum), - "type": "int", + "type": "float", "operation": "sum", } } diff --git a/tests/algorithms/smpc_standard_deviation_int_only.py b/tests/algorithms/smpc_standard_deviation_int_only.py index cf16b191ba067b7c78745eb369ed57b18f36edc8..cdef7e960cd76966fe2c1397fe4c022af00c4ba7 100644 --- a/tests/algorithms/smpc_standard_deviation_int_only.py +++ b/tests/algorithms/smpc_standard_deviation_int_only.py @@ -92,10 +92,10 @@ def smpc_local_step_1(table): if element > max_value: max_value = element secure_transfer_ = { - "sum": {"data": int(sum_), "operation": "sum"}, - "min": {"data": int(min_value), "operation": "min"}, - "max": {"data": int(max_value), "operation": "max"}, - "count": {"data": len(table), "operation": "sum"}, + "sum": {"data": int(sum_), "operation": "sum", "type": "int"}, + "min": {"data": int(min_value), "operation": "min", "type": "int"}, + "max": {"data": int(max_value), "operation": "max", "type": "int"}, + "count": {"data": len(table), "operation": "sum", "type": "int"}, } return state_, secure_transfer_ diff --git a/tests/smpc_env_tests/deployment_configs/kubernetes_values.yaml b/tests/smpc_env_tests/deployment_configs/kubernetes_values.yaml index 268ba55aac8c7fad26864cd55bf5a476aeabaa12..4e3888b8762accdd28e8bddfcd4d9764cb1fb4a9 100644 --- a/tests/smpc_env_tests/deployment_configs/kubernetes_values.yaml +++ b/tests/smpc_env_tests/deployment_configs/kubernetes_values.yaml @@ -11,8 +11,8 @@ monetdb_storage: /opt/mipengine/db csvs_datapath: /opt/mipengine/csvs controller: - node_landscape_aggregator_update_interval: 20 - celery_tasks_timeout: 10 + node_landscape_aggregator_update_interval: 30 + celery_tasks_timeout: 20 celery_run_udf_task_timeout: 120 nodes_cleanup_interval: 60 cleanup_file_folder: /opt/cleanup @@ -23,3 +23,5 @@ smpc: image: gpikra/coordinator:v6.0.0 db_image: mongo:5.0.8 queue_image: redis:alpine3.15 + get_result_interval: 5 + get_result_max_retries: 100 diff --git a/tests/standalone_tests/test_smpc_algorithms.py b/tests/standalone_tests/test_smpc_algorithms.py index b5cf9c5a5fd23656212de80b715a319d8a0f48fe..ef00954c0f79f620fb81b26932e3e90b0e1e5794 100644 --- a/tests/standalone_tests/test_smpc_algorithms.py +++ b/tests/standalone_tests/test_smpc_algorithms.py @@ -1,6 +1,5 @@ import json import re -import time import pytest import requests @@ -79,7 +78,14 @@ def get_parametrization_list_success_cases(): {"name": "max_value", "data": [4.0], "type": "FLOAT"}, ], } - parametrization_list.append((algorithm_name, request_dict, expected_response)) + parametrization_list.append( + pytest.param( + algorithm_name, + request_dict, + expected_response, + id="smpc std dev ints only without smpc flag", + ) + ) # END ~~~~~~~~~~success case 1~~~~~~~~~~ # ~~~~~~~~~~success case 2~~~~~~~~~~ @@ -153,109 +159,174 @@ def get_parametrization_list_success_cases(): {"name": "max_value", "data": [4.0], "type": "FLOAT"}, ], } - parametrization_list.append((algorithm_name, request_dict, expected_response)) + parametrization_list.append( + pytest.param( + algorithm_name, + request_dict, + expected_response, + id="smpc std dev ints only with smpc flag", + ) + ) # END ~~~~~~~~~~success case 2~~~~~~~~~~ - # - # # ~~~~~~~~~~success case 3~~~~~~~~~~ - # algorithm_name = "smpc_standard_deviation" - # request_dict = { - # "inputdata": { - # "data_model": "dementia:0.1", - # "datasets": ["edsd"], - # "x": [ - # "lefthippocampus", - # ], - # "filters": { - # "condition": "AND", - # "rules": [ - # { - # "id": "dataset", - # "type": "string", - # "value": ["edsd"], - # "operator": "in", - # }, - # { - # "condition": "AND", - # "rules": [ - # { - # "id": variable, - # "type": "string", - # "operator": "is_not_null", - # "value": None, - # } - # for variable in [ - # "lefthippocampus", - # ] - # ], - # }, - # ], - # "valid": True, - # }, - # }, - # } - # expected_response = { - # "title": "Standard Deviation", - # "columns": [ - # {"name": "variable", "data": ["lefthippocampus"], "type": "STR"}, - # {"name": "std_deviation", "data": [0.3634506955662605], "type": "FLOAT"}, - # {"name": "min_value", "data": [1.3047], "type": "FLOAT"}, - # {"name": "max_value", "data": [4.4519], "type": "FLOAT"}, - # ], - # } - # parametrization_list.append((algorithm_name, request_dict, expected_response)) - # # END ~~~~~~~~~~success case 3~~~~~~~~~~ - # - # # ~~~~~~~~~~success case 4~~~~~~~~~~ - # algorithm_name = "smpc_standard_deviation" - # request_dict = { - # "inputdata": { - # "data_model": "dementia:0.1", - # "datasets": ["edsd"], - # "x": [ - # "lefthippocampus", - # ], - # "filters": { - # "condition": "AND", - # "rules": [ - # { - # "id": "dataset", - # "type": "string", - # "value": ["edsd"], - # "operator": "in", - # }, - # { - # "condition": "AND", - # "rules": [ - # { - # "id": variable, - # "type": "string", - # "operator": "is_not_null", - # "value": None, - # } - # for variable in [ - # "lefthippocampus", - # ] - # ], - # }, - # ], - # "valid": True, - # }, - # }, - # "flags": { - # "smpc": True, - # }, - # } - # expected_response = { - # "title": "Standard Deviation", - # "columns": [ - # {"name": "variable", "data": ["lefthippocampus"], "type": "STR"}, - # {"name": "std_deviation", "data": [0.3634506955662605], "type": "FLOAT"}, - # {"name": "min_value", "data": [1.3047], "type": "FLOAT"}, - # {"name": "max_value", "data": [4.4519], "type": "FLOAT"}, - # ], - # } - # parametrization_list.append((algorithm_name, request_dict, expected_response)) - # # END ~~~~~~~~~~success case 4~~~~~~~~~~ + + # ~~~~~~~~~~success case 3~~~~~~~~~~ + algorithm_name = "smpc_standard_deviation" + request_dict = { + "inputdata": { + "data_model": "dementia:0.1", + "datasets": [ + "edsd0", + "edsd1", + "edsd2", + "edsd3", + "edsd4", + "edsd5", + "edsd6", + "edsd7", + "edsd8", + "edsd9", + ], + "x": [ + "lefthippocampus", + ], + "filters": { + "condition": "AND", + "rules": [ + { + "id": "dataset", + "type": "string", + "value": [ + "edsd0", + "edsd1", + "edsd2", + "edsd3", + "edsd4", + "edsd5", + "edsd6", + "edsd7", + "edsd8", + "edsd9", + ], + "operator": "in", + }, + { + "condition": "AND", + "rules": [ + { + "id": variable, + "type": "string", + "operator": "is_not_null", + "value": None, + } + for variable in [ + "lefthippocampus", + ] + ], + }, + ], + "valid": True, + }, + }, + } + expected_response = { + "title": "Standard Deviation", + "columns": [ + {"name": "variable", "data": ["lefthippocampus"], "type": "STR"}, + {"name": "std_deviation", "data": [0.3634506955662605], "type": "FLOAT"}, + {"name": "min_value", "data": [1.3047], "type": "FLOAT"}, + {"name": "max_value", "data": [4.4519], "type": "FLOAT"}, + ], + } + parametrization_list.append( + pytest.param( + algorithm_name, + request_dict, + expected_response, + id="smpc std dev floats/ints without smpc flag", + ) + ) + # END ~~~~~~~~~~success case 3~~~~~~~~~~ + + # ~~~~~~~~~~success case 4~~~~~~~~~~ + algorithm_name = "smpc_standard_deviation" + request_dict = { + "inputdata": { + "data_model": "dementia:0.1", + "datasets": [ + "edsd0", + "edsd1", + "edsd2", + "edsd3", + "edsd4", + "edsd5", + "edsd6", + "edsd7", + "edsd8", + "edsd9", + ], + "x": [ + "lefthippocampus", + ], + "filters": { + "condition": "AND", + "rules": [ + { + "id": "dataset", + "type": "string", + "value": [ + "edsd0", + "edsd1", + "edsd2", + "edsd3", + "edsd4", + "edsd5", + "edsd6", + "edsd7", + "edsd8", + "edsd9", + ], + "operator": "in", + }, + { + "condition": "AND", + "rules": [ + { + "id": variable, + "type": "string", + "operator": "is_not_null", + "value": None, + } + for variable in [ + "lefthippocampus", + ] + ], + }, + ], + "valid": True, + }, + }, + "flags": { + "smpc": True, + }, + } + expected_response = { + "title": "Standard Deviation", + "columns": [ + {"name": "variable", "data": ["lefthippocampus"], "type": "STR"}, + {"name": "std_deviation", "data": [0.3634506955662605], "type": "FLOAT"}, + {"name": "min_value", "data": [1.3047], "type": "FLOAT"}, + {"name": "max_value", "data": [4.4519], "type": "FLOAT"}, + ], + } + parametrization_list.append( + pytest.param( + algorithm_name, + request_dict, + expected_response, + id="smpc std dev floats/ints with smpc flag", + ) + ) + # END ~~~~~~~~~~success case 4~~~~~~~~~~ return parametrization_list diff --git a/tests/standalone_tests/test_smpc_node_tasks.py b/tests/standalone_tests/test_smpc_node_tasks.py index a851abf9c7668b7668bbc4c0908e4e20ca15b4a6..01ebc7ebb2c23ecd83cd19cc4f4a92178115af83 100644 --- a/tests/standalone_tests/test_smpc_node_tasks.py +++ b/tests/standalone_tests/test_smpc_node_tasks.py @@ -88,8 +88,12 @@ def create_table_with_secure_transfer_results_with_smpc_off( secure_transfer_1_value = 100 secure_transfer_2_value = 11 - secure_transfer_1 = {"sum": {"data": secure_transfer_1_value, "operation": "sum"}} - secure_transfer_2 = {"sum": {"data": secure_transfer_2_value, "operation": "sum"}} + secure_transfer_1 = { + "sum": {"data": secure_transfer_1_value, "operation": "sum", "type": "int"} + } + secure_transfer_2 = { + "sum": {"data": secure_transfer_2_value, "operation": "sum", "type": "int"} + } values = [ ["localnode1", json.dumps(secure_transfer_1)], ["localnode2", json.dumps(secure_transfer_2)], @@ -116,8 +120,12 @@ def create_table_with_multiple_secure_transfer_templates( table_name = create_secure_transfer_table(celery_app) - secure_transfer_template = {"sum": {"data": [0, 1, 2, 3], "operation": "sum"}} - different_secure_transfer_template = {"sum": {"data": 0, "operation": "sum"}} + secure_transfer_template = { + "sum": {"data": [0, 1, 2, 3], "operation": "sum", "type": "int"} + } + different_secure_transfer_template = { + "sum": {"data": 0, "operation": "sum", "type": "int"} + } if similar: values = [ @@ -222,7 +230,9 @@ def test_secure_transfer_output_with_smpc_off( secure_transfer_result = results[0] assert isinstance(secure_transfer_result, NodeTableDTO) - expected_result = {"sum": {"data": input_table_name_sum, "operation": "sum"}} + expected_result = { + "sum": {"data": input_table_name_sum, "operation": "sum", "type": "int"} + } validate_dict_table_data_match_expected( celery_app=localnode1_celery_app, get_table_data_task_signature=get_table_data_task, @@ -369,7 +379,7 @@ def test_secure_transfer_run_udf_flow_with_smpc_on( assert isinstance(smpc_result, NodeSMPCDTO) assert smpc_result.value.template is not None - expected_template = {"sum": {"data": 0, "operation": "sum"}} + expected_template = {"sum": {"data": 0, "operation": "sum", "type": "int"}} validate_dict_table_data_match_expected( celery_app=smpc_localnode1_celery_app, get_table_data_task_signature=get_table_data_task, diff --git a/tests/standalone_tests/test_udfgenerator.py b/tests/standalone_tests/test_udfgenerator.py index a259b5a13a2347b41e1824fc1c0bf293fd260294..05366f27066040f4b484302678f34414bc4862a7 100644 --- a/tests/standalone_tests/test_udfgenerator.py +++ b/tests/standalone_tests/test_udfgenerator.py @@ -1048,11 +1048,6 @@ class TestUDFGenBase: udf_output.max_op_values.tablename_placeholder ) template_mapping[tablename_placeholder] = tablename_placeholder - if udf_output.union_op_values: - tablename_placeholder = ( - udf_output.union_op_values.tablename_placeholder - ) - template_mapping[tablename_placeholder] = tablename_placeholder else: pytest.fail( f"A udf_output must be of the format TableUDFOutput or SMPCUDFOutput." @@ -1122,10 +1117,6 @@ class TestUDFGenBase: queries.extend( self._concrete_table_udf_outputs(udf_output.max_op_values) ) - if udf_output.union_op_values: - queries.extend( - self._concrete_table_udf_outputs(udf_output.union_op_values) - ) else: pytest.fail( f"A udf_output must be of the format TableUDFOutput or SMPCUDFOutput." @@ -1171,13 +1162,13 @@ class TestUDFGenBase: "CREATE TABLE test_secure_transfer_table(node_id VARCHAR(500), secure_transfer CLOB)" ) globalnode_db_cursor.execute( - 'INSERT INTO test_secure_transfer_table(node_id, secure_transfer) VALUES(1, \'{"sum": {"data": 1, "operation": "sum"}}\')' + 'INSERT INTO test_secure_transfer_table(node_id, secure_transfer) VALUES(1, \'{"sum": {"data": 1, "operation": "sum", "type": "int"}}\')' ) globalnode_db_cursor.execute( - 'INSERT INTO test_secure_transfer_table(node_id, secure_transfer) VALUES(2, \'{"sum": {"data": 10, "operation": "sum"}}\')' + 'INSERT INTO test_secure_transfer_table(node_id, secure_transfer) VALUES(2, \'{"sum": {"data": 10, "operation": "sum", "type": "int"}}\')' ) globalnode_db_cursor.execute( - 'INSERT INTO test_secure_transfer_table(node_id, secure_transfer) VALUES(3, \'{"sum": {"data": 100, "operation": "sum"}}\')' + 'INSERT INTO test_secure_transfer_table(node_id, secure_transfer) VALUES(3, \'{"sum": {"data": 100, "operation": "sum", "type": "int"}}\')' ) @pytest.fixture(scope="function") @@ -1186,7 +1177,7 @@ class TestUDFGenBase: "CREATE TABLE test_smpc_template_table(node_id VARCHAR(500), secure_transfer CLOB)" ) globalnode_db_cursor.execute( - 'INSERT INTO test_smpc_template_table(node_id, secure_transfer) VALUES(1, \'{"sum": {"data": [0,1,2], "operation": "sum"}}\')' + 'INSERT INTO test_smpc_template_table(node_id, secure_transfer) VALUES(1, \'{"sum": {"data": [0,1,2], "operation": "sum", "type": "int"}}\')' ) @pytest.fixture(scope="function") @@ -1204,8 +1195,8 @@ class TestUDFGenBase: "CREATE TABLE test_smpc_template_table(node_id VARCHAR(500), secure_transfer CLOB)" ) globalnode_db_cursor.execute( - 'INSERT INTO test_smpc_template_table(node_id, secure_transfer) VALUES(1, \'{"sum": {"data": [0,1,2], "operation": "sum"}, ' - '"max": {"data": 0, "operation": "max"}}\')' + 'INSERT INTO test_smpc_template_table(node_id, secure_transfer) VALUES(1, \'{"sum": {"data": [0,1,2], "operation": "sum", "type": "int"}, ' + '"max": {"data": 0, "operation": "max", "type": "int"}}\')' ) @pytest.fixture(scope="function") @@ -4490,9 +4481,9 @@ class TestUDFGen_SecureTransferOutput_with_SMPC_off( ) def f(state): result = { - "sum": {"data": state["num"], "operation": "sum"}, - "min": {"data": state["num"], "operation": "min"}, - "max": {"data": state["num"], "operation": "max"}, + "sum": {"data": state["num"], "operation": "sum", "type": "int"}, + "min": {"data": state["num"], "operation": "min", "type": "int"}, + "max": {"data": state["num"], "operation": "max", "type": "int"}, } return result @@ -4527,9 +4518,9 @@ LANGUAGE PYTHON import json __state_str = _conn.execute("SELECT state from test_state_table;")["state"][0] state = pickle.loads(__state_str) - result = {'sum': {'data': state['num'], 'operation': 'sum'}, 'min': {'data': - state['num'], 'operation': 'min'}, 'max': {'data': state['num'], - 'operation': 'max'}} + result = {'sum': {'data': state['num'], 'operation': 'sum', 'type': 'int'}, + 'min': {'data': state['num'], 'operation': 'min', 'type': 'int'}, 'max': + {'data': state['num'], 'operation': 'max', 'type': 'int'}} return json.dumps(result) }""" @@ -4575,9 +4566,9 @@ FROM ).fetchone() result = json.loads(secure_transfer_) assert result == { - "sum": {"data": 5, "operation": "sum"}, - "min": {"data": 5, "operation": "min"}, - "max": {"data": 5, "operation": "max"}, + "sum": {"data": 5, "operation": "sum", "type": "int"}, + "min": {"data": 5, "operation": "min", "type": "int"}, + "max": {"data": 5, "operation": "max", "type": "int"}, } @@ -4592,8 +4583,8 @@ class TestUDFGen_SecureTransferOutput_with_SMPC_on( ) def f(state): result = { - "sum": {"data": state["num"], "operation": "sum"}, - "max": {"data": state["num"], "operation": "max"}, + "sum": {"data": state["num"], "operation": "sum", "type": "int"}, + "max": {"data": state["num"], "operation": "max", "type": "int"}, } return result @@ -4628,9 +4619,9 @@ LANGUAGE PYTHON import json __state_str = _conn.execute("SELECT state from test_state_table;")["state"][0] state = pickle.loads(__state_str) - result = {'sum': {'data': state['num'], 'operation': 'sum'}, 'max': {'data': - state['num'], 'operation': 'max'}} - template, sum_op, min_op, max_op, union_op = udfio.split_secure_transfer_dict(result) + result = {'sum': {'data': state['num'], 'operation': 'sum', 'type': 'int'}, + 'max': {'data': state['num'], 'operation': 'max', 'type': 'int'}} + template, sum_op, min_op, max_op = udfio.split_secure_transfer_dict(result) _conn.execute(f"INSERT INTO $main_output_table_name_sum_op VALUES ('$node_id', '{json.dumps(sum_op)}');") _conn.execute(f"INSERT INTO $main_output_table_name_max_op VALUES ('$node_id', '{json.dumps(max_op)}');") return json.dumps(template) @@ -4704,8 +4695,8 @@ FROM ).fetchone() template = json.loads(template_str) assert template == { - "max": {"data": 0, "operation": "max"}, - "sum": {"data": 0, "operation": "sum"}, + "max": {"data": 0, "operation": "max", "type": "int"}, + "sum": {"data": 0, "operation": "sum", "type": "int"}, } sum_op_values_str, *_ = globalnode_db_cursor.execute( @@ -4735,9 +4726,9 @@ class TestUDFGen_SecureTransferOutputAs2ndOutput_with_SMPC_off( ) def f(state): result = { - "sum": {"data": state["num"], "operation": "sum"}, - "min": {"data": state["num"], "operation": "min"}, - "max": {"data": state["num"], "operation": "max"}, + "sum": {"data": state["num"], "operation": "sum", "type": "int"}, + "min": {"data": state["num"], "operation": "min", "type": "int"}, + "max": {"data": state["num"], "operation": "max", "type": "int"}, } return state, result @@ -4772,9 +4763,9 @@ LANGUAGE PYTHON import json __state_str = _conn.execute("SELECT state from test_state_table;")["state"][0] state = pickle.loads(__state_str) - result = {'sum': {'data': state['num'], 'operation': 'sum'}, 'min': {'data': - state['num'], 'operation': 'min'}, 'max': {'data': state['num'], - 'operation': 'max'}} + result = {'sum': {'data': state['num'], 'operation': 'sum', 'type': 'int'}, + 'min': {'data': state['num'], 'operation': 'min', 'type': 'int'}, 'max': + {'data': state['num'], 'operation': 'max', 'type': 'int'}} _conn.execute(f"INSERT INTO $loopback_table_name_0 VALUES ('$node_id', '{json.dumps(result)}');") return pickle.dumps(state) }""" @@ -4829,9 +4820,9 @@ FROM ).fetchone() result = json.loads(secure_transfer_) assert result == { - "sum": {"data": 5, "operation": "sum"}, - "min": {"data": 5, "operation": "min"}, - "max": {"data": 5, "operation": "max"}, + "sum": {"data": 5, "operation": "sum", "type": "int"}, + "min": {"data": 5, "operation": "min", "type": "int"}, + "max": {"data": 5, "operation": "max", "type": "int"}, } @@ -4849,9 +4840,9 @@ class TestUDFGen_SecureTransferOutputAs2ndOutput_with_SMPC_on( ) def f(state): result = { - "sum": {"data": state["num"], "operation": "sum"}, - "min": {"data": state["num"], "operation": "min"}, - "max": {"data": state["num"], "operation": "max"}, + "sum": {"data": state["num"], "operation": "sum", "type": "int"}, + "min": {"data": state["num"], "operation": "min", "type": "int"}, + "max": {"data": state["num"], "operation": "max", "type": "int"}, } return state, result @@ -4886,10 +4877,10 @@ LANGUAGE PYTHON import json __state_str = _conn.execute("SELECT state from test_state_table;")["state"][0] state = pickle.loads(__state_str) - result = {'sum': {'data': state['num'], 'operation': 'sum'}, 'min': {'data': - state['num'], 'operation': 'min'}, 'max': {'data': state['num'], - 'operation': 'max'}} - template, sum_op, min_op, max_op, union_op = udfio.split_secure_transfer_dict(result) + result = {'sum': {'data': state['num'], 'operation': 'sum', 'type': 'int'}, + 'min': {'data': state['num'], 'operation': 'min', 'type': 'int'}, 'max': + {'data': state['num'], 'operation': 'max', 'type': 'int'}} + template, sum_op, min_op, max_op = udfio.split_secure_transfer_dict(result) _conn.execute(f"INSERT INTO $loopback_table_name_0 VALUES ('$node_id', '{json.dumps(template)}');") _conn.execute(f"INSERT INTO $loopback_table_name_0_sum_op VALUES ('$node_id', '{json.dumps(sum_op)}');") _conn.execute(f"INSERT INTO $loopback_table_name_0_min_op VALUES ('$node_id', '{json.dumps(min_op)}');") @@ -4979,9 +4970,9 @@ FROM ).fetchone() template = json.loads(template_str) assert template == { - "sum": {"data": 0, "operation": "sum"}, - "min": {"data": 0, "operation": "min"}, - "max": {"data": 0, "operation": "max"}, + "sum": {"data": 0, "operation": "sum", "type": "int"}, + "min": {"data": 0, "operation": "min", "type": "int"}, + "max": {"data": 0, "operation": "max", "type": "int"}, } sum_op_values_str, *_ = globalnode_db_cursor.execute( @@ -5160,8 +5151,7 @@ LANGUAGE PYTHON __min_op_values = None __max_op_values_str = _conn.execute("SELECT secure_transfer from test_smpc_max_op_values_table;")["secure_transfer"][0] __max_op_values = json.loads(__max_op_values_str) - __union_op_values = None - transfer = udfio.construct_secure_transfer_dict(__template,__sum_op_values,__min_op_values,__max_op_values,__union_op_values) + transfer = udfio.construct_secure_transfer_dict(__template,__sum_op_values,__min_op_values,__max_op_values) return json.dumps(transfer) }""" diff --git a/tests/standalone_tests/test_udfio.py b/tests/standalone_tests/test_udfio.py index 24a1100117937d82a6cc8c79f3141fdf6ab3a28e..5ec2c0cebb454059aa298f9f2d39d1c6e4ffc9d8 100644 --- a/tests/standalone_tests/test_udfio.py +++ b/tests/standalone_tests/test_udfio.py @@ -73,65 +73,86 @@ def test_merge_tensor_to_list_no_nodeid(): def get_secure_transfers_to_merged_dict_success_cases(): secure_transfers_cases = [ - ( + pytest.param( [ { - "a": {"data": 2, "operation": "sum"}, + "a": {"data": 2, "operation": "sum", "type": "int"}, }, { - "a": {"data": 3, "operation": "sum"}, + "a": {"data": 3, "operation": "sum", "type": "int"}, }, ], {"a": 5}, + id="sum operation with ints", ), - ( + pytest.param( [ { - "a": {"data": 2, "operation": "sum"}, - "b": {"data": 5, "operation": "sum"}, + "a": {"data": 2.5, "operation": "sum", "type": "float"}, }, { - "a": {"data": 3, "operation": "sum"}, - "b": {"data": 7, "operation": "sum"}, + "a": {"data": 3.6, "operation": "sum", "type": "float"}, }, ], - {"a": 5, "b": 12}, + {"a": 6.1}, + id="sum operation with floats", ), - ( + pytest.param( + [ + { + "a": {"data": 2, "operation": "sum", "type": "int"}, + "b": {"data": 5, "operation": "sum", "type": "int"}, + "c": {"data": 5.123, "operation": "sum", "type": "float"}, + }, + { + "a": {"data": 3, "operation": "sum", "type": "int"}, + "b": {"data": 7, "operation": "sum", "type": "int"}, + "c": {"data": 5.456, "operation": "sum", "type": "float"}, + }, + ], + {"a": 5, "b": 12, "c": 10.579}, + id="multiple sum operations with ints/floats", + ), + pytest.param( [ { - "a": {"data": [1, 2, 3], "operation": "sum"}, + "a": {"data": [1, 2, 3], "operation": "sum", "type": "int"}, }, { - "a": {"data": [9, 8, 7], "operation": "sum"}, + "a": {"data": [9, 8, 7], "operation": "sum", "type": "int"}, }, ], { "a": [10, 10, 10], }, + id="sum operation with list of ints", ), - ( + pytest.param( [ { - "a": {"data": 10, "operation": "sum"}, + "a": {"data": 10, "operation": "sum", "type": "int"}, "b": { "data": [10, 20, 30, 40, 50, 60], "operation": "sum", + "type": "int", }, "c": { "data": [[10, 20, 30, 40, 50, 60], [70, 80, 90]], "operation": "sum", + "type": "int", }, }, { - "a": {"data": 100, "operation": "sum"}, + "a": {"data": 100, "operation": "sum", "type": "int"}, "b": { "data": [100, 200, 300, 400, 500, 600], "operation": "sum", + "type": "int", }, "c": { "data": [[100, 200, 300, 400, 500, 600], [700, 800, 900]], "operation": "sum", + "type": "int", }, }, ], @@ -140,29 +161,34 @@ def get_secure_transfers_to_merged_dict_success_cases(): "b": [110, 220, 330, 440, 550, 660], "c": [[110, 220, 330, 440, 550, 660], [770, 880, 990]], }, + id="complex sum operations with nested lists", ), - ( + pytest.param( [ { - "sum": {"data": 10, "operation": "sum"}, + "sum": {"data": 10, "operation": "sum", "type": "int"}, "min": { "data": [10, 200, 30, 400, 50, 600], "operation": "min", + "type": "int", }, "max": { "data": [[100, 20, 300, 40, 500, 60], [700, 80, 900]], "operation": "max", + "type": "int", }, }, { - "sum": {"data": 100, "operation": "sum"}, + "sum": {"data": 100, "operation": "sum", "type": "int"}, "min": { "data": [100, 20, 300, 40, 500, 60], "operation": "min", + "type": "int", }, "max": { "data": [[10, 200, 30, 400, 50, 600], [70, 800, 90]], "operation": "max", + "type": "int", }, }, ], @@ -171,6 +197,66 @@ def get_secure_transfers_to_merged_dict_success_cases(): "min": [10, 20, 30, 40, 50, 60], "max": [[100, 200, 300, 400, 500, 600], [700, 800, 900]], }, + id="mixed sum/min/max operations with nested lists", + ), + pytest.param( + [ + { + "sum": {"data": 10, "operation": "sum", "type": "int"}, + "sumfloat": {"data": 10.677, "operation": "sum", "type": "float"}, + "min": { + "data": [10, 200, 30, 400, 50, 600], + "operation": "min", + "type": "int", + }, + "max": { + "data": [[100, 20, 300, 40, 500, 60], [700, 80, 900]], + "operation": "max", + "type": "int", + }, + "maxfloat": { + "data": [ + [100.5, 200.3, 30.4, 400.5678, 50.6, 600.7], + [70.8, 800.9, 90.01], + ], + "operation": "max", + "type": "float", + }, + }, + { + "sum": {"data": 100, "operation": "sum", "type": "int"}, + "sumfloat": {"data": 0.678, "operation": "sum", "type": "float"}, + "min": { + "data": [100, 20, 300, 40, 500, 60], + "operation": "min", + "type": "int", + }, + "max": { + "data": [[10, 200, 30, 400, 50, 600], [70, 800, 90]], + "operation": "max", + "type": "int", + }, + "maxfloat": { + "data": [ + [10.1, 200.25, 30.3, 400.4, 50.5, 600.6], + [72, 8000.8, 90.9], + ], + "operation": "max", + "type": "float", + }, + }, + ], + { + "sum": 110, + "sumfloat": 11.355, + "min": [10, 20, 30, 40, 50, 60], + "max": [[100, 200, 300, 400, 500, 600], [700, 800, 900]], + "maxfloat": [ + [100.5, 200.3, 30.4, 400.5678, 50.6, 600.7], + [72, 8000.8, 90.9], + ], + }, + id="mixed sum/min/max operations with ints/floats and nested lists", ), ] return secure_transfers_cases @@ -183,150 +269,308 @@ def test_secure_transfer_to_merged_dict(transfers, result): assert secure_transfers_to_merged_dict(transfers) == result +def get_secure_transfers_merged_to_dict_fail_cases(): + secure_transfers_fail_cases = [ + ( + [ + { + "a": {"data": 2, "type": "int"}, + }, + { + "a": {"data": 3, "operation": "sum", "type": "int"}, + }, + ], + ( + ValueError, + "Secure Transfer operation is not provided. Expected format: .*", + ), + ), + ( + [ + { + "a": {"data": 2, "operation": "sum", "type": "int"}, + }, + { + "a": {"data": 3, "operation": "sum"}, + }, + ], + (ValueError, "Secure Transfer type is not provided. Expected format: .*"), + ), + ( + [ + { + "a": {"data": 2, "operation": "sum", "type": "int"}, + }, + { + "a": {"data": 3, "operation": "whatever", "type": "int"}, + }, + ], + ( + ValueError, + "Secure Transfer operation is not supported: .*", + ), + ), + ( + [ + { + "a": {"data": 2, "operation": "sum", "type": "int"}, + }, + { + "a": {"data": 3, "operation": "sum", "type": "whatever"}, + }, + ], + ( + ValueError, + "Secure Transfer type is not supported: .*", + ), + ), + ( + [ + { + "a": {"data": 2, "operation": "sum", "type": "int"}, + }, + { + "a": {"data": 3, "operation": "min", "type": "int"}, + }, + ], + ( + ValueError, + "Similar secure transfer keys should have the same operation .*", + ), + ), + ( + [ + { + "a": {"data": 2, "operation": "sum", "type": "int"}, + }, + { + "a": {"data": 3, "operation": "sum", "type": "float"}, + }, + ], + ( + ValueError, + "Similar secure transfer keys should have the same type .*", + ), + ), + ( + [ + { + "a": {"data": 2, "operation": "sum", "type": "int"}, + }, + { + "a": {"data": [3], "operation": "sum", "type": "int"}, + }, + ], + (ValueError, "Secure transfers' data should have the same structure."), + ), + ( + [ + { + "a": {"data": "tet", "operation": "sum", "type": "int"}, + }, + { + "a": {"data": "tet", "operation": "sum", "type": "int"}, + }, + ], + ( + TypeError, + "Secure transfer data must have one of the following types: .*", + ), + ), + ] + return secure_transfers_fail_cases + + +@pytest.mark.parametrize( + "transfers, exception", get_secure_transfers_merged_to_dict_fail_cases() +) +def test_secure_transfers_to_merged_dict_fail_cases(transfers, exception): + exception_type, exception_message = exception + with pytest.raises(exception_type, match=exception_message): + secure_transfers_to_merged_dict(transfers) + + def get_secure_transfer_dict_success_cases(): secure_transfer_cases = [ - ( + pytest.param( { - "a": {"data": 2, "operation": "sum"}, + "a": {"data": 2, "operation": "sum", "type": "int"}, }, ( { - "a": {"data": 0, "operation": "sum"}, + "a": {"data": 0, "operation": "sum", "type": "int"}, }, [2], [], [], - [], ), { "a": 2, }, + id="sum operation with int", ), - ( + pytest.param( { - "a": {"data": 2, "operation": "sum"}, - "b": {"data": 5, "operation": "sum"}, + "a": {"data": 2.5, "operation": "sum", "type": "float"}, }, ( { - "a": {"data": 0, "operation": "sum"}, - "b": {"data": 1, "operation": "sum"}, + "a": {"data": 0, "operation": "sum", "type": "float"}, }, - [2, 5], + [2.5], [], [], + ), + { + "a": 2.5, + }, + id="sum operation with float", + ), + pytest.param( + { + "a": {"data": 2, "operation": "sum", "type": "int"}, + "b": {"data": 5, "operation": "sum", "type": "int"}, + }, + ( + { + "a": {"data": 0, "operation": "sum", "type": "int"}, + "b": {"data": 1, "operation": "sum", "type": "int"}, + }, + [2, 5], + [], [], ), {"a": 2, "b": 5}, + id="sum operation with ints", ), - ( + pytest.param( { - "a": {"data": [1, 2, 3], "operation": "sum"}, + "a": {"data": 2, "operation": "sum", "type": "int"}, + "b": {"data": 5.5, "operation": "sum", "type": "float"}, }, ( { - "a": {"data": [0, 1, 2], "operation": "sum"}, + "a": {"data": 0, "operation": "sum", "type": "int"}, + "b": {"data": 1, "operation": "sum", "type": "float"}, }, - [1, 2, 3], + [2, 5.5], [], [], + ), + {"a": 2, "b": 5.5}, + id="sum operation with int/float", + ), + pytest.param( + { + "a": {"data": [1, 2, 3], "operation": "sum", "type": "int"}, + }, + ( + { + "a": {"data": [0, 1, 2], "operation": "sum", "type": "int"}, + }, + [1, 2, 3], + [], [], ), { "a": [1, 2, 3], }, + id="sum operation with list of ints", ), - ( + pytest.param( { - "a": {"data": 10, "operation": "sum"}, + "a": {"data": 10, "operation": "sum", "type": "int"}, "b": { "data": [10, 20, 30, 40, 50, 60], "operation": "sum", + "type": "int", }, "c": { "data": [[10, 20, 30, 40, 50, 60], [70, 80, 90]], "operation": "sum", + "type": "int", }, }, ( { - "a": {"data": 0, "operation": "sum"}, + "a": {"data": 0, "operation": "sum", "type": "int"}, "b": { "data": [1, 2, 3, 4, 5, 6], "operation": "sum", + "type": "int", }, "c": { "data": [[7, 8, 9, 10, 11, 12], [13, 14, 15]], "operation": "sum", + "type": "int", }, }, [10, 10, 20, 30, 40, 50, 60, 10, 20, 30, 40, 50, 60, 70, 80, 90], [], [], - [], ), { "a": 10, "b": [10, 20, 30, 40, 50, 60], "c": [[10, 20, 30, 40, 50, 60], [70, 80, 90]], }, + id="sum operation with nested lists of ints", ), - ( + pytest.param( { - "min": {"data": [2, 5.6], "operation": "min"}, + "min": {"data": [2, 5.6], "operation": "min", "type": "float"}, }, ( { - "min": {"data": [0, 1], "operation": "min"}, + "min": {"data": [0, 1], "operation": "min", "type": "float"}, }, [], [2, 5.6], [], - [], ), { "min": [2, 5.6], }, + id="min operation with int/float", ), - ( + pytest.param( { - "max": {"data": [2, 5.6], "operation": "max"}, + "max": {"data": [2, 5.6], "operation": "max", "type": "float"}, }, ( { - "max": {"data": [0, 1], "operation": "max"}, + "max": {"data": [0, 1], "operation": "max", "type": "float"}, }, [], [], [2, 5.6], - [], ), { "max": [2, 5.6], }, + id="max operation with int/float", ), - ( + pytest.param( { - "sum1": {"data": [1, 2, 3, 4.5], "operation": "sum"}, - "sum2": {"data": [6, 7.8], "operation": "sum"}, - "min1": {"data": [6, 7.8], "operation": "min"}, - "min2": {"data": [1.5, 2.0], "operation": "min"}, - "max1": {"data": [6.8, 7], "operation": "max"}, - "max2": {"data": [1.5, 2], "operation": "max"}, + "sum1": {"data": [1, 2, 3, 4.5], "operation": "sum", "type": "float"}, + "sum2": {"data": [6, 7.8], "operation": "sum", "type": "float"}, + "min1": {"data": [6, 7.8], "operation": "min", "type": "float"}, + "min2": {"data": [1.5, 2.0], "operation": "min", "type": "float"}, + "max1": {"data": [6.8, 7], "operation": "max", "type": "float"}, + "max2": {"data": [1.5, 2], "operation": "max", "type": "float"}, }, ( { - "sum1": {"data": [0, 1, 2, 3], "operation": "sum"}, - "sum2": {"data": [4, 5], "operation": "sum"}, - "min1": {"data": [0, 1], "operation": "min"}, - "min2": {"data": [2, 3], "operation": "min"}, - "max1": {"data": [0, 1], "operation": "max"}, - "max2": {"data": [2, 3], "operation": "max"}, + "sum1": {"data": [0, 1, 2, 3], "operation": "sum", "type": "float"}, + "sum2": {"data": [4, 5], "operation": "sum", "type": "float"}, + "min1": {"data": [0, 1], "operation": "min", "type": "float"}, + "min2": {"data": [2, 3], "operation": "min", "type": "float"}, + "max1": {"data": [0, 1], "operation": "max", "type": "float"}, + "max2": {"data": [2, 3], "operation": "max", "type": "float"}, }, [1, 2, 3, 4.5, 6, 7.8], [6, 7.8, 1.5, 2.0], [6.8, 7, 1.5, 2], - [], ), { "sum1": [1, 2, 3, 4.5], @@ -336,26 +580,38 @@ def get_secure_transfer_dict_success_cases(): "max1": [6.8, 7], "max2": [1.5, 2], }, + id="mixed sum/min/max operation with mixed ints/floats", ), - ( + pytest.param( { - "sum": {"data": [100, 200, 300], "operation": "sum"}, - "max": {"data": 58, "operation": "max"}, + "sum": {"data": [100, 200, 300], "operation": "sum", "type": "int"}, + "sumfloat": { + "data": [1.2, 2.3, 3.4], + "operation": "sum", + "type": "float", + }, + "max": {"data": 58, "operation": "max", "type": "int"}, }, ( { - "sum": {"data": [0, 1, 2], "operation": "sum"}, - "max": {"data": 0, "operation": "max"}, + "sum": {"data": [0, 1, 2], "operation": "sum", "type": "int"}, + "sumfloat": { + "data": [3, 4, 5], + "operation": "sum", + "type": "float", + }, + "max": {"data": 0, "operation": "max", "type": "int"}, }, - [100, 200, 300], + [100, 200, 300, 1.2, 2.3, 3.4], [], [58], - [], ), { "sum": [100, 200, 300], + "sumfloat": [1.2, 2.3, 3.4], "max": 58, }, + id="sum operations with ints/floats in separate keys", ), ] return secure_transfer_cases @@ -369,96 +625,29 @@ def test_split_secure_transfer_dict(secure_transfer, smpc_parts, final_result): assert split_secure_transfer_dict(secure_transfer) == smpc_parts -@pytest.mark.parametrize( - "secure_transfer, smpc_parts, final_result", - get_secure_transfer_dict_success_cases(), -) -def test_construct_secure_transfer_dict(secure_transfer, smpc_parts, final_result): - assert construct_secure_transfer_dict(*smpc_parts) == final_result - - -def get_secure_transfers_merged_to_dict_fail_cases(): - secure_transfers_fail_cases = [ +def get_split_secure_transfer_dict_fail_cases(): + split_secure_transfer_dict_fail_cases = [ ( - [ - { - "a": {"data": 2, "operation": "sum"}, - }, - { - "a": {"data": 3, "operation": "whatever"}, - }, - ], + { + "a": {"data": 3, "operation": "whatever", "type": "int"}, + }, ( ValueError, "Secure Transfer operation is not supported: .*", ), ), - ( - [ - { - "a": {"data": 2, "operation": "sum"}, - }, - { - "a": {"data": 3, "operation": "min"}, - }, - ], - ( - ValueError, - "All secure transfer keys should have the same 'operation' .*", - ), - ), - ( - [ - { - "a": {"data": 2, "operation": "sum"}, - }, - { - "a": {"data": [3], "operation": "sum"}, - }, - ], - (ValueError, "Secure transfers' data should have the same structure."), - ), - ( - [ - { - "a": {"data": "tet", "operation": "sum"}, - }, - { - "a": {"data": "tet", "operation": "sum"}, - }, - ], - ( - TypeError, - "Secure transfer data must have one of the following types: .*", - ), - ), - ] - return secure_transfers_fail_cases - - -@pytest.mark.parametrize( - "transfers, exception", get_secure_transfers_merged_to_dict_fail_cases() -) -def test_secure_transfers_to_merged_dict_fail_cases(transfers, exception): - exception_type, exception_message = exception - with pytest.raises(exception_type, match=exception_message): - secure_transfers_to_merged_dict(transfers) - - -def get_split_secure_transfer_dict_fail_cases(): - split_secure_transfer_dict_fail_cases = [ ( { - "a": {"data": 3, "operation": "whatever"}, + "a": {"data": 3, "operation": "sum", "type": "whatever"}, }, ( ValueError, - "Secure Transfer operation is not supported: .*", + "Secure Transfer type is not supported: .*", ), ), ( { - "a": {"data": "tet", "operation": "sum"}, + "a": {"data": "tet", "operation": "sum", "type": "int"}, }, ( TypeError, @@ -467,7 +656,7 @@ def get_split_secure_transfer_dict_fail_cases(): ), ( { - "a": {"llalal": 0, "operation": "sum"}, + "a": {"llalal": 0, "operation": "sum", "type": "int"}, }, ( ValueError, @@ -483,6 +672,15 @@ def get_split_secure_transfer_dict_fail_cases(): "Each Secure Transfer key should contain an operation.", ), ), + ( + { + "a": {"data": 0, "operation": "sum"}, + }, + ( + ValueError, + "Each Secure Transfer key should contain a type.", + ), + ), ] return split_secure_transfer_dict_fail_cases @@ -494,3 +692,34 @@ def test_split_secure_transfer_dict_fail_cases(result, exception): exception_type, exception_message = exception with pytest.raises(exception_type, match=exception_message): split_secure_transfer_dict(result) + + +@pytest.mark.parametrize( + "secure_transfer, smpc_parts, final_result", + get_secure_transfer_dict_success_cases(), +) +def test_construct_secure_transfer_dict(secure_transfer, smpc_parts, final_result): + assert construct_secure_transfer_dict(*smpc_parts) == final_result + + +def test_proper_int_casting_in_construct_secure_transfer_dict(): + """ + SMPC will only return floats, so we need to make sure that the final values will + be converted to the proper int type, if an int type is provided. + """ + input_values = ( + { + "sum": {"data": [0, 1, 2], "operation": "sum", "type": "int"}, + "min": {"data": [0, 1, 2], "operation": "sum", "type": "int"}, + "max": {"data": 0, "operation": "max", "type": "int"}, + }, + [100.0, 200.0, 300.0], + [10.0, 20.0, 30.0], + [1.0], + ) + + assert construct_secure_transfer_dict(*input_values) == { + "sum": [100, 200, 300], + "min": [100, 200, 300], + "max": 1, + }