diff --git a/PyNutil/io/atlas_loader.py b/PyNutil/io/atlas_loader.py index 0a58f18d917d792a2a464bd9e8cd68c977c2a1fa..deec555dac3faf022fcc68fe2b469b5672541095 100644 --- a/PyNutil/io/atlas_loader.py +++ b/PyNutil/io/atlas_loader.py @@ -3,6 +3,25 @@ import pandas as pd import numpy as np import nrrd +def load_atlas_labels(atlas=None, atlas_name=None): + if atlas_name: + atlas = brainglobe_atlasapi.BrainGlobeAtlas(atlas_name=atlas_name) + if not atlas_name and not atlas: + raise Exception("Either atlas or atlas name must be specified") + atlas_structures = { + "idx": [i["id"] for i in atlas.structures_list], + "name": [i["name"] for i in atlas.structures_list], + "r": [i["rgb_triplet"][0] for i in atlas.structures_list], + "g": [i["rgb_triplet"][1] for i in atlas.structures_list], + "b": [i["rgb_triplet"][2] for i in atlas.structures_list], + } + atlas_structures["idx"].insert(0, 0) + atlas_structures["name"].insert(0, "Clear Label") + atlas_structures["r"].insert(0, 0) + atlas_structures["g"].insert(0, 0) + atlas_structures["b"].insert(0, 0) + atlas_labels = pd.DataFrame(atlas_structures) + return atlas_labels def load_atlas_data(atlas_name): """ @@ -23,20 +42,7 @@ def load_atlas_data(atlas_name): A dataframe containing atlas labels and RGB information. """ atlas = brainglobe_atlasapi.BrainGlobeAtlas(atlas_name=atlas_name) - atlas_structures = { - "idx": [i["id"] for i in atlas.structures_list], - "name": [i["name"] for i in atlas.structures_list], - "r": [i["rgb_triplet"][0] for i in atlas.structures_list], - "g": [i["rgb_triplet"][1] for i in atlas.structures_list], - "b": [i["rgb_triplet"][2] for i in atlas.structures_list], - } - atlas_structures["idx"].insert(0, 0) - atlas_structures["name"].insert(0, "Clear Label") - atlas_structures["r"].insert(0, 0) - atlas_structures["g"].insert(0, 0) - atlas_structures["b"].insert(0, 0) - - atlas_labels = pd.DataFrame(atlas_structures) + atlas_labels = load_atlas_labels(atlas) atlas_volume = process_atlas_volume(atlas.annotation) hemi_map = process_atlas_volume(atlas.hemispheres) print("atlas labels loaded ✅") diff --git a/PyNutil/io/file_operations.py b/PyNutil/io/file_operations.py index 3d67649f89b59612866447f0f8d22546cf3a2783..12c1bd57a42453085d59b5346b19199048c3ceef 100644 --- a/PyNutil/io/file_operations.py +++ b/PyNutil/io/file_operations.py @@ -1,6 +1,6 @@ import os import json -from .read_and_write import write_points_to_meshview +from .read_and_write import write_hemi_points_to_meshview def save_analysis_output( @@ -212,14 +212,14 @@ def _save_per_section_meshview( output_folder, prepend, ): - write_points_to_meshview( + write_hemi_points_to_meshview( pixel_points[prev_pl : pl + prev_pl], labeled_points[prev_pl : pl + prev_pl], points_hemi_labels[prev_pl : pl + prev_pl], f"{output_folder}/per_section_meshview/{prepend}{split_fn}_pixels.json", atlas_labels, ) - write_points_to_meshview( + write_hemi_points_to_meshview( centroids[prev_cl : cl + prev_cl], labeled_points_centroids[prev_cl : cl + prev_cl], centroids_hemi_labels[prev_cl : cl + prev_cl], @@ -239,14 +239,14 @@ def _save_whole_series_meshview( output_folder, prepend, ): - write_points_to_meshview( + write_hemi_points_to_meshview( pixel_points, labeled_points, points_hemi_labels, f"{output_folder}/whole_series_meshview/{prepend}pixels_meshview.json", atlas_labels, ) - write_points_to_meshview( + write_hemi_points_to_meshview( centroids, labeled_points_centroids, centroids_hemi_labels, diff --git a/PyNutil/io/read_and_write.py b/PyNutil/io/read_and_write.py index 2389eaa62ba086accbc07ce39638fff4d75b582e..2e007dc95bad952af3804898bceb04451af3b937 100644 --- a/PyNutil/io/read_and_write.py +++ b/PyNutil/io/read_and_write.py @@ -9,7 +9,7 @@ import numpy as np import struct import cv2 from .reconstruct_dzi import reconstruct_dzi - +from .atlas_loader import load_atlas_labels def open_custom_region_file(path): """ @@ -283,7 +283,7 @@ def write_points(points_dict, filename, info_file): # related to read and write: write_points_to_meshview # this function combines create_region_dict and write_points functions -def write_points_to_meshview(points, point_names, hemi_label, filename, info_file): +def write_hemi_points_to_meshview(points, point_names, hemi_label, filename, info_file): """ Combines point data and region information into MeshView JSON files. @@ -307,16 +307,32 @@ def write_points_to_meshview(points, point_names, hemi_label, filename, info_fil split_fn_left = filename.split("/") split_fn_left[-1] = "left_hemisphere_" + split_fn_left[-1] outname_left = os.sep.join(split_fn_left) - left_region_dict = create_region_dict( - points[hemi_label == 1], point_names[hemi_label == 1] - ) - write_points(left_region_dict, outname_left, info_file) + write_points_to_meshview(points[hemi_label == 1], point_names[hemi_label == 1], outname_left, info_file) split_fn_right = filename.split("/") split_fn_right[-1] = "right_hemisphere_" + split_fn_right[-1] outname_right = os.sep.join(split_fn_right) - right_region_dict = create_region_dict( - points[hemi_label == 2], point_names[hemi_label == 2] - ) - write_points(right_region_dict, outname_right, info_file) - region_dict = create_region_dict(points, point_names) + write_points_to_meshview(points[hemi_label == 2], point_names[hemi_label == 2], outname_right, info_file) + write_points_to_meshview(points, point_names, filename, info_file) + +# related to read and write: write_points_to_meshview +# this function combines create_region_dict and write_points functions +def write_points_to_meshview(points, point_ids, filename, info_file): + """ + Combines point data and region information into MeshView JSON files. + + Parameters + ---------- + points : numpy.ndarray + 2D array containing [N, 3] point coordinates. + point_ids : numpy.ndarray + 1D array of region labels corresponding to each point. + filename : str + Base path for output JSON. Separate hemispheres use prefixed filenames. + info_file : pandas.DataFrame or string + A table with region IDs, names, and color data (r, g, b) for each region. + If string, this should correspond to the relevant brainglobe atlas + """ + if isinstance(info_file, str): + info_file = load_atlas_labels(info_file) + region_dict = create_region_dict(points, point_ids) write_points(region_dict, filename, info_file) diff --git a/PyNutil/processing/transformations.py b/PyNutil/processing/transformations.py index 240fc5903e64c661922d039ae1104eafa21a097c..8bc8fa4921dadadced594394c623f92a83b8fb92 100644 --- a/PyNutil/processing/transformations.py +++ b/PyNutil/processing/transformations.py @@ -46,6 +46,25 @@ def transform_to_atlas_space(anchoring, y, x, reg_height, reg_width): o = np.reshape(o, (3, 1)) return (o + xyz_u + xyz_v).T +def image_to_atlas_space(image, anchoring): + """ + Transforms to atlas space an image using the QuickNII anchoring vector. + + + Args: + image (ndarray): An Image which you would like to apply the anchoring vector to + anchoring (list): Anchoring vector. + + Returns: + ndarray: Transformed coordinates for every pixel in the image. + """ + width = image.shape[1] + height = image.shape[0] + x = np.arange(width) + y = np.arange(height) + x_coords, y_coords = np.meshgrid(x, y) + coordinates = transform_to_atlas_space(anchoring, y_coords.flatten(), x_coords.flatten(), height, width) + return coordinates def get_transformed_coordinates( non_linear,