diff --git a/mipengine/controller/algorithm_executor.py b/mipengine/controller/algorithm_executor.py index c6a1ad19105ed9ed8c3c11c0481c9a1f13b0bb1c..751d1faf653ce286669c21b166daa4ef9c9d44ff 100644 --- a/mipengine/controller/algorithm_executor.py +++ b/mipengine/controller/algorithm_executor.py @@ -15,26 +15,26 @@ from mipengine.controller import config as ctrl_config from mipengine.controller import controller_logger as ctrl_logger from mipengine.controller.algorithm_execution_DTOs import AlgorithmExecutionDTO from mipengine.controller.algorithm_execution_DTOs import NodesTasksHandlersDTO -from mipengine.controller.algorithm_executor_node_data_objects import AlgoExecData -from mipengine.controller.algorithm_executor_node_data_objects import GlobalNodeData -from mipengine.controller.algorithm_executor_node_data_objects import ( +from mipengine.controller.algorithm_flow_data_objects import AlgoFlowData +from mipengine.controller.algorithm_flow_data_objects import GlobalNodeData +from mipengine.controller.algorithm_flow_data_objects import ( GlobalNodeSMPCTables, ) -from mipengine.controller.algorithm_executor_node_data_objects import GlobalNodeTable -from mipengine.controller.algorithm_executor_node_data_objects import LocalNodesData -from mipengine.controller.algorithm_executor_node_data_objects import ( +from mipengine.controller.algorithm_flow_data_objects import GlobalNodeTable +from mipengine.controller.algorithm_flow_data_objects import LocalNodesData +from mipengine.controller.algorithm_flow_data_objects import ( LocalNodesSMPCTables, ) -from mipengine.controller.algorithm_executor_node_data_objects import LocalNodesTable -from mipengine.controller.algorithm_executor_node_data_objects import NodeData -from mipengine.controller.algorithm_executor_node_data_objects import NodeSMPCTables -from mipengine.controller.algorithm_executor_node_data_objects import NodeTable -from mipengine.controller.algorithm_executor_node_data_objects import ( +from mipengine.controller.algorithm_flow_data_objects import LocalNodesTable +from mipengine.controller.algorithm_flow_data_objects import ( algoexec_udf_kwargs_to_node_udf_kwargs, ) -from mipengine.controller.algorithm_executor_node_data_objects import ( +from mipengine.controller.algorithm_flow_data_objects import ( algoexec_udf_posargs_to_node_udf_posargs, ) +from mipengine.controller.algorithm_executor_node_data_objects import NodeData +from mipengine.controller.algorithm_executor_node_data_objects import SMPCTableNames +from mipengine.controller.algorithm_executor_node_data_objects import TableName from mipengine.controller.algorithm_executor_nodes import GlobalNode from mipengine.controller.algorithm_executor_nodes import LocalNode from mipengine.controller.algorithm_executor_smpc_helper import get_smpc_results @@ -71,13 +71,13 @@ class NodeDownAlgorithmExecutionException(Exception): class InconsistentTableSchemasException(Exception): - def __init__(self, tables_schemas: Dict[NodeTable, TableSchema]): + def __init__(self, tables_schemas: Dict[TableName, TableSchema]): message = f"Tables: {tables_schemas} do not have a common schema" super().__init__(message) class InconsistentUDFResultSizeException(Exception): - def __init__(self, result_tables: Dict[int, List[Tuple[LocalNode, NodeTable]]]): + def __init__(self, result_tables: Dict[int, List[Tuple[LocalNode, TableName]]]): message = ( f"The following udf execution results on multiple nodes should have " f"the same number of results.\nResults:{result_tables}" @@ -321,7 +321,7 @@ class _AlgorithmExecutionInterface: positional_args: Optional[List[Any]] = None, keyword_args: Optional[Dict[str, Any]] = None, share_to_global: Union[None, bool, List[bool]] = None, - ) -> Union[AlgoExecData, List[AlgoExecData]]: + ) -> Union[AlgoFlowData, List[AlgoFlowData]]: # 1. check positional_args and keyword_args tables do not contain _GlobalNodeTable(s) # 2. queues run_udf task on all local nodes # 3. waits for all nodes to complete the tasks execution @@ -394,9 +394,9 @@ class _AlgorithmExecutionInterface: for nodes_result in all_nodes_results: # All nodes' results have the same type so only the first_result is needed to define the type first_result = nodes_result[0][1] - if isinstance(first_result, NodeTable): + if isinstance(first_result, TableName): results.append(LocalNodesTable(dict(nodes_result))) - elif isinstance(first_result, NodeSMPCTables): + elif isinstance(first_result, SMPCTableNames): results.append(LocalNodesSMPCTables(dict(nodes_result))) else: raise NotImplementedError @@ -409,7 +409,7 @@ class _AlgorithmExecutionInterface: ) -> GlobalNodeData: if isinstance(local_nodes_data, LocalNodesTable): return self._share_local_table_to_global( - local_node_table=local_nodes_data, + local_nodes_table=local_nodes_data, command_id=command_id, ) elif isinstance(local_nodes_data, LocalNodesSMPCTables): @@ -419,10 +419,10 @@ class _AlgorithmExecutionInterface: def _share_local_table_to_global( self, - local_node_table: LocalNodesTable, + local_nodes_table: LocalNodesTable, command_id: int, ) -> GlobalNodeTable: - nodes_tables = local_node_table.nodes_tables + nodes_tables = local_nodes_table.nodes_tables # check the tables have the same schema common_schema = self._validate_same_schema_tables(nodes_tables) @@ -444,26 +444,29 @@ class _AlgorithmExecutionInterface: def _share_local_smpc_tables_to_global( self, - nodes_smpc_tables: LocalNodesSMPCTables, + local_nodes_smpc_tables: LocalNodesSMPCTables, command_id: int, ) -> GlobalNodeSMPCTables: global_template_table = self._share_local_table_to_global( - local_node_table=nodes_smpc_tables.template, command_id=command_id + local_nodes_table=local_nodes_smpc_tables.template_local_nodes_table, + command_id=command_id, ) self._global_node.validate_smpc_templates_match( global_template_table.table.full_table_name ) - smpc_clients_per_op = load_data_to_smpc_clients(command_id, nodes_smpc_tables) + smpc_clients_per_op = load_data_to_smpc_clients( + command_id, local_nodes_smpc_tables + ) - (add_op, min_op, max_op, union_op,) = trigger_smpc_computations( + (sum_op, min_op, max_op, union_op,) = trigger_smpc_computations( context_id=self._global_node.context_id, command_id=command_id, smpc_clients_per_op=smpc_clients_per_op, ) ( - add_op_result_table, + sum_op_result_table, min_op_result_table, max_op_result_table, union_op_result_table, @@ -471,18 +474,21 @@ class _AlgorithmExecutionInterface: node=self._global_node, context_id=self._global_node.context_id, command_id=command_id, - add_op=add_op, + sum_op=sum_op, min_op=min_op, max_op=max_op, union_op=union_op, ) return GlobalNodeSMPCTables( - template=global_template_table, - add_op=add_op_result_table, - min_op=min_op_result_table, - max_op=max_op_result_table, - union_op=union_op_result_table, + node=self._global_node, + smpc_tables=SMPCTableNames( + template=global_template_table.table, + sum_op=sum_op_result_table, + min_op=min_op_result_table, + max_op=max_op_result_table, + union_op=union_op_result_table, + ), ) def run_udf_on_global_node( @@ -491,7 +497,7 @@ class _AlgorithmExecutionInterface: positional_args: Optional[List[Any]] = None, keyword_args: Optional[Dict[str, Any]] = None, share_to_locals: Union[None, bool, List[bool]] = None, - ) -> Union[AlgoExecData, List[AlgoExecData]]: + ) -> Union[AlgoFlowData, List[AlgoFlowData]]: # 1. check positional_args and keyword_args tables do not contain _LocalNodeTable(s) # 2. queue run_udf on the global node # 3. wait for it to complete @@ -547,7 +553,7 @@ class _AlgorithmExecutionInterface: def _convert_global_udf_results_to_global_node_data( self, - node_tables: List[NodeTable], + node_tables: List[TableName], ) -> List[GlobalNodeTable]: global_tables = [ GlobalNodeTable( @@ -660,7 +666,7 @@ class _AlgorithmExecutionInterface: raise InconsistentShareTablesValueException(share_to, number_of_results) def _validate_same_schema_tables( - self, tables: Dict[LocalNode, NodeTable] + self, tables: Dict[LocalNode, TableName] ) -> TableSchema: """ Returns : TableSchema the common TableSchema, if all tables have the same schema @@ -697,25 +703,17 @@ class _SingleLocalNodeAlgorithmExecutionInterface(_AlgorithmExecutionInterface): ) elif isinstance(local_nodes_data, LocalNodesSMPCTables): return GlobalNodeSMPCTables( - template=GlobalNodeTable( - node=self._global_node, - table=local_nodes_data.template.nodes_tables[self._local_nodes[0]], - ), - add_op=GlobalNodeTable( - node=self._global_node, - table=local_nodes_data.add_op.nodes_tables[self._local_nodes[0]], - ), - min_op=GlobalNodeTable( - node=self._global_node, - table=local_nodes_data.min_op.nodes_tables[self._local_nodes[0]], - ), - max_op=GlobalNodeTable( - node=self._global_node, - table=local_nodes_data.max_op.nodes_tables[self._local_nodes[0]], - ), - union_op=GlobalNodeTable( - node=self._global_node, - table=local_nodes_data.union_op.nodes_tables[self._local_nodes[0]], + node=self._global_node, + smpc_tables=SMPCTableNames( + template=local_nodes_data.nodes_smpc_tables[ + self._global_node + ].template, + sum_op=local_nodes_data.nodes_smpc_tables[self._global_node].sum_op, + min_op=local_nodes_data.nodes_smpc_tables[self._global_node].min_op, + max_op=local_nodes_data.nodes_smpc_tables[self._global_node].max_op, + union_op=local_nodes_data.nodes_smpc_tables[ + self._global_node + ].union_op, ), ) diff --git a/mipengine/controller/algorithm_executor_node_data_objects.py b/mipengine/controller/algorithm_executor_node_data_objects.py index fc70c338f99e652cd63bf4adbe6c1463635aff5d..cadf97b7f4da056e6d59d9bbdb373671ce430ef4 100644 --- a/mipengine/controller/algorithm_executor_node_data_objects.py +++ b/mipengine/controller/algorithm_executor_node_data_objects.py @@ -1,28 +1,15 @@ from abc import ABC -from typing import Any -from typing import Dict -from typing import List -from typing import Union - -from mipengine.node_tasks_DTOs import NodeLiteralDTO -from mipengine.node_tasks_DTOs import NodeSMPCDTO -from mipengine.node_tasks_DTOs import NodeSMPCValueDTO -from mipengine.node_tasks_DTOs import NodeTableDTO -from mipengine.node_tasks_DTOs import NodeUDFDTO -from mipengine.node_tasks_DTOs import TableSchema -from mipengine.node_tasks_DTOs import UDFKeyArguments -from mipengine.node_tasks_DTOs import UDFPosArguments class NodeData(ABC): """ - NodeData are located into one specific Node. + NodeData is an object representing data located into one specific Node. """ pass -class NodeTable(NodeData): +class TableName(NodeData): def __init__(self, table_name): self._full_name = table_name full_name_split = self._full_name.split("_") @@ -71,285 +58,16 @@ class NodeTable(NodeData): return self.full_table_name -class NodeSMPCTables(NodeData): - template: NodeTable - add_op: NodeTable - min_op: NodeTable - max_op: NodeTable - union_op: NodeTable +class SMPCTableNames(NodeData): + template: TableName + sum_op: TableName + min_op: TableName + max_op: TableName + union_op: TableName - def __init__(self, template, add_op, min_op, max_op, union_op): + def __init__(self, template, sum_op, min_op, max_op, union_op): self.template = template - self.add_op = add_op + self.sum_op = sum_op self.min_op = min_op self.max_op = max_op self.union_op = union_op - - -class AlgoExecData(ABC): - """ - AlgoExecData are representing one data object but could be located into - more than one node. For example the LocalNodesTable is treated as one table - from the algorithm executor but is located in multiple local nodes. - """ - - pass - - -class LocalNodesData(AlgoExecData, ABC): - pass - - -class GlobalNodeData(AlgoExecData, ABC): - pass - - -# When _AlgorithmExecutionInterface::run_udf_on_local_nodes(..) is called, depending on -# how many local nodes are participating in the current algorithm execution, several -# database tables are created on all participating local nodes. Irrespectevely of the -# number of local nodes participating, the number of tables created on each of these local -# nodes will be the same. -# Class _LocalNodeTable is the structure that represents the concept of these database -# tables, created during the execution of a udf, in the algorithm execution layer. A key -# concept is that a _LocalNodeTable stores 'pointers' to 'relevant' tables existing in -# different local nodes accross the federation. By 'relevant' I mean tables that are -# generated when triggering a udf execution accross several local nodes. By 'pointers' -# I mean mapping between local nodes and table names and the aim is to hide the underline -# complexity from the algorithm flow and exposing a single 'local node table' object that -# stores in the background pointers to several tables in several local nodes. -class LocalNodesTable(LocalNodesData): - def __init__(self, nodes_tables: Dict["LocalNode", NodeTable]): - self._nodes_tables = nodes_tables - self._validate_matching_table_names(list(self._nodes_tables.values())) - - @property - def nodes_tables(self) -> Dict["LocalNode", NodeTable]: - return self._nodes_tables - - # TODO this is redundant, either remove it or overload all node methods here? - def get_table_schema(self) -> TableSchema: - node = list(self.nodes_tables.keys())[0] - table = self.nodes_tables[node] - return node.get_table_schema(table) - - def get_table_data(self) -> List[Union[int, float, str]]: - tables_data = [] - for node, table_name in self.nodes_tables.items(): - tables_data.append(node.get_table_data(table_name)) - tables_data_flat = [table_data.columns for table_data in tables_data] - tables_data_flat = [ - elem - for table in tables_data_flat - for column in table - for elem in column.data - ] - return tables_data_flat - - def __repr__(self): - r = f"\n\tLocalNodeTable: {self.get_table_schema()}\n" - for node, table_name in self.nodes_tables.items(): - r += f"\t{node=} {table_name=}\n" - return r - - def _validate_matching_table_names(self, table_names: List[NodeTable]): - table_name_without_node_id = table_names[0].without_node_id() - for table_name in table_names: - if table_name.without_node_id() != table_name_without_node_id: - raise self.MismatchingTableNamesException( - [table_name.full_table_name for table_name in table_names] - ) - - class MismatchingTableNamesException(Exception): - def __init__(self, table_names: List[str]): - message = f"Mismatched table names ->{table_names}" - super().__init__(message) - self.message = message - - -class GlobalNodeTable(GlobalNodeData): - def __init__(self, node: "GlobalNode", table: NodeTable): - self._node = node - self._table = table - - @property - def node(self) -> "GlobalNode": - return self._node - - @property - def table(self) -> NodeTable: - return self._table - - # TODO this is redundant, either remove it or overload all node methods here? - def get_table_schema(self) -> TableSchema: - return self._node.get_table_schema(self.table) - - def get_table_data(self) -> List[List[Any]]: - table_data = [ - column.data for column in self.node.get_table_data(self.table).columns - ] - return table_data - - def __repr__(self): - r = f"\n\tGlobalNodeTable: \n\tschema={self.get_table_schema()}\n \t{self.table=}\n" - return r - - -class LocalNodesSMPCTables(LocalNodesData): - template: LocalNodesTable - add_op: LocalNodesTable - min_op: LocalNodesTable - max_op: LocalNodesTable - union_op: LocalNodesTable - - def __init__(self, nodes_smpc_tables: Dict["LocalNode", NodeSMPCTables]): - template_nodes_tables = {} - add_op_nodes_tables = {} - min_op_nodes_tables = {} - max_op_nodes_tables = {} - union_op_nodes_tables = {} - for node, node_smpc_tables in nodes_smpc_tables.items(): - template_nodes_tables[node] = node_smpc_tables.template - add_op_nodes_tables[node] = node_smpc_tables.add_op - min_op_nodes_tables[node] = node_smpc_tables.min_op - max_op_nodes_tables[node] = node_smpc_tables.max_op - union_op_nodes_tables[node] = node_smpc_tables.union_op - self.template = LocalNodesTable(template_nodes_tables) - self.add_op = create_local_nodes_table_from_nodes_tables(add_op_nodes_tables) - self.min_op = create_local_nodes_table_from_nodes_tables(min_op_nodes_tables) - self.max_op = create_local_nodes_table_from_nodes_tables(max_op_nodes_tables) - self.union_op = create_local_nodes_table_from_nodes_tables( - union_op_nodes_tables - ) - - -class GlobalNodeSMPCTables(GlobalNodeData): - template: GlobalNodeTable - add_op: GlobalNodeTable - min_op: GlobalNodeTable - max_op: GlobalNodeTable - union_op: GlobalNodeTable - - def __init__(self, template, add_op, min_op, max_op, union_op): - self.template = template - self.add_op = add_op - self.min_op = min_op - self.max_op = max_op - self.union_op = union_op - - -def algoexec_udf_kwargs_to_node_udf_kwargs( - algoexec_kwargs: Dict[str, Any], - local_node: "LocalNode" = None, -) -> UDFKeyArguments: - if not algoexec_kwargs: - return UDFKeyArguments(args={}) - - args = {} - for key, arg in algoexec_kwargs.items(): - udf_argument = _algoexec_udf_arg_to_node_udf_arg(arg, local_node) - args[key] = udf_argument - return UDFKeyArguments(args=args) - - -def algoexec_udf_posargs_to_node_udf_posargs( - algoexec_posargs: List[Any], - local_node: "LocalNode" = None, -) -> UDFPosArguments: - if not algoexec_posargs: - return UDFPosArguments(args=[]) - - args = [] - for arg in algoexec_posargs: - args.append(_algoexec_udf_arg_to_node_udf_arg(arg, local_node)) - return UDFPosArguments(args=args) - - -def _algoexec_udf_arg_to_node_udf_arg( - algoexec_arg: AlgoExecData, local_node: "LocalNode" = None -) -> NodeUDFDTO: - """ - Converts the algorithm executor run_udf input arguments, coming from the algorithm flow - to node udf pos/key arguments to be send to the NODE. - - Parameters - ---------- - algoexec_arg is the argument to be converted. - local_node is need only when the algoexec_arg is of LocalNodesTable, to know - which local table should be selected. - - Returns - ------- - a NodeUDFDTO - """ - if isinstance(algoexec_arg, LocalNodesTable): - if not local_node: - raise ValueError( - "local_node parameter is required on LocalNodesTable convertion." - ) - return NodeTableDTO(value=algoexec_arg.nodes_tables[local_node].full_table_name) - elif isinstance(algoexec_arg, GlobalNodeTable): - return NodeTableDTO(value=algoexec_arg.table.full_table_name) - elif isinstance(algoexec_arg, LocalNodesSMPCTables): - return NodeSMPCDTO( - value=NodeSMPCValueDTO( - template=algoexec_arg.template.nodes_tables[local_node].full_table_name, - add_op_values=algoexec_arg.add_op.nodes_tables[ - local_node - ].full_table_name, - min_op_values=algoexec_arg.min_op.nodes_tables[ - local_node - ].full_table_name, - max_op_values=algoexec_arg.max_op.nodes_tables[ - local_node - ].full_table_name, - union_op_values=algoexec_arg.union_op.nodes_tables[ - local_node - ].full_table_name, - ) - ) - elif isinstance(algoexec_arg, GlobalNodeSMPCTables): - return NodeSMPCDTO( - value=NodeSMPCValueDTO( - template=NodeTableDTO( - value=algoexec_arg.template.table.full_table_name - ), - add_op_values=create_node_table_dto_from_global_node_table( - algoexec_arg.add_op - ), - min_op_values=create_node_table_dto_from_global_node_table( - algoexec_arg.min_op - ), - max_op_values=create_node_table_dto_from_global_node_table( - algoexec_arg.max_op - ), - union_op_values=create_node_table_dto_from_global_node_table( - algoexec_arg.union_op - ), - ) - ) - else: - return NodeLiteralDTO(value=algoexec_arg) - - -def create_node_table_from_node_table_dto(node_table_dto: NodeTableDTO): - if not node_table_dto: - return None - - return NodeTable(table_name=node_table_dto.value) - - -def create_node_table_dto_from_global_node_table(node_table: GlobalNodeTable): - if not node_table: - return None - - return NodeTableDTO(value=node_table.table.full_table_name) - - -def create_local_nodes_table_from_nodes_tables( - nodes_tables: Dict["LocalNode", Union[NodeTable, None]] -): - for table in nodes_tables.values(): - if not table: - return None - - return LocalNodesTable(nodes_tables) diff --git a/mipengine/controller/algorithm_executor_nodes.py b/mipengine/controller/algorithm_executor_nodes.py index 10d25a4d29372c2fecc172df5074ebac0a950f05..125afb1181c946e33d0b0cc1044f69457e3cdd9f 100644 --- a/mipengine/controller/algorithm_executor_nodes.py +++ b/mipengine/controller/algorithm_executor_nodes.py @@ -3,14 +3,12 @@ from abc import abstractmethod from typing import Any from typing import Dict from typing import List +from typing import Optional from typing import Tuple from mipengine.controller.algorithm_executor_node_data_objects import NodeData -from mipengine.controller.algorithm_executor_node_data_objects import NodeSMPCTables -from mipengine.controller.algorithm_executor_node_data_objects import NodeTable -from mipengine.controller.algorithm_executor_node_data_objects import ( - create_node_table_from_node_table_dto, -) +from mipengine.controller.algorithm_executor_node_data_objects import SMPCTableNames +from mipengine.controller.algorithm_executor_node_data_objects import TableName from mipengine.controller.node_tasks_handler_interface import INodeTasksHandler from mipengine.controller.node_tasks_handler_interface import IQueuedUDFAsyncResult from mipengine.node_tasks_DTOs import NodeSMPCDTO @@ -24,23 +22,23 @@ from mipengine.node_tasks_DTOs import UDFPosArguments class _INode(ABC): @abstractmethod - def get_tables(self) -> List[NodeTable]: + def get_tables(self) -> List[TableName]: pass @abstractmethod - def get_table_schema(self, table_name: NodeTable) -> TableSchema: + def get_table_schema(self, table_name: TableName) -> TableSchema: pass @abstractmethod - def get_table_data(self, table_name: NodeTable) -> TableData: + def get_table_data(self, table_name: TableName) -> TableData: pass @abstractmethod - def create_table(self, command_id: str, schema: TableSchema) -> NodeTable: + def create_table(self, command_id: str, schema: TableSchema) -> TableName: pass @abstractmethod - def get_views(self) -> List[NodeTable]: + def get_views(self) -> List[TableName]: pass @abstractmethod @@ -50,11 +48,11 @@ class _INode(ABC): pathology: str, columns: List[str], filters: List[str], - ) -> NodeTable: + ) -> TableName: pass @abstractmethod - def get_merge_tables(self) -> List[NodeTable]: + def get_merge_tables(self) -> List[TableName]: pass @abstractmethod @@ -67,7 +65,7 @@ class _INode(ABC): @abstractmethod def create_remote_table( - self, table_name: str, table_schema: TableSchema, native_node: "_Node" + self, table_name: str, table_schema: TableSchema, native_node: "_INode" ): pass @@ -84,7 +82,7 @@ class _INode(ABC): @abstractmethod def get_queued_udf_result( self, async_result: IQueuedUDFAsyncResult - ) -> List[NodeTable]: + ) -> List[TableName]: pass @abstractmethod @@ -115,12 +113,12 @@ class _Node(_INode, ABC): return f"{self.node_id}" @property - def initial_view_tables(self) -> Dict[str, NodeTable]: + def initial_view_tables(self) -> Dict[str, TableName]: return self._initial_view_tables def _create_initial_view_tables( self, initial_view_tables_params - ) -> Dict[str, NodeTable]: + ) -> Dict[str, TableName]: # will contain the views created from the pathology, datasets. Its keys are # the variable sets x, y etc initial_view_tables = {} @@ -157,9 +155,9 @@ class _Node(_INode, ABC): return self._node_tasks_handler.node_data_address # TABLES functionality - def get_tables(self) -> List[NodeTable]: + def get_tables(self) -> List[TableName]: tables = [ - NodeTable(table_name) + TableName(table_name) for table_name in self._node_tasks_handler.get_tables( request_id=self.request_id, context_id=self.context_id, @@ -167,19 +165,19 @@ class _Node(_INode, ABC): ] return tables - def get_table_schema(self, table_name: NodeTable) -> TableSchema: + def get_table_schema(self, table_name: TableName) -> TableSchema: return self._node_tasks_handler.get_table_schema( request_id=self.request_id, table_name=table_name.full_table_name ) - def get_table_data(self, table_name: NodeTable) -> TableData: + def get_table_data(self, table_name: TableName) -> TableData: return self._node_tasks_handler.get_table_data( request_id=self.request_id, table_name=table_name.full_table_name, ) - def create_table(self, command_id: str, schema: TableSchema) -> NodeTable: - return NodeTable( + def create_table(self, command_id: str, schema: TableSchema) -> TableName: + return TableName( self._node_tasks_handler.create_table( request_id=self.request_id, context_id=self.context_id, @@ -189,11 +187,11 @@ class _Node(_INode, ABC): ) # VIEWS functionality - def get_views(self) -> List[NodeTable]: + def get_views(self) -> List[TableName]: result = self._node_tasks_handler.get_views( request_id=self.request_id, context_id=self.context_id ) - return [NodeTable(table_name) for table_name in result] + return [TableName(table_name) for table_name in result] # TODO: this is very specific to mip, very inconsistent with the rest, has to # be abstracted somehow @@ -203,7 +201,7 @@ class _Node(_INode, ABC): pathology: str, columns: List[str], filters: List[str], - ) -> NodeTable: + ) -> TableName: result = self._node_tasks_handler.create_pathology_view( request_id=self.request_id, context_id=self.context_id, @@ -212,23 +210,23 @@ class _Node(_INode, ABC): columns=columns, filters=filters, ) - return NodeTable(result) + return TableName(result) # MERGE TABLES functionality - def get_merge_tables(self) -> List[NodeTable]: + def get_merge_tables(self) -> List[TableName]: result = self._node_tasks_handler.get_merge_tables( request_id=self.request_id, context_id=self.context_id ) - return [NodeTable(table_name) for table_name in result] + return [TableName(table_name) for table_name in result] - def create_merge_table(self, command_id: str, table_names: List[str]) -> NodeTable: + def create_merge_table(self, command_id: str, table_names: List[str]) -> TableName: result = self._node_tasks_handler.create_merge_table( request_id=self.request_id, context_id=self.context_id, command_id=command_id, table_names=table_names, ) - return NodeTable(result) + return TableName(result) # REMOTE TABLES functionality def get_remote_tables(self) -> List[str]: @@ -299,13 +297,13 @@ class LocalNode(_Node): udf_results = [] for result in node_udf_results.results: if isinstance(result, NodeTableDTO): - udf_results.append(NodeTable(result.value)) + udf_results.append(TableName(result.value)) elif isinstance(result, NodeSMPCDTO): udf_results.append( - NodeSMPCTables( - template=NodeTable(result.value.template.value), - add_op=create_node_table_from_node_table_dto( - result.value.add_op_values + SMPCTableNames( + template=TableName(result.value.template.value), + sum_op=create_node_table_from_node_table_dto( + result.value.sum_op_values ), min_op=create_node_table_from_node_table_dto( result.value.min_op_values @@ -328,15 +326,22 @@ class LocalNode(_Node): ) +def create_node_table_from_node_table_dto(node_table_dto: NodeTableDTO): + if not node_table_dto: + return None + + return TableName(table_name=node_table_dto.value) + + class GlobalNode(_Node): def get_queued_udf_result( self, async_result: IQueuedUDFAsyncResult - ) -> List[NodeTable]: + ) -> List[TableName]: node_udf_results = self._node_tasks_handler.get_queued_udf_result(async_result) results = [] for result in node_udf_results.results: if isinstance(result, NodeTableDTO): - results.append(NodeTable(result.value)) + results.append(TableName(result.value)) elif isinstance(result, NodeSMPCDTO): raise TypeError("A global node should not return an SMPC DTO.") else: @@ -353,9 +358,13 @@ class GlobalNode(_Node): def get_smpc_result( self, - command_id: int, jobid: str, + command_id: str, + command_subid: Optional[str] = "0", ) -> str: return self._node_tasks_handler.get_smpc_result( - self.context_id, str(command_id), jobid + jobid=jobid, + context_id=self.context_id, + command_id=str(command_id), + command_subid=command_subid, ) diff --git a/mipengine/controller/algorithm_executor_smpc_helper.py b/mipengine/controller/algorithm_executor_smpc_helper.py index 70b18deb0392d0f1873b7749d67a3caac79b7905..a9aca75a9c363e6a7848f410f7aca700a6ee11cf 100644 --- a/mipengine/controller/algorithm_executor_smpc_helper.py +++ b/mipengine/controller/algorithm_executor_smpc_helper.py @@ -2,12 +2,13 @@ from typing import List from typing import Tuple from mipengine.controller import config as ctrl_config -from mipengine.controller.algorithm_executor_node_data_objects import GlobalNodeTable -from mipengine.controller.algorithm_executor_node_data_objects import ( +from mipengine.controller.algorithm_executor_node_data_objects import SMPCTableNames +from mipengine.controller.algorithm_flow_data_objects import GlobalNodeTable +from mipengine.controller.algorithm_flow_data_objects import ( LocalNodesSMPCTables, ) -from mipengine.controller.algorithm_executor_node_data_objects import LocalNodesTable -from mipengine.controller.algorithm_executor_node_data_objects import NodeTable +from mipengine.controller.algorithm_flow_data_objects import LocalNodesTable +from mipengine.controller.algorithm_executor_node_data_objects import TableName from mipengine.controller.algorithm_executor_nodes import GlobalNode from mipengine.smpc_DTOs import SMPCRequestType from mipengine.smpc_cluster_comm_helpers import trigger_smpc_computation @@ -37,8 +38,8 @@ def load_operation_data_to_smpc_clients( def load_data_to_smpc_clients( command_id: int, smpc_tables: LocalNodesSMPCTables ) -> Tuple[List[int], List[int], List[int], List[int]]: - add_op_smpc_clients = load_operation_data_to_smpc_clients( - command_id, smpc_tables.add_op, SMPCRequestType.SUM + sum_op_smpc_clients = load_operation_data_to_smpc_clients( + command_id, smpc_tables.sum_op, SMPCRequestType.SUM ) min_op_smpc_clients = load_operation_data_to_smpc_clients( command_id, smpc_tables.min_op, SMPCRequestType.MIN @@ -50,7 +51,7 @@ def load_data_to_smpc_clients( command_id, smpc_tables.union_op, SMPCRequestType.UNION ) return ( - add_op_smpc_clients, + sum_op_smpc_clients, min_op_smpc_clients, max_op_smpc_clients, union_op_smpc_clients, @@ -83,13 +84,13 @@ def trigger_smpc_computations( smpc_clients_per_op: Tuple[List[int], List[int], List[int], List[int]], ) -> Tuple[bool, bool, bool, bool]: ( - add_op_smpc_clients, + sum_op_smpc_clients, min_op_smpc_clients, max_op_smpc_clients, union_op_smpc_clients, ) = smpc_clients_per_op - add_op = trigger_smpc_operation_computation( - context_id, command_id, SMPCRequestType.SUM, add_op_smpc_clients + sum_op = trigger_smpc_operation_computation( + context_id, command_id, SMPCRequestType.SUM, sum_op_smpc_clients ) min_op = trigger_smpc_operation_computation( context_id, command_id, SMPCRequestType.MIN, min_op_smpc_clients @@ -100,80 +101,74 @@ def trigger_smpc_computations( union_op = trigger_smpc_operation_computation( context_id, command_id, SMPCRequestType.UNION, union_op_smpc_clients ) - return add_op, min_op, max_op, union_op + return sum_op, min_op, max_op, union_op def get_smpc_results( node: GlobalNode, context_id: str, command_id: int, - add_op: bool, + sum_op: bool, min_op: bool, max_op: bool, union_op: bool, -) -> Tuple[GlobalNodeTable, GlobalNodeTable, GlobalNodeTable, GlobalNodeTable]: - add_op_result_table = ( +) -> Tuple[TableName, TableName, TableName, TableName]: + sum_op_result_table = ( node.get_smpc_result( - command_id=command_id, jobid=get_smpc_job_id( context_id=context_id, command_id=command_id, operation=SMPCRequestType.SUM, ), + command_id=str(command_id), + command_subid="0", ) - if add_op + if sum_op else None ) min_op_result_table = ( node.get_smpc_result( - command_id=command_id, jobid=get_smpc_job_id( context_id=context_id, command_id=command_id, operation=SMPCRequestType.MIN, ), + command_id=str(command_id), + command_subid="1", ) if min_op else None ) max_op_result_table = ( node.get_smpc_result( - command_id=command_id, jobid=get_smpc_job_id( context_id=context_id, command_id=command_id, operation=SMPCRequestType.MAX, ), + command_id=str(command_id), + command_subid="2", ) if max_op else None ) union_op_result_table = ( node.get_smpc_result( - command_id=command_id, jobid=get_smpc_job_id( context_id=context_id, command_id=command_id, operation=SMPCRequestType.UNION, ), + command_id=str(command_id), + command_subid="3", ) if union_op else None ) - result = ( - GlobalNodeTable(node=node, table=NodeTable(table_name=add_op_result_table)) - if add_op_result_table - else None, - GlobalNodeTable(node=node, table=NodeTable(table_name=min_op_result_table)) - if min_op_result_table - else None, - GlobalNodeTable(node=node, table=NodeTable(table_name=max_op_result_table)) - if max_op_result_table - else None, - GlobalNodeTable(node=node, table=NodeTable(table_name=union_op_result_table)) - if union_op_result_table - else None, + return ( + TableName(sum_op_result_table), + TableName(min_op_result_table), + TableName(max_op_result_table), + TableName(union_op_result_table), ) - - return result diff --git a/mipengine/controller/algorithm_flow_data_objects.py b/mipengine/controller/algorithm_flow_data_objects.py new file mode 100644 index 0000000000000000000000000000000000000000..86145fa4010cb71617798385631fde4cbd42c7e2 --- /dev/null +++ b/mipengine/controller/algorithm_flow_data_objects.py @@ -0,0 +1,273 @@ +from abc import ABC +from typing import Any +from typing import Dict +from typing import List +from typing import Union + +from mipengine.controller.algorithm_executor_node_data_objects import SMPCTableNames +from mipengine.controller.algorithm_executor_node_data_objects import TableName +from mipengine.controller.algorithm_executor_nodes import GlobalNode +from mipengine.controller.algorithm_executor_nodes import LocalNode +from mipengine.node_tasks_DTOs import NodeLiteralDTO +from mipengine.node_tasks_DTOs import NodeSMPCDTO +from mipengine.node_tasks_DTOs import NodeSMPCValueDTO +from mipengine.node_tasks_DTOs import NodeTableDTO +from mipengine.node_tasks_DTOs import NodeUDFDTO +from mipengine.node_tasks_DTOs import TableSchema +from mipengine.node_tasks_DTOs import UDFKeyArguments +from mipengine.node_tasks_DTOs import UDFPosArguments + + +class AlgoFlowData(ABC): + """ + AlgoFlowData are representing data objects in the algorithm flow. + These objects are the result of running udfs and are used as input + as well in the udfs. + """ + + pass + + +class LocalNodesData(AlgoFlowData, ABC): + """ + LocalNodesData are representing data objects in the algorithm flow + that are located in many (or one) local nodes. + """ + + pass + + +class GlobalNodeData(AlgoFlowData, ABC): + """ + GlobalNodeData are representing data objects in the algorithm flow + that are located in the global node. + """ + + pass + + +# When _AlgorithmExecutionInterface::run_udf_on_local_nodes(..) is called, depending on +# how many local nodes are participating in the current algorithm execution, several +# database tables are created on all participating local nodes. Irrespectevely of the +# number of local nodes participating, the number of tables created on each of these local +# nodes will be the same. +# Class _LocalNodeTable is the structure that represents the concept of these database +# tables, created during the execution of a udf, in the algorithm execution layer. A key +# concept is that a _LocalNodeTable stores 'pointers' to 'relevant' tables existing in +# different local nodes accross the federation. By 'relevant' I mean tables that are +# generated when triggering a udf execution accross several local nodes. By 'pointers' +# I mean mapping between local nodes and table names and the aim is to hide the underline +# complexity from the algorithm flow and exposing a single 'local node table' object that +# stores in the background pointers to several tables in several local nodes. +class LocalNodesTable(LocalNodesData): + _nodes_tables: Dict[LocalNode, TableName] + + def __init__(self, nodes_tables: Dict[LocalNode, TableName]): + self._nodes_tables = nodes_tables + self._validate_matching_table_names(list(self._nodes_tables.values())) + + @property + def nodes_tables(self) -> Dict[LocalNode, TableName]: + return self._nodes_tables + + # TODO this is redundant, either remove it or overload all node methods here? + def get_table_schema(self) -> TableSchema: + node = list(self.nodes_tables.keys())[0] + table = self.nodes_tables[node] + return node.get_table_schema(table) + + def get_table_data(self) -> List[Union[int, float, str]]: + tables_data = [] + for node, table_name in self.nodes_tables.items(): + tables_data.append(node.get_table_data(table_name)) + tables_data_flat = [table_data.columns for table_data in tables_data] + tables_data_flat = [ + elem + for table in tables_data_flat + for column in table + for elem in column.data + ] + return tables_data_flat + + def __repr__(self): + r = f"\n\tLocalNodeTable: {self.get_table_schema()}\n" + for node, table_name in self.nodes_tables.items(): + r += f"\t{node=} {table_name=}\n" + return r + + def _validate_matching_table_names(self, table_names: List[TableName]): + table_name_without_node_id = table_names[0].without_node_id() + for table_name in table_names: + if table_name.without_node_id() != table_name_without_node_id: + raise MismatchingTableNamesException( + [table_name.full_table_name for table_name in table_names] + ) + + +class GlobalNodeTable(GlobalNodeData): + _node: GlobalNode + _table: TableName + + def __init__(self, node: GlobalNode, table: TableName): + self._node = node + self._table = table + + @property + def node(self) -> GlobalNode: + return self._node + + @property + def table(self) -> TableName: + return self._table + + # TODO this is redundant, either remove it or overload all node methods here? + def get_table_schema(self) -> TableSchema: + return self._node.get_table_schema(self.table) + + def get_table_data(self) -> List[List[Any]]: + table_data = [ + column.data for column in self.node.get_table_data(self.table).columns + ] + return table_data + + def __repr__(self): + r = f"\n\tGlobalNodeTable: \n\tschema={self.get_table_schema()}\n \t{self.table=}\n" + return r + + +class LocalNodesSMPCTables(LocalNodesData): + _nodes_smpc_tables: Dict[LocalNode, SMPCTableNames] + + def __init__(self, nodes_smpc_tables: Dict[LocalNode, SMPCTableNames]): + self._nodes_smpc_tables = nodes_smpc_tables + + @property + def nodes_smpc_tables(self) -> Dict[LocalNode, SMPCTableNames]: + return self._nodes_smpc_tables + + @property + def template_local_nodes_table(self) -> LocalNodesTable: + return LocalNodesTable( + {node: tables.template for node, tables in self.nodes_smpc_tables.items()} + ) + + +class GlobalNodeSMPCTables(GlobalNodeData): + _node: GlobalNode + _smpc_tables: SMPCTableNames + + def __init__(self, node: GlobalNode, smpc_tables: SMPCTableNames): + self._node = node + self._smpc_tables = smpc_tables + + @property + def node(self) -> GlobalNode: + return self._node + + @property + def smpc_tables(self) -> SMPCTableNames: + return self._smpc_tables + + +def algoexec_udf_kwargs_to_node_udf_kwargs( + algoexec_kwargs: Dict[str, Any], + local_node: LocalNode = None, +) -> UDFKeyArguments: + if not algoexec_kwargs: + return UDFKeyArguments(args={}) + + args = {} + for key, arg in algoexec_kwargs.items(): + udf_argument = _algoexec_udf_arg_to_node_udf_arg(arg, local_node) + args[key] = udf_argument + return UDFKeyArguments(args=args) + + +def algoexec_udf_posargs_to_node_udf_posargs( + algoexec_posargs: List[Any], + local_node: LocalNode = None, +) -> UDFPosArguments: + if not algoexec_posargs: + return UDFPosArguments(args=[]) + + args = [] + for arg in algoexec_posargs: + args.append(_algoexec_udf_arg_to_node_udf_arg(arg, local_node)) + return UDFPosArguments(args=args) + + +def _algoexec_udf_arg_to_node_udf_arg( + algoexec_arg: AlgoFlowData, local_node: LocalNode = None +) -> NodeUDFDTO: + """ + Converts the algorithm executor run_udf input arguments, coming from the algorithm flow + to node udf pos/key arguments to be send to the NODE. + + Parameters + ---------- + algoexec_arg is the argument to be converted. + local_node is need only when the algoexec_arg is of LocalNodesTable, to know + which local table should be selected. + + Returns + ------- + a NodeUDFDTO + """ + if isinstance(algoexec_arg, LocalNodesTable): + if not local_node: + raise ValueError( + "local_node parameter is required on LocalNodesTable conversion." + ) + return NodeTableDTO(value=algoexec_arg.nodes_tables[local_node].full_table_name) + elif isinstance(algoexec_arg, GlobalNodeTable): + return NodeTableDTO(value=algoexec_arg.table.full_table_name) + elif isinstance(algoexec_arg, LocalNodesSMPCTables): + raise ValueError( + "'LocalNodesSMPCTables' cannot be used as argument. It must be shared." + ) + elif isinstance(algoexec_arg, GlobalNodeSMPCTables): + return NodeSMPCDTO( + value=NodeSMPCValueDTO( + template=NodeTableDTO( + value=algoexec_arg.smpc_tables.template.full_table_name + ), + sum_op_values=create_node_table_dto_from_global_node_table( + algoexec_arg.smpc_tables.sum_op + ), + min_op_values=create_node_table_dto_from_global_node_table( + algoexec_arg.smpc_tables.min_op + ), + max_op_values=create_node_table_dto_from_global_node_table( + algoexec_arg.smpc_tables.max_op + ), + union_op_values=create_node_table_dto_from_global_node_table( + algoexec_arg.smpc_tables.union_op + ), + ) + ) + else: + return NodeLiteralDTO(value=algoexec_arg) + + +def create_node_table_dto_from_global_node_table(table: TableName): + if not table: + return None + + return NodeTableDTO(value=table.full_table_name) + + +def create_local_nodes_table_from_nodes_tables( + nodes_tables: Dict[LocalNode, Union[TableName, None]] +): + for table in nodes_tables.values(): + if not table: + return None + + return LocalNodesTable(nodes_tables) + + +class MismatchingTableNamesException(Exception): + def __init__(self, table_names: List[str]): + message = f"Mismatched table names ->{table_names}" + super().__init__(message) + self.message = message diff --git a/mipengine/controller/node_tasks_handler_celery.py b/mipengine/controller/node_tasks_handler_celery.py index b8fc6538cb29aa115d7aa09b7c06638027c1f478..fa498884b7ac15f3445651ba20faafc2b58ff5f6 100644 --- a/mipengine/controller/node_tasks_handler_celery.py +++ b/mipengine/controller/node_tasks_handler_celery.py @@ -8,6 +8,7 @@ from billiard.exceptions import TimeLimitExceeded from celery.exceptions import TimeoutError from celery.result import AsyncResult from kombu.exceptions import OperationalError +from typing import Optional from mipengine.controller.celery_app import get_node_celery_app from mipengine.controller.node_tasks_handler_interface import INodeTasksHandler @@ -388,16 +389,18 @@ class NodeTasksHandlerCelery(INodeTasksHandler): @broker_connection_closed_handler def get_smpc_result( self, + jobid: str, context_id: str, command_id: str, - jobid: str, + command_subid: Optional[str] = "0", ) -> str: task_signature = self._celery_app.signature(TASK_SIGNATURES["get_smpc_result"]) return self._apply_async( task_signature=task_signature, + jobid=jobid, context_id=context_id, command_id=command_id, - jobid=jobid, + command_subid=command_subid, ).get(self._tasks_timeout) # CLEANUP functionality diff --git a/mipengine/controller/node_tasks_handler_interface.py b/mipengine/controller/node_tasks_handler_interface.py index 93a9c40c45d3d43becf76f71a1308bf25021bb2c..b3ac0ae12ee20556d839424b1e1d8a691e061114 100644 --- a/mipengine/controller/node_tasks_handler_interface.py +++ b/mipengine/controller/node_tasks_handler_interface.py @@ -171,8 +171,9 @@ class INodeTasksHandler(ABC): @abstractmethod def get_smpc_result( self, + jobid: str, context_id: str, command_id: str, - jobid: str, + command_subid: Optional[str] = "0", ) -> str: pass diff --git a/mipengine/node/tasks/smpc.py b/mipengine/node/tasks/smpc.py index 939d2202672e0abfde5d20b352bf59ad31d2ff68..6dca079cce3d3cee063344bc4ae265e6c7ca0949 100644 --- a/mipengine/node/tasks/smpc.py +++ b/mipengine/node/tasks/smpc.py @@ -4,6 +4,8 @@ from typing import List from celery import shared_task +from typing import Optional + from mipengine import DType from mipengine import smpc_cluster_comm_helpers as smpc_cluster from mipengine.node import config as node_config @@ -91,9 +93,10 @@ def load_data_to_smpc_client(request_id: str, table_name: str, jobid: str) -> in @initialise_logger def get_smpc_result( request_id: str, + jobid: str, context_id: str, command_id: str, - jobid: str, + command_subid: Optional[str] = "0", ) -> str: """ Fetches the results from an SMPC and writes them into a table. @@ -101,8 +104,10 @@ def get_smpc_result( Parameters ---------- request_id: The identifier for the logging + jobid: The identifier for the smpc job. context_id: An identifier of the action. command_id: An identifier for the command, used for naming the result table. + command_subid: An identifier for the command, used for naming the result table. jobid: The jobid of the SMPC. Returns @@ -140,13 +145,16 @@ def get_smpc_result( request_id=request_id, context_id=context_id, command_id=command_id, + command_subid=command_subid, smpc_op_result_data=smpc_response_with_output.computationOutput, ) return results_table_name -def _create_smpc_results_table(request_id, context_id, command_id, smpc_op_result_data): +def _create_smpc_results_table( + request_id, context_id, command_id, command_subid, smpc_op_result_data +): """ Create a table with the SMPC specific schema and insert the results of the SMPC to it. @@ -157,6 +165,7 @@ def _create_smpc_results_table(request_id, context_id, command_id, smpc_op_resul node_config.identifier, context_id, command_id, + command_subid, ) table_schema = TableSchema( columns=[ @@ -181,7 +190,7 @@ def _create_smpc_results_table(request_id, context_id, command_id, smpc_op_resul def _get_smpc_values_from_table_data(table_data: List[ColumnData], op: SMPCRequestType): if op == SMPCRequestType.SUM: node_id_column, values_column = table_data - add_op_values = values_column.data + sum_op_values = values_column.data else: raise NotImplementedError - return add_op_values + return sum_op_values diff --git a/mipengine/node/tasks/udfs.py b/mipengine/node/tasks/udfs.py index e5c0e84eebf70ea08a60d96bab85786b76fb1cdb..436373705da4b74b424651174c2bde7142ca4937 100644 --- a/mipengine/node/tasks/udfs.py +++ b/mipengine/node/tasks/udfs.py @@ -171,9 +171,9 @@ def _create_table_info_from_tablename(tablename: str): def _convert_smpc_udf2udfgen_arg(udf_argument: NodeSMPCDTO): template = _create_table_info_from_tablename(udf_argument.value.template.value) - add_op = ( - _create_table_info_from_tablename(udf_argument.value.add_op_values.value) - if udf_argument.value.add_op_values + sum_op = ( + _create_table_info_from_tablename(udf_argument.value.sum_op_values.value) + if udf_argument.value.sum_op_values else None ) min_op = ( @@ -193,7 +193,7 @@ def _convert_smpc_udf2udfgen_arg(udf_argument: NodeSMPCDTO): ) return SMPCTablesInfo( template=template, - add_op_values=add_op, + sum_op_values=sum_op, min_op_values=min_op, max_op_values=max_op, union_op_values=union_op, @@ -266,8 +266,8 @@ def _get_all_table_results_from_smpc_result( ) -> List[TableUDFGenResult]: table_results = [smpc_result.template] table_results.append( - smpc_result.add_op_values - ) if smpc_result.add_op_values else None + smpc_result.sum_op_values + ) if smpc_result.sum_op_values else None table_results.append( smpc_result.min_op_values ) if smpc_result.min_op_values else None @@ -351,13 +351,13 @@ def _convert_udfgen2udf_smpc_result_and_mapping( udfgen_result.template, context_id, command_id, command_subid ) - if udfgen_result.add_op_values: - (add_op_udf_result, mapping,) = _convert_udfgen2udf_table_result_and_mapping( - udfgen_result.add_op_values, context_id, command_id, command_subid + 1 + if udfgen_result.sum_op_values: + (sum_op_udf_result, mapping,) = _convert_udfgen2udf_table_result_and_mapping( + udfgen_result.sum_op_values, context_id, command_id, command_subid + 1 ) table_names_tmpl_mapping.update(mapping) else: - add_op_udf_result = None + sum_op_udf_result = None if udfgen_result.min_op_values: (min_op_udf_result, mapping,) = _convert_udfgen2udf_table_result_and_mapping( @@ -386,7 +386,7 @@ def _convert_udfgen2udf_smpc_result_and_mapping( result = NodeSMPCDTO( value=NodeSMPCValueDTO( template=template_udf_result, - add_op_values=add_op_udf_result, + sum_op_values=sum_op_udf_result, min_op_values=min_op_udf_result, max_op_values=max_op_udf_result, union_op_values=union_op_udf_result, diff --git a/mipengine/node_tasks_DTOs.py b/mipengine/node_tasks_DTOs.py index d84dbfbfc0b337a65a3703c5928049b5a467c517..95c772dfda9023b38cfcad33162617db61ec9c19 100644 --- a/mipengine/node_tasks_DTOs.py +++ b/mipengine/node_tasks_DTOs.py @@ -121,7 +121,7 @@ class NodeTableDTO(NodeUDFDTO): class NodeSMPCValueDTO(ImmutableBaseModel): template: NodeTableDTO - add_op_values: NodeTableDTO = None + sum_op_values: NodeTableDTO = None min_op_values: NodeTableDTO = None max_op_values: NodeTableDTO = None union_op_values: NodeTableDTO = None diff --git a/mipengine/udfgen/udfgen_DTOs.py b/mipengine/udfgen/udfgen_DTOs.py index ec5ee68a03dacdf8def6741a362d2cab3c33a0ca..5ff9d81e650a7132776e7bd6d46ad2d97c5dea8e 100644 --- a/mipengine/udfgen/udfgen_DTOs.py +++ b/mipengine/udfgen/udfgen_DTOs.py @@ -43,7 +43,7 @@ class TableUDFGenResult(UDFGenResult): class SMPCUDFGenResult(UDFGenResult): template: TableUDFGenResult - add_op_values: Optional[TableUDFGenResult] = None + sum_op_values: Optional[TableUDFGenResult] = None min_op_values: Optional[TableUDFGenResult] = None max_op_values: Optional[TableUDFGenResult] = None union_op_values: Optional[TableUDFGenResult] = None @@ -51,7 +51,7 @@ class SMPCUDFGenResult(UDFGenResult): def __eq__(self, other): if self.template != other.template: return False - if self.add_op_values != other.add_op_values: + if self.sum_op_values != other.sum_op_values: return False if self.min_op_values != other.min_op_values: return False @@ -65,7 +65,7 @@ class SMPCUDFGenResult(UDFGenResult): return ( f"SMPCUDFGenResult(" f"template={self.template}, " - f"add_op_values={self.add_op_values}, " + f"sum_op_values={self.sum_op_values}, " f"min_op_values={self.min_op_values}, " f"max_op_values={self.max_op_values}, " f"union_op_values={self.union_op_values}" @@ -93,7 +93,7 @@ class UDFGenExecutionQueries(UDFGenBaseModel): class SMPCTablesInfo(UDFGenBaseModel): template: TableInfo - add_op_values: Optional[TableInfo] = None + sum_op_values: Optional[TableInfo] = None min_op_values: Optional[TableInfo] = None max_op_values: Optional[TableInfo] = None union_op_values: Optional[TableInfo] = None @@ -102,7 +102,7 @@ class SMPCTablesInfo(UDFGenBaseModel): return ( f"SMPCUDFInput(" f"template={self.template}, " - f"add_op_values={self.add_op_values}, " + f"sum_op_values={self.sum_op_values}, " f"min_op_values={self.min_op_values}, " f"max_op_values={self.max_op_values}, " f"union_op_values={self.union_op_values}" diff --git a/mipengine/udfgen/udfgenerator.py b/mipengine/udfgen/udfgenerator.py index f714c287af940450a046ea3966403f52cbabb842..6bcc011d030b6ed5cea085fbf009ae5c5706005b 100644 --- a/mipengine/udfgen/udfgenerator.py +++ b/mipengine/udfgen/udfgenerator.py @@ -184,7 +184,7 @@ Local UDF step Example ... def local_step(x, y): ... state["x"] = x["key"] ... state["y"] = y["key"] -... transfer["sum"] = {"data": x["key"] + y["key"], "type": "int", "operation": "addition"} +... transfer["sum"] = {"data": x["key"] + y["key"], "operation": "sum"} ... return state, transfer Global UDF step Example @@ -206,15 +206,11 @@ So, the secure_transfer dict sent should be of the format: The data could be an int/float or a list containing other lists or float/int. -The type enumerations are: - - "int" - - "decimal" (Not yet implemented) - The operation enumerations are: - - "addition" - - "min" (Not yet implemented) - - "max" (Not yet implemented) - - "union" (Not yet implemented) + - "sum" (Floats not supported when SMPC is enabled) + - "min" (Floats not supported when SMPC is enabled) + - "max" (Floats not supported when SMPC is enabled) + - "union" (Not yet supported) 2. Translating and calling UDFs @@ -364,7 +360,9 @@ def get_smpc_build_template(secure_transfer_type): stmts.append( f'__{operation_name}_values_str = _conn.execute("SELECT secure_transfer from {{{operation_name}_values_table_name}};")["secure_transfer"][0]' ) - stmts.append(f"__{operation_name}_values = json.loads(__add_op_values_str)") + stmts.append( + f"__{operation_name}_values = json.loads(__{operation_name}_values_str)" + ) else: stmts.append(f"__{operation_name}_values = None") return stmts @@ -374,12 +372,12 @@ def get_smpc_build_template(secure_transfer_type): '__template_str = _conn.execute("SELECT secure_transfer from {template_table_name};")["secure_transfer"][0]' ) stmts.append("__template = json.loads(__template_str)") - stmts.extend(get_smpc_op_template(secure_transfer_type.add_op, "add_op")) + stmts.extend(get_smpc_op_template(secure_transfer_type.sum_op, "sum_op")) stmts.extend(get_smpc_op_template(secure_transfer_type.min_op, "min_op")) stmts.extend(get_smpc_op_template(secure_transfer_type.max_op, "max_op")) stmts.extend(get_smpc_op_template(secure_transfer_type.union_op, "union_op")) stmts.append( - "{varname} = udfio.construct_secure_transfer_dict(__template,__add_op_values,__min_op_values,__max_op_values,__union_op_values)" + "{varname} = udfio.construct_secure_transfer_dict(__template,__sum_op_values,__min_op_values,__max_op_values,__union_op_values)" ) return LN.join(stmts) @@ -439,18 +437,18 @@ def _get_secure_transfer_op_return_stmt_template(op_enabled, table_name_tmpl, op def _get_secure_transfer_main_return_stmt_template(output_type, smpc_used): if smpc_used: return_stmts = [ - "template, add_op, min_op, max_op, union_op = udfio.split_secure_transfer_dict({return_name})" + "template, sum_op, min_op, max_op, union_op = udfio.split_secure_transfer_dict({return_name})" ] ( _, - add_op_tmpl, + sum_op_tmpl, min_op_tmpl, max_op_tmpl, union_op_tmpl, ) = _get_smpc_table_template_names(_get_main_table_template_name()) return_stmts.extend( _get_secure_transfer_op_return_stmt_template( - output_type.add_op, add_op_tmpl, "add_op" + output_type.sum_op, sum_op_tmpl, "sum_op" ) ) return_stmts.extend( @@ -502,11 +500,11 @@ def _get_secure_transfer_sec_return_stmt_template( ): if smpc_used: return_stmts = [ - "template, add_op, min_op, max_op, union_op = udfio.split_secure_transfer_dict({return_name})" + "template, sum_op, min_op, max_op, union_op = udfio.split_secure_transfer_dict({return_name})" ] ( template_tmpl, - add_op_tmpl, + sum_op_tmpl, min_op_tmpl, max_op_tmpl, union_op_tmpl, @@ -518,7 +516,7 @@ def _get_secure_transfer_sec_return_stmt_template( ) return_stmts.extend( _get_secure_transfer_op_return_stmt_template( - output_type.add_op, add_op_tmpl, "add_op" + output_type.sum_op, sum_op_tmpl, "sum_op" ) ) return_stmts.extend( @@ -894,20 +892,20 @@ def merge_transfer(): class SecureTransferType(DictType, InputType, LoopbackOutputType): _data_column_name = "secure_transfer" _data_column_type = dt.JSON - _add_op: bool + _sum_op: bool _min_op: bool _max_op: bool _union_op: bool - def __init__(self, add_op=False, min_op=False, max_op=False, union_op=False): - self._add_op = add_op + def __init__(self, sum_op=False, min_op=False, max_op=False, union_op=False): + self._sum_op = sum_op self._min_op = min_op self._max_op = max_op self._union_op = union_op @property - def add_op(self): - return self._add_op + def sum_op(self): + return self._sum_op @property def min_op(self): @@ -922,12 +920,12 @@ class SecureTransferType(DictType, InputType, LoopbackOutputType): return self._union_op -def secure_transfer(add_op=False, min_op=False, max_op=False, union_op=False): - if not add_op and not min_op and not max_op and not union_op: +def secure_transfer(sum_op=False, min_op=False, max_op=False, union_op=False): + if not sum_op and not min_op and not max_op and not union_op: raise UDFBadDefinition( "In a secure_transfer at least one operation should be enabled." ) - return SecureTransferType(add_op, min_op, max_op, union_op) + return SecureTransferType(sum_op, min_op, max_op, union_op) class StateType(DictType, InputType, LoopbackOutputType): @@ -1057,7 +1055,7 @@ class SecureTransferArg(DictArg): class SMPCSecureTransferArg(UDFArgument): type: SecureTransferType template_table_name: str - add_op_values_table_name: str + sum_op_values_table_name: str min_op_values_table_name: str max_op_values_table_name: str union_op_values_table_name: str @@ -1065,26 +1063,26 @@ class SMPCSecureTransferArg(UDFArgument): def __init__( self, template_table_name: str, - add_op_values_table_name: str, + sum_op_values_table_name: str, min_op_values_table_name: str, max_op_values_table_name: str, union_op_values_table_name: str, ): - add_op = False + sum_op = False min_op = False max_op = False union_op = False - if add_op_values_table_name: - add_op = True + if sum_op_values_table_name: + sum_op = True if min_op_values_table_name: min_op = True if max_op_values_table_name: max_op = True if union_op_values_table_name: union_op = True - self.type = SecureTransferType(add_op, min_op, max_op, union_op) + self.type = SecureTransferType(sum_op, min_op, max_op, union_op) self.template_table_name = template_table_name - self.add_op_values_table_name = add_op_values_table_name + self.sum_op_values_table_name = sum_op_values_table_name self.min_op_values_table_name = min_op_values_table_name self.max_op_values_table_name = max_op_values_table_name self.union_op_values_table_name = union_op_values_table_name @@ -1242,7 +1240,7 @@ class SMPCBuild(ASTNode): return self.template.format( varname=self.arg_name, template_table_name=self.arg.template_table_name, - add_op_values_table_name=self.arg.add_op_values_table_name, + sum_op_values_table_name=self.arg.sum_op_values_table_name, min_op_values_table_name=self.arg.min_op_values_table_name, max_op_values_table_name=self.arg.max_op_values_table_name, union_op_values_table_name=self.arg.union_op_values_table_name, @@ -2013,12 +2011,12 @@ def convert_udfgenarg_to_udfarg(udfgen_arg, smpc_used) -> UDFArgument: def convert_smpc_udf_input_to_udf_arg(smpc_udf_input: SMPCTablesInfo): - add_op_table_name = None + sum_op_table_name = None min_op_table_name = None max_op_table_name = None union_op_table_name = None - if smpc_udf_input.add_op_values: - add_op_table_name = smpc_udf_input.add_op_values.name + if smpc_udf_input.sum_op_values: + sum_op_table_name = smpc_udf_input.sum_op_values.name if smpc_udf_input.min_op_values: min_op_table_name = smpc_udf_input.min_op_values.name if smpc_udf_input.max_op_values: @@ -2027,7 +2025,7 @@ def convert_smpc_udf_input_to_udf_arg(smpc_udf_input: SMPCTablesInfo): union_op_table_name = smpc_udf_input.union_op_values.name return SMPCSecureTransferArg( template_table_name=smpc_udf_input.template.name, - add_op_values_table_name=add_op_table_name, + sum_op_values_table_name=sum_op_table_name, min_op_values_table_name=min_op_table_name, max_op_values_table_name=max_op_table_name, union_op_values_table_name=union_op_table_name, @@ -2506,7 +2504,7 @@ def _get_smpc_table_template_names(prefix: str): """ return ( prefix, - prefix + "_add_op", + prefix + "_sum_op", prefix + "_min_op", prefix + "_max_op", prefix + "_union_op", @@ -2529,18 +2527,18 @@ def _create_table_udf_output(output_type: OutputType, table_name: str) -> UDFGen def _create_smpc_udf_output(output_type: SecureTransferType, table_name_prefix: str): ( template_tmpl, - add_op_tmpl, + sum_op_tmpl, min_op_tmpl, max_op_tmpl, union_op_tmpl, ) = _get_smpc_table_template_names(table_name_prefix) template = _create_table_udf_output(output_type, template_tmpl) - add_op = None + sum_op = None min_op = None max_op = None union_op = None - if output_type.add_op: - add_op = _create_table_udf_output(output_type, add_op_tmpl) + if output_type.sum_op: + sum_op = _create_table_udf_output(output_type, sum_op_tmpl) if output_type.min_op: min_op = _create_table_udf_output(output_type, min_op_tmpl) if output_type.max_op: @@ -2549,7 +2547,7 @@ def _create_smpc_udf_output(output_type: SecureTransferType, table_name_prefix: union_op = _create_table_udf_output(output_type, union_op_tmpl) return SMPCUDFGenResult( template=template, - add_op_values=add_op, + sum_op_values=sum_op, min_op_values=min_op, max_op_values=max_op, union_op_values=union_op, diff --git a/mipengine/udfgen/udfio.py b/mipengine/udfgen/udfio.py index bb88b9f9b41b1ba8f68e148beb967276d687adcd..c224639e2daed9b0bc38893b7defe7bf79a8c93c 100644 --- a/mipengine/udfgen/udfio.py +++ b/mipengine/udfgen/udfio.py @@ -3,9 +3,10 @@ import os import re from functools import partial from functools import reduce +from typing import Any from typing import List from typing import Tuple -from typing import Union +from typing import Type import numpy as np import pandas as pd @@ -115,30 +116,25 @@ def merge_tensor_to_list(columns): # ~~~~~~~~~~~~~~~~~~~~~~~~ Secure Transfer methods ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# TODO Add validate_secure_transfer_object +numeric_operations = ["sum", "min", "max"] def secure_transfers_to_merged_dict(transfers: List[dict]): """ Converts a list of secure transfer dictionaries to one dictionary that contains the aggregation of all the initial values. + This is used for secure transfer objects when SMPC is disabled. """ - # TODO Should also work for "decimals", "min", "max" and "union" - result = {} # Get all keys from a list of dicts all_keys = set().union(*(d.keys() for d in transfers)) for key in all_keys: operation = transfers[0][key]["operation"] - op_type = transfers[0][key]["type"] - if operation == "addition": - if op_type == "int": - result[key] = _add_secure_transfer_key_integer_data(key, transfers) - else: - raise NotImplementedError( - f"Secure transfer type: {type} not supported for operation: {operation}" - ) + if operation in numeric_operations: + result[key] = _operation_on_secure_transfer_key_data( + key, transfers, operation + ) else: raise NotImplementedError( f"Secure transfer operation not supported: {operation}" @@ -146,47 +142,155 @@ def secure_transfers_to_merged_dict(transfers: List[dict]): return result -def _add_secure_transfer_key_integer_data(key, transfers: List[dict]): +def _operation_on_secure_transfer_key_data(key, transfers: List[dict], operation: str): """ - Given a list of secure_transfer dicts, it sums the data of the key provided. - The values should be integers. + Given a list of secure_transfer dicts, it makes the appropriate operation on the data of the key provided. """ result = transfers[0][key]["data"] for transfer in transfers[1:]: - if transfer[key]["operation"] != "addition": + if transfer[key]["operation"] not in numeric_operations: raise ValueError( - f"All secure transfer keys should have the same 'operation' value. 'addition' != {transfer[key]['operation']}" + f"Secure Transfer operation is not supported: {transfer[key]['operation']}" ) - if transfer[key]["type"] != "int": + if transfer[key]["operation"] != operation: raise ValueError( - f"All secure transfer keys should have the same 'type' value. 'int' != {transfer[key]['type']}" + f"All secure transfer keys should have the same 'operation' value. " + f"'{operation}' != {transfer[key]['operation']}" ) - result = _add_integer_type_values(result, transfer[key]["data"]) + result = _calc_values(result, transfer[key]["data"], operation) return result -def _add_integer_type_values(value1: Tuple[int, list], value2: Tuple[int, list]): +def _calc_values(value1: Any, value2: Any, operation: str): """ - The values could be either integers or lists that contain other lists or integers. - The type of the values should not change, only the value. + The values could be either integers/floats or lists that contain other lists or integers/floats. """ - if type(value1) != type(value2): - raise TypeError( - f"Secure transfer data have different types: {type(value1)} != {type(value2)}" - ) + _validate_calc_values(value1, value2) - if isinstance(value1, list): + if isinstance(value1, list) and isinstance(value2, list): result = [] for e1, e2 in zip(value1, value2): - result.append(_add_integer_type_values(e1, e2)) + result.append(_calc_values(e1, e2, operation)) return result - else: + + return _calc_numeric_values(value1, value2, operation) + + +def _validate_calc_values(value1, value2): + allowed_types = [int, float, list] + for value in [value1, value2]: + if type(value) not in allowed_types: + raise TypeError( + f"Secure transfer data must have one of the following types: " + f"{allowed_types}. Type provided: {type(value)}" + ) + if (isinstance(value1, list) or isinstance(value2, list)) and ( + type(value1) != type(value2) + ): + raise ValueError("Secure transfers' data should have the same structure.") + + +def _calc_numeric_values(value1: Any, value2: Any, operation: str): + if operation == "sum": return value1 + value2 + elif operation == "min": + return value1 if value1 < value2 else value2 + elif operation == "max": + return value1 if value1 > value2 else value2 + else: + raise NotImplementedError + + +def split_secure_transfer_dict(dict_: dict) -> Tuple[dict, list, list, list, list]: + """ + When SMPC is used, a secure transfer dict should be split in different parts: + 1) The template of the dict with relative positions instead of values, + 2) flattened lists for each operation, containing the values. + """ + secure_transfer_template = {} + op_flat_data = {"sum": [], "min": [], "max": []} + op_indexes = {"sum": 0, "min": 0, "max": 0} + for key, data_transfer in dict_.items(): + _validate_secure_transfer_item(key, data_transfer) + cur_op = data_transfer["operation"] + try: + ( + data_transfer_tmpl, + cur_flat_data, + op_indexes[cur_op], + ) = _flatten_data_and_keep_relative_positions( + op_indexes[cur_op], data_transfer["data"], [list, int, float] + ) + except TypeError as e: + raise TypeError( + f"Secure Transfer key: '{key}', operation: '{cur_op}'. Error: {str(e)}" + ) + op_flat_data[cur_op].extend(cur_flat_data) + + secure_transfer_template[key] = dict_[key] + secure_transfer_template[key]["data"] = data_transfer_tmpl + + return ( + secure_transfer_template, + op_flat_data["sum"], + op_flat_data["min"], + op_flat_data["max"], + [], + ) + + +def _validate_secure_transfer_item(key: str, data_transfer: dict): + if "operation" not in data_transfer.keys(): + raise ValueError( + f"Each Secure Transfer key should contain an operation. Key: {key}" + ) + + if "data" not in data_transfer.keys(): + raise ValueError(f"Each Secure Transfer key should contain data. Key: {key}") + + if data_transfer["operation"] not in numeric_operations: + raise ValueError( + f"Secure Transfer operation is not supported: {data_transfer['operation']}" + ) + + +def construct_secure_transfer_dict( + template: dict, + sum_op_values: List[int] = None, + min_op_values: List[int] = None, + max_op_values: List[int] = None, + union_op_values: List[int] = None, +) -> dict: + """ + When SMPC is used, a secure_transfer dict is broken into template and values. + In order to be used from a udf it needs to take it's final key - value form. + """ + final_dict = {} + for key, data_transfer in template.items(): + if data_transfer["operation"] == "sum": + unflattened_data = _unflatten_data_using_relative_positions( + data_transfer["data"], sum_op_values, [int, float] + ) + elif data_transfer["operation"] == "min": + unflattened_data = _unflatten_data_using_relative_positions( + data_transfer["data"], min_op_values, [int, float] + ) + elif data_transfer["operation"] == "max": + unflattened_data = _unflatten_data_using_relative_positions( + data_transfer["data"], max_op_values, [int, float] + ) + else: + raise ValueError(f"Operation not supported: {data_transfer['operation']}") + + final_dict[key] = unflattened_data + return final_dict -def _flatten_int_data_and_keep_relative_positions( - index: int, data: Union[list, int] -) -> Tuple[Union[list, int], List[int], int]: +def _flatten_data_and_keep_relative_positions( + index: int, + data: Any, + allowed_types: List[Type], +) -> Tuple[Any, List[Any], int]: """ Iterates through a nested list structure and: 1) keeps the structure of the data with relative positions in the flat array, @@ -194,11 +298,14 @@ def _flatten_int_data_and_keep_relative_positions( 3) also returns the final index so it can be used again. For example: - >>> _flatten_int_data_and_keep_relative_positions(0, [[7,6,7],[8,9,10]]) + >>> _flatten_data_and_keep_relative_positions(0, [[7,6,7],[8,9,10]]) Returns: [7, 6, 7, 8, 9, 10], [[0, 1, 2], [3, 4, 5]], 6 """ + if type(data) not in allowed_types: + raise TypeError(f"Types allowed: {allowed_types}") + if isinstance(data, list): data_pos_template = [] flat_data = [] @@ -207,20 +314,18 @@ def _flatten_int_data_and_keep_relative_positions( data_template, cur_flat_data, index, - ) = _flatten_int_data_and_keep_relative_positions(index, elem) + ) = _flatten_data_and_keep_relative_positions(index, elem, allowed_types) data_pos_template.append(data_template) flat_data.extend(cur_flat_data) return data_pos_template, flat_data, index - elif isinstance(data, int): - return index, [data], index + 1 - - else: - raise NotImplementedError + return index, [data], index + 1 -def _unflatten_int_data_using_relative_positions( - data_tmpl: Union[list, int], flat_values: List[int] +def _unflatten_data_using_relative_positions( + data_tmpl: Any, + flat_values: List[Any], + allowed_types: List[Type], ): """ It's doing the exact opposite of _flatten_int_data_and_keep_relative_positions. @@ -229,73 +334,11 @@ def _unflatten_int_data_using_relative_positions( """ if isinstance(data_tmpl, list): return [ - _unflatten_int_data_using_relative_positions(elem, flat_values) + _unflatten_data_using_relative_positions(elem, flat_values, allowed_types) for elem in data_tmpl ] - elif isinstance(data_tmpl, int): - return flat_values[data_tmpl] - - else: - ValueError( - f"Relative position values can only be ints or list. Received: {data_tmpl}" - ) - - -def split_secure_transfer_dict(dict_: dict) -> Tuple[dict, list, list, list, list]: - """ - When SMPC is used a secure transfer dict should be split in different parts: - 1) The template of the dict with relative positions instead of values, - 2) flattened lists for each operation, containing the values. - """ - # TODO Needs to support min, max and union as well - secure_transfer_template = {} - addition_flat_data = [] - addition_index = 0 - for key, data_transfer in dict_.items(): - if data_transfer["operation"] == "addition": - if data_transfer["type"] == "int": - ( - data_transfer_tmpl, - cur_flat_data, - addition_index, - ) = _flatten_int_data_and_keep_relative_positions( - addition_index, data_transfer["data"] - ) - addition_flat_data.extend(cur_flat_data) - else: - raise NotImplementedError - else: - raise NotImplementedError - - secure_transfer_template[key] = dict_[key] - secure_transfer_template[key]["data"] = data_transfer_tmpl - - return secure_transfer_template, addition_flat_data, [], [], [] + if type(data_tmpl) not in allowed_types: + raise TypeError(f"Types allowed: {allowed_types}") -def construct_secure_transfer_dict( - template: dict, - add_op_values: List[int] = None, - min_op_values: List[int] = None, - max_op_values: List[int] = None, - union_op_values: List[int] = None, -) -> dict: - """ - When SMPC is used a secure_transfer dict is broken into template and values. - In order to be used from a udf it needs to take it's final key - value form. - """ - # TODO Needs to support min, max and union as well - final_dict = {} - for key, data_transfer in template.items(): - if data_transfer["operation"] == "addition": - if data_transfer["type"] == "int": - unflattened_data = _unflatten_int_data_using_relative_positions( - data_transfer["data"], add_op_values - ) - else: - raise NotImplementedError - else: - raise NotImplementedError - - final_dict[key] = unflattened_data - return final_dict + return flat_values[data_tmpl] diff --git a/tests/algorithms/orphan_udfs.py b/tests/algorithms/orphan_udfs.py index 87161af11b5f03282ccf0fed36c73866a4b95f2c..55406ac9daee370330548d625c6ae864d4ab3323 100644 --- a/tests/algorithms/orphan_udfs.py +++ b/tests/algorithms/orphan_udfs.py @@ -29,18 +29,16 @@ def local_step(table: DataFrame): return state_, transfer_ -@udf(table=relation(S), return_type=secure_transfer(add_op=True)) +@udf(table=relation(S), return_type=secure_transfer(sum_op=True)) def smpc_local_step(table: DataFrame): sum_ = 0 for element, *_ in table.values: sum_ += element - secure_transfer_ = { - "sum": {"data": int(sum_), "type": "int", "operation": "addition"} - } + secure_transfer_ = {"sum": {"data": int(sum_), "operation": "sum"}} return secure_transfer_ -@udf(locals_result=secure_transfer(add_op=True), return_type=transfer()) +@udf(locals_result=secure_transfer(sum_op=True), return_type=transfer()) def smpc_global_step(locals_result): result = {"total_sum": locals_result["sum"]} return result diff --git a/tests/algorithms/smpc_standard_deviation.py b/tests/algorithms/smpc_standard_deviation.py index 572d410a0ccb000d4b6b9082b43a9761323db54e..e174bdb3c52971f606b2daf0f12dc437239ead91 100644 --- a/tests/algorithms/smpc_standard_deviation.py +++ b/tests/algorithms/smpc_standard_deviation.py @@ -50,6 +50,8 @@ def run(algo_interface): ) std_deviation = json.loads(global_result.get_table_data()[1][0])["deviation"] + min_value = json.loads(global_result.get_table_data()[1][0])["min_value"] + max_value = json.loads(global_result.get_table_data()[1][0])["max_value"] x_variables = algo_interface.x_variables result = TabularDataResult( @@ -57,6 +59,8 @@ def run(algo_interface): columns=[ ColumnDataStr(name="variable", data=x_variables), ColumnDataFloat(name="std_deviation", data=[std_deviation]), + ColumnDataFloat(name="min_value", data=[min_value]), + ColumnDataFloat(name="max_value", data=[max_value]), ], ) return result @@ -73,25 +77,43 @@ def relation_to_matrix(rel): return rel -@udf(table=tensor(S, 2), return_type=[state(), secure_transfer(add_op=True)]) +@udf( + table=tensor(S, 2), + return_type=[state(), secure_transfer(sum_op=True, min_op=True, max_op=True)], +) def smpc_local_step_1(table): state_ = {"table": table} sum_ = 0 + min_value = table[0][0] + max_value = table[0][0] for (element,) in table: sum_ += element + if element < min_value: + min_value = element + if element > max_value: + max_value = element secure_transfer_ = { - "sum": {"data": int(sum_), "type": "int", "operation": "addition"}, - "count": {"data": len(table), "type": "int", "operation": "addition"}, + "sum": {"data": float(sum_), "operation": "sum"}, + "min": {"data": float(min_value), "operation": "min"}, + "max": {"data": float(max_value), "operation": "max"}, + "count": {"data": len(table), "operation": "sum"}, } return state_, secure_transfer_ -@udf(locals_result=secure_transfer(add_op=True), return_type=[state(), transfer()]) +@udf( + locals_result=secure_transfer(sum_op=True, min_op=True, max_op=True), + return_type=[state(), transfer()], +) def smpc_global_step_1(locals_result): total_sum = locals_result["sum"] total_count = locals_result["count"] average = total_sum / total_count - state_ = {"count": total_count} + state_ = { + "count": total_count, + "min_value": locals_result["min"], + "max_value": locals_result["max"], + } transfer_ = {"average": average} return state_, transfer_ @@ -99,7 +121,7 @@ def smpc_global_step_1(locals_result): @udf( prev_state=state(), global_transfer=transfer(), - return_type=secure_transfer(add_op=True), + return_type=secure_transfer(sum_op=True), ) def smpc_local_step_2(prev_state, global_transfer): deviation_sum = 0 @@ -107,9 +129,9 @@ def smpc_local_step_2(prev_state, global_transfer): deviation_sum += pow(element - global_transfer["average"], 2) secure_transfer_ = { "deviation_sum": { - "data": int(deviation_sum), + "data": float(deviation_sum), "type": "int", - "operation": "addition", + "operation": "sum", } } return secure_transfer_ @@ -117,12 +139,16 @@ def smpc_local_step_2(prev_state, global_transfer): @udf( prev_state=state(), - locals_result=secure_transfer(add_op=True), + locals_result=secure_transfer(sum_op=True), return_type=transfer(), ) def smpc_global_step_2(prev_state, locals_result): total_deviation = locals_result["deviation_sum"] from math import sqrt - deviation = {"deviation": sqrt(total_deviation / prev_state["count"])} + deviation = { + "deviation": sqrt(total_deviation / prev_state["count"]), + "min_value": prev_state["min_value"], + "max_value": prev_state["max_value"], + } return deviation diff --git a/tests/dev_env_tests/test_post_algorithms.py b/tests/dev_env_tests/test_post_algorithms.py index b5c81096055dc39a59d9472ae152ddafb34b0341..af11afbda41e2cb127dde4f39828a8f451affa44 100644 --- a/tests/dev_env_tests/test_post_algorithms.py +++ b/tests/dev_env_tests/test_post_algorithms.py @@ -52,7 +52,9 @@ def get_parametrization_list_success_cases(): "title": "Standard Deviation", "columns": [ {"name": "variable", "data": ["lefthippocampus"], "type": "STR"}, - {"name": "std_deviation", "data": [0.3611575592573076], "type": "FLOAT"}, + {"name": "std_deviation", "data": [0.3634506955662605], "type": "FLOAT"}, + {"name": "min_value", "data": [1.3047], "type": "FLOAT"}, + {"name": "max_value", "data": [4.4519], "type": "FLOAT"}, ], } parametrization_list.append((algorithm_name, request_dict, expected_response)) diff --git a/tests/dev_env_tests/test_single_local_node_algorithm_execution.py b/tests/dev_env_tests/test_single_local_node_algorithm_execution.py index 5f601cd84f4e1a230f41428fec3dd8bb3f9d948e..4e3ae2ed084760989c94ee854553677f905b7c66 100644 --- a/tests/dev_env_tests/test_single_local_node_algorithm_execution.py +++ b/tests/dev_env_tests/test_single_local_node_algorithm_execution.py @@ -120,6 +120,8 @@ def mock_algorithms_modules(): ): yield + mipengine.ALGORITHM_FOLDERS = "./mipengine/algorithms" + def get_parametrization_list_success_cases(): parametrization_list = [] diff --git a/tests/smpc_env_tests/test_smpc_algorithms.py b/tests/smpc_env_tests/test_smpc_algorithms.py index 7295f866cf7d22a4cba22fcbf3e0142363963b12..427f3d66ebc56ad62efffdb5f4cf63750d62ee8d 100644 --- a/tests/smpc_env_tests/test_smpc_algorithms.py +++ b/tests/smpc_env_tests/test_smpc_algorithms.py @@ -53,7 +53,9 @@ def get_parametrization_list_success_cases(): "title": "Standard Deviation", "columns": [ {"name": "variable", "data": ["lefthippocampus"], "type": "STR"}, - {"name": "std_deviation", "data": [0.3611575592573076], "type": "FLOAT"}, + {"name": "std_deviation", "data": [0.3634506955662605], "type": "FLOAT"}, + {"name": "min_value", "data": [1.3047], "type": "FLOAT"}, + {"name": "max_value", "data": [4.4519], "type": "FLOAT"}, ], } parametrization_list.append((algorithm_name, request_dict, expected_response)) @@ -103,7 +105,9 @@ def get_parametrization_list_success_cases(): "title": "Standard Deviation", "columns": [ {"name": "variable", "data": ["lefthippocampus"], "type": "STR"}, - {"name": "std_deviation", "data": [0.3611575592573076], "type": "FLOAT"}, + {"name": "std_deviation", "data": [0.3634506955662605], "type": "FLOAT"}, + {"name": "min_value", "data": [1.3047], "type": "FLOAT"}, + {"name": "max_value", "data": [4.4519], "type": "FLOAT"}, ], } parametrization_list.append((algorithm_name, request_dict, expected_response)) @@ -111,9 +115,9 @@ def get_parametrization_list_success_cases(): return parametrization_list -# @pytest.mark.skip( -# reason="SMPC is not deployed in the CI yet. https://team-1617704806227.atlassian.net/browse/MIP-344" -# ) +@pytest.mark.skip( + reason="SMPC is not deployed in the CI yet. https://team-1617704806227.atlassian.net/browse/MIP-344" +) @pytest.mark.parametrize( "algorithm_name, request_dict, expected_response", get_parametrization_list_success_cases(), @@ -183,9 +187,9 @@ def get_parametrization_list_exception_cases(): return parametrization_list -# @pytest.mark.skip( -# reason="SMPC is not deployed in the CI yet. https://team-1617704806227.atlassian.net/browse/MIP-344" -# ) +@pytest.mark.skip( + reason="SMPC is not deployed in the CI yet. https://team-1617704806227.atlassian.net/browse/MIP-344" +) @pytest.mark.parametrize( "algorithm_name, request_dict, expected_response", get_parametrization_list_exception_cases(), diff --git a/tests/standalone_tests/test_algorithm_executor.py b/tests/standalone_tests/test_algorithm_executor.py index 36bdab7bd8af22536ac44f120554d6626b3a6f97..a218393797222d4146f585fd767372357ba66780 100644 --- a/tests/standalone_tests/test_algorithm_executor.py +++ b/tests/standalone_tests/test_algorithm_executor.py @@ -4,7 +4,7 @@ from typing import List import pytest -from mipengine.controller.algorithm_executor_node_data_objects import NodeTable +from mipengine.controller.algorithm_executor_node_data_objects import TableName from mipengine.controller.algorithm_executor_nodes import _INode from mipengine.controller.node_tasks_handler_interface import IQueuedUDFAsyncResult from mipengine.node_tasks_DTOs import ColumnInfo @@ -20,22 +20,22 @@ class NodeMock(_INode): def __init__(self): self.tables: Dict[str, TableSchema] = {} - def get_tables(self) -> List[NodeTable]: + def get_tables(self) -> List[TableName]: pass - def get_table_schema(self, table_name: NodeTable): + def get_table_schema(self, table_name: TableName): return self.tables[table_name] - def get_table_data(self, table_name: NodeTable) -> TableData: + def get_table_data(self, table_name: TableName) -> TableData: pass - def create_table(self, command_id: str, schema: TableSchema) -> NodeTable: + def create_table(self, command_id: str, schema: TableSchema) -> TableName: table_name = f"normal_testnode_cntxtid1_cmdid{randint(0,999)}_cmdsubid1" - table_name = NodeTable(table_name) + table_name = TableName(table_name) self.tables[table_name] = schema return table_name - def get_views(self) -> List[NodeTable]: + def get_views(self) -> List[TableName]: pass def create_pathology_view( @@ -44,13 +44,13 @@ class NodeMock(_INode): pathology: str, columns: List[str], filters: List[str], - ) -> NodeTable: + ) -> TableName: pass - def get_merge_tables(self) -> List[NodeTable]: + def get_merge_tables(self) -> List[TableName]: pass - def create_merge_table(self, command_id: str, table_names: List[NodeTable]): + def create_merge_table(self, command_id: str, table_names: List[TableName]): pass def get_remote_tables(self) -> List[str]: @@ -68,7 +68,7 @@ class NodeMock(_INode): def get_queued_udf_result( self, async_result: IQueuedUDFAsyncResult - ) -> List[NodeTable]: + ) -> List[TableName]: pass def get_udfs(self, algorithm_name) -> List[str]: diff --git a/tests/standalone_tests/test_smpc_node_tasks.py b/tests/standalone_tests/test_smpc_node_tasks.py index 8f2b08c017d31d309abff4faf024684f5d3102c7..8e433573d7f2eb25ef1869c693958a329ac921c7 100644 --- a/tests/standalone_tests/test_smpc_node_tasks.py +++ b/tests/standalone_tests/test_smpc_node_tasks.py @@ -97,12 +97,8 @@ def create_table_with_secure_transfer_results_with_smpc_off( secure_transfer_1_value = 100 secure_transfer_2_value = 11 - secure_transfer_1 = { - "sum": {"data": secure_transfer_1_value, "type": "int", "operation": "addition"} - } - secure_transfer_2 = { - "sum": {"data": secure_transfer_2_value, "type": "int", "operation": "addition"} - } + secure_transfer_1 = {"sum": {"data": secure_transfer_1_value, "operation": "sum"}} + secure_transfer_2 = {"sum": {"data": secure_transfer_2_value, "operation": "sum"}} values = [ ["localnode1", json.dumps(secure_transfer_1)], ["localnode2", json.dumps(secure_transfer_2)], @@ -123,12 +119,8 @@ def create_table_with_multiple_secure_transfer_templates( table_name = create_secure_transfer_table(celery_app) - secure_transfer_template = { - "sum": {"data": [0, 1, 2, 3], "type": "int", "operation": "addition"} - } - different_secure_transfer_template = { - "sum": {"data": 0, "type": "int", "operation": "addition"} - } + secure_transfer_template = {"sum": {"data": [0, 1, 2, 3], "operation": "sum"}} + different_secure_transfer_template = {"sum": {"data": 0, "operation": "sum"}} if similar: values = [ @@ -148,22 +140,22 @@ def create_table_with_multiple_secure_transfer_templates( return table_name -def create_table_with_smpc_add_op_values(celery_app) -> Tuple[str, str]: +def create_table_with_smpc_sum_op_values(celery_app) -> Tuple[str, str]: insert_data_to_table_task = get_celery_task_signature( celery_app, "insert_data_to_table" ) table_name = create_secure_transfer_table(celery_app) - add_op_values = [0, 1, 2, 3, 4, 5] + sum_op_values = [0, 1, 2, 3, 4, 5] values = [ - ["localnode1", json.dumps(add_op_values)], + ["localnode1", json.dumps(sum_op_values)], ] insert_data_to_table_task.delay( request_id=request_id, table_name=table_name, values=values ).get() - return table_name, json.dumps(add_op_values) + return table_name, json.dumps(sum_op_values) def validate_dict_table_data_match_expected( @@ -208,9 +200,7 @@ def test_secure_transfer_output_with_smpc_off( secure_transfer_result = results[0] assert isinstance(secure_transfer_result, NodeTableDTO) - expected_result = { - "sum": {"data": input_table_name_sum, "type": "int", "operation": "addition"} - } + expected_result = {"sum": {"data": input_table_name_sum, "operation": "sum"}} validate_dict_table_data_match_expected( get_table_data_task, secure_transfer_result.value, @@ -334,19 +324,19 @@ def test_secure_transfer_run_udf_flow_with_smpc_on( assert isinstance(smpc_result, NodeSMPCDTO) assert smpc_result.value.template is not None - expected_template = {"sum": {"data": 0, "type": "int", "operation": "addition"}} + expected_template = {"sum": {"data": 0, "operation": "sum"}} validate_dict_table_data_match_expected( get_table_data_task, smpc_result.value.template.value, expected_template, ) - assert smpc_result.value.add_op_values is not None - expected_add_op_values = [input_table_name_sum] + assert smpc_result.value.sum_op_values is not None + expected_sum_op_values = [input_table_name_sum] validate_dict_table_data_match_expected( get_table_data_task, - smpc_result.value.add_op_values.value, - expected_add_op_values, + smpc_result.value.sum_op_values.value, + expected_sum_op_values, ) # ----------------------- SECURE TRANSFER INPUT---------------------- @@ -355,7 +345,7 @@ def test_secure_transfer_run_udf_flow_with_smpc_on( smpc_arg = NodeSMPCDTO( value=NodeSMPCValueDTO( template=NodeTableDTO(value=smpc_result.value.template.value), - add_op_values=NodeTableDTO(value=smpc_result.value.add_op_values.value), + sum_op_values=NodeTableDTO(value=smpc_result.value.sum_op_values.value), ) ) @@ -410,7 +400,7 @@ def test_load_data_to_smpc_client( use_smpc_localnode1_database, smpc_localnode1_celery_app, ): - table_name, add_op_values_str = create_table_with_smpc_add_op_values( + table_name, sum_op_values_str = create_table_with_smpc_sum_op_values( smpc_localnode1_celery_app ) @@ -439,10 +429,10 @@ def test_load_data_to_smpc_client( # TODO Remove when smpc cluster call is fixed # The problem is that the call returns the integers as string - # assert response.text == add_op_values_str + # assert response.text == sum_op_values_str result = json.loads(response.text) result = [int(elem) for elem in result] - assert json.dumps(result) == add_op_values_str + assert json.dumps(result) == sum_op_values_str def test_get_smpc_result_from_localnode_fails( @@ -652,13 +642,13 @@ def test_orchestrate_SMPC_between_two_localnodes_and_the_globalnode( smpc_client_1 = load_data_to_smpc_client_task_localnode1.delay( request_id=request_id, context_id=context_id, - table_name=local_1_smpc_result.value.add_op_values.value, + table_name=local_1_smpc_result.value.sum_op_values.value, jobid=smpc_job_id, ).get() smpc_client_2 = load_data_to_smpc_client_task_localnode2.delay( request_id=request_id, context_id=context_id, - table_name=local_2_smpc_result.value.add_op_values.value, + table_name=local_2_smpc_result.value.sum_op_values.value, jobid=smpc_job_id, ).get() @@ -682,7 +672,7 @@ def test_orchestrate_SMPC_between_two_localnodes_and_the_globalnode( assert response.status_code == 200 # --------- Get Results of SMPC in globalnode ----------------- - add_op_values_tablename = get_smpc_result_task_globalnode.delay( + sum_op_values_tablename = get_smpc_result_task_globalnode.delay( request_id=request_id, context_id=context_id, command_id="4", @@ -693,7 +683,7 @@ def test_orchestrate_SMPC_between_two_localnodes_and_the_globalnode( smpc_arg = NodeSMPCDTO( value=NodeSMPCValueDTO( template=NodeTableDTO(value=globalnode_template_tablename), - add_op_values=NodeTableDTO(value=add_op_values_tablename), + sum_op_values=NodeTableDTO(value=sum_op_values_tablename), ) ) pos_args_str = UDFPosArguments(args=[smpc_arg]).json() diff --git a/tests/standalone_tests/test_udfgenerator.py b/tests/standalone_tests/test_udfgenerator.py index d314cf4dec2c39a523776585f369abc789975d42..32c83d0d061f1568cdade057972dbce7ccf21b5b 100644 --- a/tests/standalone_tests/test_udfgenerator.py +++ b/tests/standalone_tests/test_udfgenerator.py @@ -384,7 +384,7 @@ class TestUDFValidation: def test_validate_func_as_valid_udf_with_secure_transfer_output(self): @udf( y=state(), - return_type=secure_transfer(add_op=True), + return_type=secure_transfer(sum_op=True), ) def f(y): y = {"num": 1} @@ -394,7 +394,7 @@ class TestUDFValidation: def test_validate_func_as_valid_udf_with_secure_transfer_input(self): @udf( - y=secure_transfer(add_op=True), + y=secure_transfer(sum_op=True), return_type=transfer(), ) def f(y): @@ -1035,9 +1035,9 @@ class TestUDFGenBase: elif isinstance(udf_output, SMPCUDFGenResult): tablename_placeholder = udf_output.template.tablename_placeholder template_mapping[tablename_placeholder] = tablename_placeholder - if udf_output.add_op_values: + if udf_output.sum_op_values: tablename_placeholder = ( - udf_output.add_op_values.tablename_placeholder + udf_output.sum_op_values.tablename_placeholder ) template_mapping[tablename_placeholder] = tablename_placeholder if udf_output.min_op_values: @@ -1112,9 +1112,9 @@ class TestUDFGenBase: queries.extend(self._concrete_table_udf_outputs(udf_output)) elif isinstance(udf_output, SMPCUDFGenResult): queries.extend(self._concrete_table_udf_outputs(udf_output.template)) - if udf_output.add_op_values: + if udf_output.sum_op_values: queries.extend( - self._concrete_table_udf_outputs(udf_output.add_op_values) + self._concrete_table_udf_outputs(udf_output.sum_op_values) ) if udf_output.min_op_values: queries.extend( @@ -1173,31 +1173,50 @@ class TestUDFGenBase: "CREATE TABLE test_secure_transfer_table(node_id VARCHAR(500), secure_transfer CLOB)" ) globalnode_db_cursor.execute( - 'INSERT INTO test_secure_transfer_table(node_id, secure_transfer) VALUES(1, \'{"sum": {"data": 1, "type": "int", "operation": "addition"}}\')' + 'INSERT INTO test_secure_transfer_table(node_id, secure_transfer) VALUES(1, \'{"sum": {"data": 1, "operation": "sum"}}\')' ) globalnode_db_cursor.execute( - 'INSERT INTO test_secure_transfer_table(node_id, secure_transfer) VALUES(2, \'{"sum": {"data": 10, "type": "int", "operation": "addition"}}\')' + 'INSERT INTO test_secure_transfer_table(node_id, secure_transfer) VALUES(2, \'{"sum": {"data": 10, "operation": "sum"}}\')' ) globalnode_db_cursor.execute( - 'INSERT INTO test_secure_transfer_table(node_id, secure_transfer) VALUES(3, \'{"sum": {"data": 100, "type": "int", "operation": "addition"}}\')' + 'INSERT INTO test_secure_transfer_table(node_id, secure_transfer) VALUES(3, \'{"sum": {"data": 100, "operation": "sum"}}\')' ) @pytest.fixture(scope="function") - def create_smpc_template_table(self, globalnode_db_cursor): + def create_smpc_template_table_with_sum(self, globalnode_db_cursor): globalnode_db_cursor.execute( "CREATE TABLE test_smpc_template_table(node_id VARCHAR(500), secure_transfer CLOB)" ) globalnode_db_cursor.execute( - 'INSERT INTO test_smpc_template_table(node_id, secure_transfer) VALUES(1, \'{"sum": {"data": [0,1,2], "type": "int", "operation": "addition"}}\')' + 'INSERT INTO test_smpc_template_table(node_id, secure_transfer) VALUES(1, \'{"sum": {"data": [0,1,2], "operation": "sum"}}\')' ) @pytest.fixture(scope="function") - def create_smpc_add_op_values_table(self, globalnode_db_cursor): + def create_smpc_sum_op_values_table(self, globalnode_db_cursor): globalnode_db_cursor.execute( - "CREATE TABLE test_smpc_add_op_values_table(node_id VARCHAR(500), secure_transfer CLOB)" + "CREATE TABLE test_smpc_sum_op_values_table(node_id VARCHAR(500), secure_transfer CLOB)" ) globalnode_db_cursor.execute( - "INSERT INTO test_smpc_add_op_values_table(node_id, secure_transfer) VALUES(1, '[100,200,300]')" + "INSERT INTO test_smpc_sum_op_values_table(node_id, secure_transfer) VALUES(1, '[100,200,300]')" + ) + + @pytest.fixture(scope="function") + def create_smpc_template_table_with_sum_and_max(self, globalnode_db_cursor): + globalnode_db_cursor.execute( + "CREATE TABLE test_smpc_template_table(node_id VARCHAR(500), secure_transfer CLOB)" + ) + globalnode_db_cursor.execute( + 'INSERT INTO test_smpc_template_table(node_id, secure_transfer) VALUES(1, \'{"sum": {"data": [0,1,2], "operation": "sum"}, ' + '"max": {"data": 0, "operation": "max"}}\')' + ) + + @pytest.fixture(scope="function") + def create_smpc_max_op_values_table(self, globalnode_db_cursor): + globalnode_db_cursor.execute( + "CREATE TABLE test_smpc_max_op_values_table(node_id VARCHAR(500), secure_transfer CLOB)" + ) + globalnode_db_cursor.execute( + "INSERT INTO test_smpc_max_op_values_table(node_id, secure_transfer) VALUES(1, '[58]')" ) # TODO Should become more dynamic in the future. @@ -1426,8 +1445,8 @@ class TestUDFGen_Invalid_SMPCUDFInput_To_Transfer_Type(TestUDFGenBase): ), type_=TableType.NORMAL, ), - add_op_values=TableInfo( - name="test_smpc_add_op_values_table", + sum_op_values=TableInfo( + name="test_smpc_sum_op_values_table", schema_=TableSchema( columns=[ ColumnInfo(name="secure_transfer", dtype=DType.JSON), @@ -1452,7 +1471,7 @@ class TestUDFGen_Invalid_TableInfoArgs_To_SecureTransferType(TestUDFGenBase): @pytest.fixture(scope="class") def udfregistry(self): @udf( - transfer=secure_transfer(add_op=True), + transfer=secure_transfer(sum_op=True), return_type=transfer(), ) def f(transfer): @@ -1499,7 +1518,7 @@ class TestUDFGen_Invalid_SMPCUDFInput_with_SMPC_off(TestUDFGenBase): @pytest.fixture(scope="class") def udfregistry(self): @udf( - transfer=secure_transfer(add_op=True), + transfer=secure_transfer(sum_op=True), return_type=transfer(), ) def f(transfer): @@ -1520,8 +1539,8 @@ class TestUDFGen_Invalid_SMPCUDFInput_with_SMPC_off(TestUDFGenBase): ), type_=TableType.NORMAL, ), - add_op_values=TableInfo( - name="test_smpc_add_op_values_table", + sum_op_values=TableInfo( + name="test_smpc_sum_op_values_table", schema_=TableSchema( columns=[ ColumnInfo(name="secure_transfer", dtype=DType.JSON), @@ -4469,11 +4488,13 @@ class TestUDFGen_SecureTransferOutput_with_SMPC_off( def udfregistry(self): @udf( state=state(), - return_type=secure_transfer(add_op=True), + return_type=secure_transfer(sum_op=True, min_op=True, max_op=True), ) def f(state): result = { - "sum": {"data": state["num"], "type": "int", "operation": "addition"} + "sum": {"data": state["num"], "operation": "sum"}, + "min": {"data": state["num"], "operation": "min"}, + "max": {"data": state["num"], "operation": "max"}, } return result @@ -4508,8 +4529,9 @@ LANGUAGE PYTHON import json __state_str = _conn.execute("SELECT state from test_state_table;")["state"][0] state = pickle.loads(__state_str) - result = {'sum': {'data': state['num'], 'type': 'int', 'operation': 'addition'} - } + result = {'sum': {'data': state['num'], 'operation': 'sum'}, 'min': {'data': + state['num'], 'operation': 'min'}, 'max': {'data': state['num'], + 'operation': 'max'}} return json.dumps(result) }""" @@ -4554,7 +4576,11 @@ FROM "SELECT secure_transfer FROM main_output_table_name" ).fetchone() result = json.loads(secure_transfer_) - assert result == {"sum": {"data": 5, "type": "int", "operation": "addition"}} + assert result == { + "sum": {"data": 5, "operation": "sum"}, + "min": {"data": 5, "operation": "min"}, + "max": {"data": 5, "operation": "max"}, + } class TestUDFGen_SecureTransferOutput_with_SMPC_on( @@ -4564,11 +4590,12 @@ class TestUDFGen_SecureTransferOutput_with_SMPC_on( def udfregistry(self): @udf( state=state(), - return_type=secure_transfer(add_op=True), + return_type=secure_transfer(sum_op=True, max_op=True), ) def f(state): result = { - "sum": {"data": state["num"], "type": "int", "operation": "addition"} + "sum": {"data": state["num"], "operation": "sum"}, + "max": {"data": state["num"], "operation": "max"}, } return result @@ -4603,10 +4630,11 @@ LANGUAGE PYTHON import json __state_str = _conn.execute("SELECT state from test_state_table;")["state"][0] state = pickle.loads(__state_str) - result = {'sum': {'data': state['num'], 'type': 'int', 'operation': 'addition'} - } - template, add_op, min_op, max_op, union_op = udfio.split_secure_transfer_dict(result) - _conn.execute(f"INSERT INTO $main_output_table_name_add_op VALUES ('$node_id', '{json.dumps(add_op)}');") + result = {'sum': {'data': state['num'], 'operation': 'sum'}, 'max': {'data': + state['num'], 'operation': 'max'}} + template, sum_op, min_op, max_op, union_op = udfio.split_secure_transfer_dict(result) + _conn.execute(f"INSERT INTO $main_output_table_name_sum_op VALUES ('$node_id', '{json.dumps(sum_op)}');") + _conn.execute(f"INSERT INTO $main_output_table_name_max_op VALUES ('$node_id', '{json.dumps(max_op)}');") return json.dumps(template) }""" @@ -4633,13 +4661,22 @@ FROM 'CREATE TABLE $main_output_table_name("node_id" VARCHAR(500),"secure_transfer" CLOB);' ), ), - add_op_values=TableUDFGenResult( - tablename_placeholder="main_output_table_name_add_op", + sum_op_values=TableUDFGenResult( + tablename_placeholder="main_output_table_name_sum_op", + drop_query=Template( + "DROP TABLE IF EXISTS $main_output_table_name_sum_op;" + ), + create_query=Template( + 'CREATE TABLE $main_output_table_name_sum_op("node_id" VARCHAR(500),"secure_transfer" CLOB);' + ), + ), + max_op_values=TableUDFGenResult( + tablename_placeholder="main_output_table_name_max_op", drop_query=Template( - "DROP TABLE IF EXISTS $main_output_table_name_add_op;" + "DROP TABLE IF EXISTS $main_output_table_name_max_op;" ), create_query=Template( - 'CREATE TABLE $main_output_table_name_add_op("node_id" VARCHAR(500),"secure_transfer" CLOB);' + 'CREATE TABLE $main_output_table_name_max_op("node_id" VARCHAR(500),"secure_transfer" CLOB);' ), ), ) @@ -4668,13 +4705,22 @@ FROM "SELECT secure_transfer FROM main_output_table_name" ).fetchone() template = json.loads(template_str) - assert template == {"sum": {"data": 0, "type": "int", "operation": "addition"}} + assert template == { + "max": {"data": 0, "operation": "max"}, + "sum": {"data": 0, "operation": "sum"}, + } - add_op_values_str, *_ = globalnode_db_cursor.execute( - "SELECT secure_transfer FROM main_output_table_name_add_op" + sum_op_values_str, *_ = globalnode_db_cursor.execute( + "SELECT secure_transfer FROM main_output_table_name_sum_op" ).fetchone() - add_op_values = json.loads(add_op_values_str) - assert add_op_values == [5] + sum_op_values = json.loads(sum_op_values_str) + assert sum_op_values == [5] + + max_op_values_str, *_ = globalnode_db_cursor.execute( + "SELECT secure_transfer FROM main_output_table_name_max_op" + ).fetchone() + max_op_values = json.loads(max_op_values_str) + assert max_op_values == [5] class TestUDFGen_SecureTransferOutputAs2ndOutput_with_SMPC_off( @@ -4684,11 +4730,16 @@ class TestUDFGen_SecureTransferOutputAs2ndOutput_with_SMPC_off( def udfregistry(self): @udf( state=state(), - return_type=[state(), secure_transfer(add_op=True)], + return_type=[ + state(), + secure_transfer(sum_op=True, min_op=True, max_op=True), + ], ) def f(state): result = { - "sum": {"data": state["num"], "type": "int", "operation": "addition"} + "sum": {"data": state["num"], "operation": "sum"}, + "min": {"data": state["num"], "operation": "min"}, + "max": {"data": state["num"], "operation": "max"}, } return state, result @@ -4723,8 +4774,9 @@ LANGUAGE PYTHON import json __state_str = _conn.execute("SELECT state from test_state_table;")["state"][0] state = pickle.loads(__state_str) - result = {'sum': {'data': state['num'], 'type': 'int', 'operation': 'addition'} - } + result = {'sum': {'data': state['num'], 'operation': 'sum'}, 'min': {'data': + state['num'], 'operation': 'min'}, 'max': {'data': state['num'], + 'operation': 'max'}} _conn.execute(f"INSERT INTO $loopback_table_name_0 VALUES ('$node_id', '{json.dumps(result)}');") return pickle.dumps(state) }""" @@ -4778,7 +4830,11 @@ FROM "SELECT secure_transfer FROM loopback_table_name_0" ).fetchone() result = json.loads(secure_transfer_) - assert result == {"sum": {"data": 5, "type": "int", "operation": "addition"}} + assert result == { + "sum": {"data": 5, "operation": "sum"}, + "min": {"data": 5, "operation": "min"}, + "max": {"data": 5, "operation": "max"}, + } class TestUDFGen_SecureTransferOutputAs2ndOutput_with_SMPC_on( @@ -4788,11 +4844,16 @@ class TestUDFGen_SecureTransferOutputAs2ndOutput_with_SMPC_on( def udfregistry(self): @udf( state=state(), - return_type=[state(), secure_transfer(add_op=True)], + return_type=[ + state(), + secure_transfer(sum_op=True, min_op=True, max_op=True), + ], ) def f(state): result = { - "sum": {"data": state["num"], "type": "int", "operation": "addition"} + "sum": {"data": state["num"], "operation": "sum"}, + "min": {"data": state["num"], "operation": "min"}, + "max": {"data": state["num"], "operation": "max"}, } return state, result @@ -4827,11 +4888,14 @@ LANGUAGE PYTHON import json __state_str = _conn.execute("SELECT state from test_state_table;")["state"][0] state = pickle.loads(__state_str) - result = {'sum': {'data': state['num'], 'type': 'int', 'operation': 'addition'} - } - template, add_op, min_op, max_op, union_op = udfio.split_secure_transfer_dict(result) + result = {'sum': {'data': state['num'], 'operation': 'sum'}, 'min': {'data': + state['num'], 'operation': 'min'}, 'max': {'data': state['num'], + 'operation': 'max'}} + template, sum_op, min_op, max_op, union_op = udfio.split_secure_transfer_dict(result) _conn.execute(f"INSERT INTO $loopback_table_name_0 VALUES ('$node_id', '{json.dumps(template)}');") - _conn.execute(f"INSERT INTO $loopback_table_name_0_add_op VALUES ('$node_id', '{json.dumps(add_op)}');") + _conn.execute(f"INSERT INTO $loopback_table_name_0_sum_op VALUES ('$node_id', '{json.dumps(sum_op)}');") + _conn.execute(f"INSERT INTO $loopback_table_name_0_min_op VALUES ('$node_id', '{json.dumps(min_op)}');") + _conn.execute(f"INSERT INTO $loopback_table_name_0_max_op VALUES ('$node_id', '{json.dumps(max_op)}');") return pickle.dumps(state) }""" @@ -4863,13 +4927,31 @@ FROM 'CREATE TABLE $loopback_table_name_0("node_id" VARCHAR(500),"secure_transfer" CLOB);' ), ), - add_op_values=TableUDFGenResult( - tablename_placeholder="loopback_table_name_0_add_op", + sum_op_values=TableUDFGenResult( + tablename_placeholder="loopback_table_name_0_sum_op", + drop_query=Template( + "DROP TABLE IF EXISTS $loopback_table_name_0_sum_op;" + ), + create_query=Template( + 'CREATE TABLE $loopback_table_name_0_sum_op("node_id" VARCHAR(500),"secure_transfer" CLOB);' + ), + ), + min_op_values=TableUDFGenResult( + tablename_placeholder="loopback_table_name_0_min_op", + drop_query=Template( + "DROP TABLE IF EXISTS $loopback_table_name_0_min_op;" + ), + create_query=Template( + 'CREATE TABLE $loopback_table_name_0_min_op("node_id" VARCHAR(500),"secure_transfer" CLOB);' + ), + ), + max_op_values=TableUDFGenResult( + tablename_placeholder="loopback_table_name_0_max_op", drop_query=Template( - "DROP TABLE IF EXISTS $loopback_table_name_0_add_op;" + "DROP TABLE IF EXISTS $loopback_table_name_0_max_op;" ), create_query=Template( - 'CREATE TABLE $loopback_table_name_0_add_op("node_id" VARCHAR(500),"secure_transfer" CLOB);' + 'CREATE TABLE $loopback_table_name_0_max_op("node_id" VARCHAR(500),"secure_transfer" CLOB);' ), ), ), @@ -4898,13 +4980,29 @@ FROM "SELECT secure_transfer FROM loopback_table_name_0" ).fetchone() template = json.loads(template_str) - assert template == {"sum": {"data": 0, "type": "int", "operation": "addition"}} + assert template == { + "sum": {"data": 0, "operation": "sum"}, + "min": {"data": 0, "operation": "min"}, + "max": {"data": 0, "operation": "max"}, + } + + sum_op_values_str, *_ = globalnode_db_cursor.execute( + "SELECT secure_transfer FROM loopback_table_name_0_sum_op" + ).fetchone() + sum_op_values = json.loads(sum_op_values_str) + assert sum_op_values == [5] - add_op_values_str, *_ = globalnode_db_cursor.execute( - "SELECT secure_transfer FROM loopback_table_name_0_add_op" + min_op_values_str, *_ = globalnode_db_cursor.execute( + "SELECT secure_transfer FROM loopback_table_name_0_min_op" ).fetchone() - add_op_values = json.loads(add_op_values_str) - assert add_op_values == [5] + min_op_values = json.loads(min_op_values_str) + assert min_op_values == [5] + + max_op_values_str, *_ = globalnode_db_cursor.execute( + "SELECT secure_transfer FROM loopback_table_name_0_max_op" + ).fetchone() + max_op_values = json.loads(max_op_values_str) + assert max_op_values == [5] class TestUDFGen_SecureTransferInput_with_SMPC_off( @@ -4913,7 +5011,7 @@ class TestUDFGen_SecureTransferInput_with_SMPC_off( @pytest.fixture(scope="class") def udfregistry(self): @udf( - transfer=secure_transfer(add_op=True), + transfer=secure_transfer(sum_op=True), return_type=transfer(), ) def f(transfer): @@ -5003,7 +5101,7 @@ class TestUDFGen_SecureTransferInput_with_SMPC_on( @pytest.fixture(scope="class") def udfregistry(self): @udf( - transfer=secure_transfer(add_op=True), + transfer=secure_transfer(sum_op=True, max_op=True), return_type=transfer(), ) def f(transfer): @@ -5024,8 +5122,17 @@ class TestUDFGen_SecureTransferInput_with_SMPC_on( ), type_=TableType.NORMAL, ), - add_op_values=TableInfo( - name="test_smpc_add_op_values_table", + sum_op_values=TableInfo( + name="test_smpc_sum_op_values_table", + schema_=TableSchema( + columns=[ + ColumnInfo(name="secure_transfer", dtype=DType.JSON), + ] + ), + type_=TableType.NORMAL, + ), + max_op_values=TableInfo( + name="test_smpc_max_op_values_table", schema_=TableSchema( columns=[ ColumnInfo(name="secure_transfer", dtype=DType.JSON), @@ -5050,12 +5157,13 @@ LANGUAGE PYTHON import json __template_str = _conn.execute("SELECT secure_transfer from test_smpc_template_table;")["secure_transfer"][0] __template = json.loads(__template_str) - __add_op_values_str = _conn.execute("SELECT secure_transfer from test_smpc_add_op_values_table;")["secure_transfer"][0] - __add_op_values = json.loads(__add_op_values_str) + __sum_op_values_str = _conn.execute("SELECT secure_transfer from test_smpc_sum_op_values_table;")["secure_transfer"][0] + __sum_op_values = json.loads(__sum_op_values_str) __min_op_values = None - __max_op_values = None + __max_op_values_str = _conn.execute("SELECT secure_transfer from test_smpc_max_op_values_table;")["secure_transfer"][0] + __max_op_values = json.loads(__max_op_values_str) __union_op_values = None - transfer = udfio.construct_secure_transfer_dict(__template,__add_op_values,__min_op_values,__max_op_values,__union_op_values) + transfer = udfio.construct_secure_transfer_dict(__template,__sum_op_values,__min_op_values,__max_op_values,__union_op_values) return json.dumps(transfer) }""" @@ -5088,8 +5196,9 @@ FROM @pytest.mark.database @pytest.mark.usefixtures( "use_globalnode_database", - "create_smpc_template_table", - "create_smpc_add_op_values_table", + "create_smpc_template_table_with_sum_and_max", + "create_smpc_sum_op_values_table", + "create_smpc_max_op_values_table", ) def test_udf_with_db( self, @@ -5105,7 +5214,7 @@ FROM "SELECT transfer FROM main_output_table_name" ).fetchone() result = json.loads(transfer) - assert result == {"sum": [100, 200, 300]} + assert result == {"sum": [100, 200, 300], "max": 58} class TestUDFGen_LoggerArgument(TestUDFGenBase, _TestGenerateUDFQueries): diff --git a/tests/standalone_tests/test_udfio.py b/tests/standalone_tests/test_udfio.py index 89a8cbb6a9931a3de1381dcfeab969266afc037a..cf859a1b355034a297ffa766a111eef1781329a0 100644 --- a/tests/standalone_tests/test_udfio.py +++ b/tests/standalone_tests/test_udfio.py @@ -76,10 +76,10 @@ def get_secure_transfers_to_merged_dict_success_cases(): ( [ { - "a": {"data": 2, "type": "int", "operation": "addition"}, + "a": {"data": 2, "operation": "sum"}, }, { - "a": {"data": 3, "type": "int", "operation": "addition"}, + "a": {"data": 3, "operation": "sum"}, }, ], {"a": 5}, @@ -87,12 +87,12 @@ def get_secure_transfers_to_merged_dict_success_cases(): ( [ { - "a": {"data": 2, "type": "int", "operation": "addition"}, - "b": {"data": 5, "type": "int", "operation": "addition"}, + "a": {"data": 2, "operation": "sum"}, + "b": {"data": 5, "operation": "sum"}, }, { - "a": {"data": 3, "type": "int", "operation": "addition"}, - "b": {"data": 7, "type": "int", "operation": "addition"}, + "a": {"data": 3, "operation": "sum"}, + "b": {"data": 7, "operation": "sum"}, }, ], {"a": 5, "b": 12}, @@ -100,10 +100,10 @@ def get_secure_transfers_to_merged_dict_success_cases(): ( [ { - "a": {"data": [1, 2, 3], "type": "int", "operation": "addition"}, + "a": {"data": [1, 2, 3], "operation": "sum"}, }, { - "a": {"data": [9, 8, 7], "type": "int", "operation": "addition"}, + "a": {"data": [9, 8, 7], "operation": "sum"}, }, ], { @@ -113,29 +113,25 @@ def get_secure_transfers_to_merged_dict_success_cases(): ( [ { - "a": {"data": 10, "type": "int", "operation": "addition"}, + "a": {"data": 10, "operation": "sum"}, "b": { "data": [10, 20, 30, 40, 50, 60], - "type": "int", - "operation": "addition", + "operation": "sum", }, "c": { "data": [[10, 20, 30, 40, 50, 60], [70, 80, 90]], - "type": "int", - "operation": "addition", + "operation": "sum", }, }, { - "a": {"data": 100, "type": "int", "operation": "addition"}, + "a": {"data": 100, "operation": "sum"}, "b": { "data": [100, 200, 300, 400, 500, 600], - "type": "int", - "operation": "addition", + "operation": "sum", }, "c": { "data": [[100, 200, 300, 400, 500, 600], [700, 800, 900]], - "type": "int", - "operation": "addition", + "operation": "sum", }, }, ], @@ -145,102 +141,57 @@ def get_secure_transfers_to_merged_dict_success_cases(): "c": [[110, 220, 330, 440, 550, 660], [770, 880, 990]], }, ), - ] - return secure_transfers_cases - - -@pytest.mark.parametrize( - "transfers, result", get_secure_transfers_to_merged_dict_success_cases() -) -def test_secure_transfer_to_merged_dict(transfers, result): - assert secure_transfers_to_merged_dict(transfers) == result - - -def get_secure_transfers_merged_to_dict_fail_cases(): - secure_transfers_fail_cases = [ - ( - [ - { - "a": {"data": 2, "type": "int", "operation": "addition"}, - }, - { - "a": {"data": 3, "type": "int", "operation": "whatever"}, - }, - ], - ( - ValueError, - "All secure transfer keys should have the same 'operation' .*", - ), - ), - ( - [ - { - "a": {"data": 2, "type": "int", "operation": "addition"}, - }, - { - "a": {"data": 3, "type": "decimal", "operation": "addition"}, - }, - ], - (ValueError, "All secure transfer keys should have the same 'type' .*"), - ), - ( - [ - { - "a": {"data": 2, "type": "int", "operation": "addition"}, - }, - { - "a": {"data": [3], "type": "int", "operation": "addition"}, - }, - ], - (TypeError, "Secure transfer data have different types: .*"), - ), ( [ { - "a": {"data": 2, "type": "whatever", "operation": "addition"}, - }, - { - "a": {"data": 3, "type": "whatever", "operation": "addition"}, - }, - ], - ( - NotImplementedError, - "Secure transfer type: .* not supported for operation: .*", - ), - ), - ( - [ - { - "a": {"data": 2, "type": "int", "operation": "whatever"}, + "sum": {"data": 10, "operation": "sum"}, + "min": { + "data": [10, 200, 30, 400, 50, 600], + "operation": "min", + }, + "max": { + "data": [[100, 20, 300, 40, 500, 60], [700, 80, 900]], + "operation": "max", + }, }, { - "a": {"data": 3, "type": "int", "operation": "addition"}, + "sum": {"data": 100, "operation": "sum"}, + "min": { + "data": [100, 20, 300, 40, 500, 60], + "operation": "min", + }, + "max": { + "data": [[10, 200, 30, 400, 50, 600], [70, 800, 90]], + "operation": "max", + }, }, ], - (NotImplementedError, "Secure transfer operation not supported: .*"), + { + "sum": 110, + "min": [10, 20, 30, 40, 50, 60], + "max": [[100, 200, 300, 400, 500, 600], [700, 800, 900]], + }, ), ] - return secure_transfers_fail_cases + return secure_transfers_cases @pytest.mark.parametrize( - "transfers, exception", get_secure_transfers_merged_to_dict_fail_cases() + "transfers, result", get_secure_transfers_to_merged_dict_success_cases() ) -def test_secure_transfers_to_merged_dict_fail_cases(transfers, exception): - exception_type, exception_message = exception - with pytest.raises(exception_type, match=exception_message): - secure_transfers_to_merged_dict(transfers) +def test_secure_transfer_to_merged_dict(transfers, result): + assert secure_transfers_to_merged_dict(transfers) == result def get_secure_transfer_dict_success_cases(): secure_transfer_cases = [ ( { - "a": {"data": 2, "type": "int", "operation": "addition"}, + "a": {"data": 2, "operation": "sum"}, }, ( { - "a": {"data": 0, "type": "int", "operation": "addition"}, + "a": {"data": 0, "operation": "sum"}, }, [2], [], @@ -253,13 +204,13 @@ def get_secure_transfer_dict_success_cases(): ), ( { - "a": {"data": 2, "type": "int", "operation": "addition"}, - "b": {"data": 5, "type": "int", "operation": "addition"}, + "a": {"data": 2, "operation": "sum"}, + "b": {"data": 5, "operation": "sum"}, }, ( { - "a": {"data": 0, "type": "int", "operation": "addition"}, - "b": {"data": 1, "type": "int", "operation": "addition"}, + "a": {"data": 0, "operation": "sum"}, + "b": {"data": 1, "operation": "sum"}, }, [2, 5], [], @@ -270,11 +221,11 @@ def get_secure_transfer_dict_success_cases(): ), ( { - "a": {"data": [1, 2, 3], "type": "int", "operation": "addition"}, + "a": {"data": [1, 2, 3], "operation": "sum"}, }, ( { - "a": {"data": [0, 1, 2], "type": "int", "operation": "addition"}, + "a": {"data": [0, 1, 2], "operation": "sum"}, }, [1, 2, 3], [], @@ -287,30 +238,26 @@ def get_secure_transfer_dict_success_cases(): ), ( { - "a": {"data": 10, "type": "int", "operation": "addition"}, + "a": {"data": 10, "operation": "sum"}, "b": { "data": [10, 20, 30, 40, 50, 60], - "type": "int", - "operation": "addition", + "operation": "sum", }, "c": { "data": [[10, 20, 30, 40, 50, 60], [70, 80, 90]], - "type": "int", - "operation": "addition", + "operation": "sum", }, }, ( { - "a": {"data": 0, "type": "int", "operation": "addition"}, + "a": {"data": 0, "operation": "sum"}, "b": { "data": [1, 2, 3, 4, 5, 6], - "type": "int", - "operation": "addition", + "operation": "sum", }, "c": { "data": [[7, 8, 9, 10, 11, 12], [13, 14, 15]], - "type": "int", - "operation": "addition", + "operation": "sum", }, }, [10, 10, 20, 30, 40, 50, 60, 10, 20, 30, 40, 50, 60, 70, 80, 90], @@ -324,6 +271,92 @@ def get_secure_transfer_dict_success_cases(): "c": [[10, 20, 30, 40, 50, 60], [70, 80, 90]], }, ), + ( + { + "min": {"data": [2, 5.6], "operation": "min"}, + }, + ( + { + "min": {"data": [0, 1], "operation": "min"}, + }, + [], + [2, 5.6], + [], + [], + ), + { + "min": [2, 5.6], + }, + ), + ( + { + "max": {"data": [2, 5.6], "operation": "max"}, + }, + ( + { + "max": {"data": [0, 1], "operation": "max"}, + }, + [], + [], + [2, 5.6], + [], + ), + { + "max": [2, 5.6], + }, + ), + ( + { + "sum1": {"data": [1, 2, 3, 4.5], "operation": "sum"}, + "sum2": {"data": [6, 7.8], "operation": "sum"}, + "min1": {"data": [6, 7.8], "operation": "min"}, + "min2": {"data": [1.5, 2.0], "operation": "min"}, + "max1": {"data": [6.8, 7], "operation": "max"}, + "max2": {"data": [1.5, 2], "operation": "max"}, + }, + ( + { + "sum1": {"data": [0, 1, 2, 3], "operation": "sum"}, + "sum2": {"data": [4, 5], "operation": "sum"}, + "min1": {"data": [0, 1], "operation": "min"}, + "min2": {"data": [2, 3], "operation": "min"}, + "max1": {"data": [0, 1], "operation": "max"}, + "max2": {"data": [2, 3], "operation": "max"}, + }, + [1, 2, 3, 4.5, 6, 7.8], + [6, 7.8, 1.5, 2.0], + [6.8, 7, 1.5, 2], + [], + ), + { + "sum1": [1, 2, 3, 4.5], + "sum2": [6, 7.8], + "min1": [6, 7.8], + "min2": [1.5, 2.0], + "max1": [6.8, 7], + "max2": [1.5, 2], + }, + ), + ( + { + "sum": {"data": [100, 200, 300], "operation": "sum"}, + "max": {"data": 58, "operation": "max"}, + }, + ( + { + "sum": {"data": [0, 1, 2], "operation": "sum"}, + "max": {"data": 0, "operation": "max"}, + }, + [100, 200, 300], + [], + [58], + [], + ), + { + "sum": [100, 200, 300], + "max": 58, + }, + ), ] return secure_transfer_cases @@ -342,3 +375,122 @@ def test_split_secure_transfer_dict(secure_transfer, smpc_parts, final_result): ) def test_construct_secure_transfer_dict(secure_transfer, smpc_parts, final_result): assert construct_secure_transfer_dict(*smpc_parts) == final_result + + +def get_secure_transfers_merged_to_dict_fail_cases(): + secure_transfers_fail_cases = [ + ( + [ + { + "a": {"data": 2, "operation": "sum"}, + }, + { + "a": {"data": 3, "operation": "whatever"}, + }, + ], + ( + ValueError, + "Secure Transfer operation is not supported: .*", + ), + ), + ( + [ + { + "a": {"data": 2, "operation": "sum"}, + }, + { + "a": {"data": 3, "operation": "min"}, + }, + ], + ( + ValueError, + "All secure transfer keys should have the same 'operation' .*", + ), + ), + ( + [ + { + "a": {"data": 2, "operation": "sum"}, + }, + { + "a": {"data": [3], "operation": "sum"}, + }, + ], + (ValueError, "Secure transfers' data should have the same structure."), + ), + ( + [ + { + "a": {"data": "tet", "operation": "sum"}, + }, + { + "a": {"data": "tet", "operation": "sum"}, + }, + ], + ( + TypeError, + "Secure transfer data must have one of the following types: .*", + ), + ), + ] + return secure_transfers_fail_cases + + +@pytest.mark.parametrize( + "transfers, exception", get_secure_transfers_merged_to_dict_fail_cases() +) +def test_secure_transfers_to_merged_dict_fail_cases(transfers, exception): + exception_type, exception_message = exception + with pytest.raises(exception_type, match=exception_message): + secure_transfers_to_merged_dict(transfers) + + +def get_split_secure_transfer_dict_fail_cases(): + split_secure_transfer_dict_fail_cases = [ + ( + { + "a": {"data": 3, "operation": "whatever"}, + }, + ( + ValueError, + "Secure Transfer operation is not supported: .*", + ), + ), + ( + { + "a": {"data": "tet", "operation": "sum"}, + }, + ( + TypeError, + "Secure Transfer key: 'a', operation: 'sum'. Error: Types allowed: .*", + ), + ), + ( + { + "a": {"llalal": 0, "operation": "sum"}, + }, + ( + ValueError, + "Each Secure Transfer key should contain data.", + ), + ), + ( + { + "a": {"data": 0, "sdfs": "sum"}, + }, + ( + ValueError, + "Each Secure Transfer key should contain an operation.", + ), + ), + ] + return split_secure_transfer_dict_fail_cases + + +@pytest.mark.parametrize( + "result, exception", get_split_secure_transfer_dict_fail_cases() +) +def test_split_secure_transfer_dict_fail_cases(result, exception): + exception_type, exception_message = exception + with pytest.raises(exception_type, match=exception_message): + split_secure_transfer_dict(result)