diff --git a/mipengine/controller/data_model_registry.py b/mipengine/controller/data_model_registry.py index dfc75d12cfc231e8e83540119ce620d399edb6cb..9655cb5f44fb0742eda24ceb2c7c3c9057001bcf 100644 --- a/mipengine/controller/data_model_registry.py +++ b/mipengine/controller/data_model_registry.py @@ -11,8 +11,8 @@ def _have_common_elements(a: List[Any], b: List[Any]): class DataModelRegistry: def __init__(self): - self.data_models: Dict[str, CommonDataElements] = {} - self.datasets_location: Dict[str, Dict[str, List[str]]] = {} + self._data_models: Dict[str, CommonDataElements] = {} + self._datasets_location: Dict[str, Dict[str, List[str]]] = {} @property def data_models(self): diff --git a/mipengine/controller/node_landscape_aggregator.py b/mipengine/controller/node_landscape_aggregator.py index 8e8f510af711967e1499a3247e4316eaf1b99bf0..98028a097597ad2aa75d247eee2f3fd45f06cea5 100644 --- a/mipengine/controller/node_landscape_aggregator.py +++ b/mipengine/controller/node_landscape_aggregator.py @@ -44,9 +44,7 @@ async def _get_nodes_info(nodes_socket_addr: List[str]) -> List[NodeInfo]: } tasks_coroutines = [ - _task_to_async(task, connection=app.broker_connection())( - request_id=NODE_LANDSCAPE_AGGREGATOR_REQUEST_ID - ) + _task_to_async(task, app=app)(request_id=NODE_LANDSCAPE_AGGREGATOR_REQUEST_ID) for app, task in nodes_task_signature.items() ] results = await asyncio.gather(*tasks_coroutines, return_exceptions=True) @@ -55,6 +53,10 @@ async def _get_nodes_info(nodes_socket_addr: List[str]) -> List[NodeInfo]: for result in results if not isinstance(result, Exception) ] + + for app in celery_apps: + app.close() + return nodes_info @@ -64,9 +66,9 @@ async def _get_node_datasets_per_data_model( celery_app = get_node_celery_app(node_socket_addr) task_signature = celery_app.signature(GET_NODE_DATASETS_PER_DATA_MODEL_SIGNATURE) - result = await _task_to_async( - task_signature, connection=celery_app.broker_connection() - )(request_id=NODE_LANDSCAPE_AGGREGATOR_REQUEST_ID) + result = await _task_to_async(task_signature, app=celery_app)( + request_id=NODE_LANDSCAPE_AGGREGATOR_REQUEST_ID + ) datasets_per_data_model = {} if not isinstance(result, Exception): @@ -80,15 +82,15 @@ async def _get_node_cdes(node_socket_addr: str, data_model: str) -> CommonDataEl celery_app = get_node_celery_app(node_socket_addr) task_signature = celery_app.signature(GET_DATA_MODEL_CDES_SIGNATURE) - result = await _task_to_async( - task_signature, connection=celery_app.broker_connection() - )(data_model=data_model, request_id=NODE_LANDSCAPE_AGGREGATOR_REQUEST_ID) + result = await _task_to_async(task_signature, app=celery_app)( + data_model=data_model, request_id=NODE_LANDSCAPE_AGGREGATOR_REQUEST_ID + ) if not isinstance(result, Exception): return CommonDataElements.parse_raw(result) -def _task_to_async(task, connection): +def _task_to_async(task, app): """ex Converts a Celery task to an async function Celery doesn't currently support asyncio "await" while "getting" a result @@ -106,18 +108,19 @@ def _task_to_async(task, connection): delay = 0.1 # Since apply_async is used instead of delay so that we can pass the connection as an argument, # the args and kwargs need to be passed as named arguments. - async_result = await sync_to_async(task.apply_async)( - args=args, kwargs=kwargs, connection=connection - ) - while not async_result.ready(): - total_delay += delay - if total_delay > CELERY_TASKS_TIMEOUT: - raise TimeoutError( - f"Celery task: {task} didn't respond in {CELERY_TASKS_TIMEOUT}s." - ) - await asyncio.sleep(delay) - delay = min(delay * 1.5, 2) # exponential backoff, max 2 seconds - return async_result.get(timeout=CELERY_TASKS_TIMEOUT - total_delay) + with app.broker_connection() as conn: + async_result = await sync_to_async(task.apply_async)( + args=args, kwargs=kwargs, connection=conn + ) + while not async_result.ready(): + total_delay += delay + if total_delay > CELERY_TASKS_TIMEOUT: + raise TimeoutError( + f"Celery task: {task} didn't respond in {CELERY_TASKS_TIMEOUT}s." + ) + await asyncio.sleep(delay) + delay = min(delay * 1.5, 2) # exponential backoff, max 2 seconds + return async_result.get(timeout=CELERY_TASKS_TIMEOUT - total_delay) return wrapper @@ -157,25 +160,26 @@ class NodeLandscapeAggregator(metaclass=Singleton): local_nodes = [ node for node in nodes_info if node.role == NodeRole.LOCALNODE ] - datasets_locations = await _get_datasets_locations(local_nodes) - datasets_labels = await _get_datasets_labels(local_nodes) - data_model_cdes_across_nodes = await _get_cdes_across_nodes(local_nodes) + ( + dataset_locations, + aggregated_datasets, + ) = await _gather_all_dataset_infos(local_nodes) + data_model_cdes_per_node = await _get_cdes_across_nodes(local_nodes) compatible_data_models = _get_compatible_data_models( - data_model_cdes_across_nodes + data_model_cdes_per_node ) - data_models = get_updated_data_model_with_dataset_enumerations( - compatible_data_models, datasets_labels + _update_data_models_with_aggregated_datasets( + compatible_data_models, aggregated_datasets + ) + datasets_locations = _get_dataset_locations_of_compatible_data_models( + compatible_data_models, dataset_locations ) - datasets_locations = { - common_data_model: datasets_locations[common_data_model] - for common_data_model in data_models - } - self._node_registry._nodes = { + self._node_registry.nodes = { node_info.id: node_info for node_info in nodes_info } - self._data_model_registry._data_models = data_models - self._data_model_registry._datasets_location = datasets_locations + self._data_model_registry.data_models = compatible_data_models + self._data_model_registry.datasets_location = datasets_locations logger.debug(f"Nodes:{[node for node in self._node_registry.nodes]}") except Exception as exc: logger.error(f"Node Landscape Aggregator exception: {type(exc)}:{exc}") @@ -234,42 +238,50 @@ class NodeLandscapeAggregator(metaclass=Singleton): ) -async def _get_datasets_locations(nodes: List[NodeInfo]) -> Dict[str, Dict[str, str]]: - datasets_locations = {} +async def _gather_all_dataset_infos( + nodes: List[NodeInfo], +) -> Tuple[Dict[str, Dict[str, str]], Dict[str, Dict[str, str]]]: + """ + + Args: + nodes: The nodes available in the system + + Returns: + A tuple with: + 1. The location of each dataset. + 2. The aggregated datasets, existing in all nodes + """ + dataset_locations = {} + aggregated_datasets = {} + for node_info in nodes: node_socket_addr = _get_node_socket_addr(node_info) datasets_per_data_model = await _get_node_datasets_per_data_model( node_socket_addr ) for data_model, datasets in datasets_per_data_model.items(): - current_datasets = ( - datasets_locations[data_model] - if data_model in datasets_locations + + current_labels = ( + aggregated_datasets[data_model] + if data_model in aggregated_datasets else {} ) + current_datasets = ( + dataset_locations[data_model] if data_model in dataset_locations else {} + ) for dataset in datasets: + current_labels[dataset] = datasets[dataset] + if dataset in current_datasets: current_datasets[dataset].append(node_info.id) else: current_datasets[dataset] = [node_info.id] - datasets_locations[data_model] = current_datasets - return datasets_locations + aggregated_datasets[data_model] = current_labels + dataset_locations[data_model] = current_datasets - -async def _get_datasets_labels(nodes: List[NodeInfo]) -> Dict[str, Dict[str, str]]: - datasets_labels = {} - for node_info in nodes: - node_socket_addr = _get_node_socket_addr(node_info) - datasets_per_data_model = await _get_node_datasets_per_data_model( - node_socket_addr - ) - for data_model, datasets in datasets_per_data_model.items(): - datasets_labels[data_model] = {} - for dataset in datasets: - datasets_labels[data_model][dataset] = datasets[dataset] - return datasets_labels + return dataset_locations, aggregated_datasets async def _get_cdes_across_nodes( @@ -289,6 +301,15 @@ async def _get_cdes_across_nodes( return nodes_cdes +def _get_dataset_locations_of_compatible_data_models( + compatible_data_models, dataset_locations +): + return { + compatible_data_model: dataset_locations[compatible_data_model] + for compatible_data_model in compatible_data_models + } + + def _get_compatible_data_models( data_model_cdes_across_nodes: Dict[str, List[Tuple[str, CommonDataElements]]] ) -> Dict[str, CommonDataElements]: @@ -323,10 +344,13 @@ def _get_compatible_data_models( return data_models -def get_updated_data_model_with_dataset_enumerations( +def _update_data_models_with_aggregated_datasets( data_models: Dict[str, CommonDataElements], - datasets_labels: Dict[str, Dict[str, str]], -) -> Dict[str, CommonDataElements]: + aggregated_datasets: Dict[str, Dict[str, str]], +): + """ + Updates each data_model's 'dataset' enumerations with the aggregated datasets + """ for data_model in data_models: dataset_cde = data_models[data_model].values["dataset"] new_dataset_cde = CommonDataElement( @@ -334,9 +358,8 @@ def get_updated_data_model_with_dataset_enumerations( label=dataset_cde.label, sql_type=dataset_cde.sql_type, is_categorical=dataset_cde.is_categorical, - enumerations=datasets_labels[data_model], + enumerations=aggregated_datasets[data_model], min=dataset_cde.min, max=dataset_cde.max, ) data_models[data_model].values["dataset"] = new_dataset_cde - return data_models diff --git a/mipengine/controller/node_registry.py b/mipengine/controller/node_registry.py index de01b1f5814ae2e9eaaf672d3d9358e61e325427..c4091702fe412b8af2b9e41debf96dcb1b675c41 100644 --- a/mipengine/controller/node_registry.py +++ b/mipengine/controller/node_registry.py @@ -6,7 +6,7 @@ from mipengine.node_info_DTOs import NodeRole class NodeRegistry: def __init__(self): - self.nodes: Dict[str, NodeInfo] = {} + self._nodes: Dict[str, NodeInfo] = {} @property def nodes(self) -> Dict[str, NodeInfo]: