From 3c8f1173c2d54328e6a81998fad4dba1cc595ee0 Mon Sep 17 00:00:00 2001 From: polarbean <harry.carey95@gmail.com> Date: Tue, 25 Mar 2025 20:38:39 +0100 Subject: [PATCH] update docstrings --- PyNutil/io/atlas_loader.py | 51 ++++++ PyNutil/io/read_and_write.py | 146 ++++++++++++--- PyNutil/io/reconstruct_dzi.py | 13 +- PyNutil/main.py | 60 ++++--- PyNutil/processing/coordinate_extraction.py | 189 ++++++++++++-------- PyNutil/processing/counting_and_load.py | 104 +++++------ PyNutil/processing/data_analysis.py | 34 ++-- PyNutil/processing/generate_target_slice.py | 8 +- PyNutil/processing/transformations.py | 18 +- PyNutil/processing/utils.py | 66 ++++--- 10 files changed, 463 insertions(+), 226 deletions(-) diff --git a/PyNutil/io/atlas_loader.py b/PyNutil/io/atlas_loader.py index 19ce6fc..21ac5d2 100644 --- a/PyNutil/io/atlas_loader.py +++ b/PyNutil/io/atlas_loader.py @@ -5,6 +5,23 @@ import nrrd def load_atlas_data(atlas_name): + """ + Loads atlas data using the brainglobe_atlasapi. + + Parameters + ---------- + atlas_name : str + Name of the atlas to load. + + Returns + ------- + numpy.ndarray + The atlas volume array. + numpy.ndarray + The hemisphere data array. + pandas.DataFrame + A dataframe containing atlas labels and RGB information. + """ atlas = brainglobe_atlasapi.BrainGlobeAtlas(atlas_name=atlas_name) atlas_structures = { "idx": [i["id"] for i in atlas.structures_list], @@ -27,10 +44,44 @@ def load_atlas_data(atlas_name): def process_atlas_volume(vol): + """ + Processes the atlas volume by transposing and reversing axes. + + Parameters + ---------- + vol : numpy.ndarray + The atlas volume to process. + + Returns + ------- + numpy.ndarray + The processed atlas volume. + """ return np.transpose(vol, [2, 0, 1])[::-1, ::-1, ::-1] def load_custom_atlas(atlas_path, hemi_path, label_path): + """ + Loads a custom atlas from provided file paths. + + Parameters + ---------- + atlas_path : str + Path to the custom atlas volume file. + hemi_path : str or None + Path to the hemisphere file, if any. + label_path : str + Path to the label CSV file for region info. + + Returns + ------- + numpy.ndarray + The loaded atlas volume. + numpy.ndarray or None + The hemisphere array, or None if hemi_path is not provided. + pandas.DataFrame + A dataframe containing atlas labels. + """ atlas_volume, _ = nrrd.read(atlas_path) if hemi_path: hemi_volume, _ = nrrd.read(hemi_path) diff --git a/PyNutil/io/read_and_write.py b/PyNutil/io/read_and_write.py index adbe2dd..c33ab17 100644 --- a/PyNutil/io/read_and_write.py +++ b/PyNutil/io/read_and_write.py @@ -13,20 +13,31 @@ from .reconstruct_dzi import reconstruct_dzi def open_custom_region_file(path): """ - Opens a custom region file created by QCAlign or manually by the user. + Opens a custom region file (TSV or XLSX) and returns both a dictionary + of region information and a corresponding pandas DataFrame. + + The dictionary contains: + - 'custom_ids': The unique IDs for each region. + - 'custom_names': The region names. + - 'rgb_values': The RGB values for each region. + - 'subregion_ids': Lists of underlying atlas IDs for each region. + + The returned DataFrame has columns: + - idx: The unique IDs for each region. + - name: The region names. + - r, g, b: The RGB values for each region. Parameters ---------- path : str - the path to the TSV or XLSX file containing the custom region mappings. - If the file extension is not XLSX we will assume it is TSV. By default - QCAlign exports TSV files with a TXT extension. + Path to the TSV or XLSX file containing region information. Returns - ---------- - custom_region_to_dict : dict - - + ------- + dict + A dictionary with region data fields ('custom_ids', 'custom_names', 'rgb_values', and 'subregion_ids'). + pandas.DataFrame + A DataFrame summarizing the region IDs, names, and RGB information. """ if path.lower().endswith(".xlsx"): df = pd.read_excel(path) @@ -78,13 +89,20 @@ def open_custom_region_file(path): def read_flat_file(file): """ - Reads a flat file and returns an image array. + Reads a custom 'flat' file format and returns its contents as a 2D NumPy array. - Args: - file (str): Path to the flat file. + This format includes a header encoding bit-depth (B), width (W), and height (H). + Pixel data follows in a sequence that is unpacked into a 2D array. - Returns: - ndarray: Image array. + Parameters + ---------- + file : str + Path to the flat file to read. + + Returns + ------- + numpy.ndarray + A 2D NumPy array containing image data. """ with open(file, "rb") as f: b, w, h = struct.unpack(">BII", f.read(9)) @@ -99,13 +117,21 @@ def read_flat_file(file): def read_seg_file(file): """ - Reads a segmentation file and returns an image array. + Reads a segmentation file encoded with a specialized format (SegRLEv1), + decodes it, and returns a 2D NumPy array representing segment labels. - Args: - file (str): Path to the segmentation file. + The file contains a header with atlas information and compression codes + that are used to rebuild the segmentation data. - Returns: - ndarray: Image array. + Parameters + ---------- + file : str + Path to the segmentation file. + + Returns + ------- + numpy.ndarray + A 2D NumPy array with segmentation labels. """ with open(file, "rb") as f: @@ -135,13 +161,18 @@ def read_seg_file(file): def load_segmentation(segmentation_path: str): """ - Loads a segmentation from a file. + Loads segmentation data either from a '.dzip' file or a standard image file. + If the file ends with '.dzip', it will be processed as a DZI using 'reconstruct_dzi'. - Args: - segmentation_path (str): Path to the segmentation file. + Parameters + ---------- + segmentation_path : str + Path to the segmentation file (supports '.dzip' or common image formats). - Returns: - ndarray: Segmentation array. + Returns + ------- + numpy.ndarray + A 2D or 3D segmentation array, depending on the file contents. """ if segmentation_path.endswith(".dzip"): return reconstruct_dzi(segmentation_path) @@ -152,6 +183,24 @@ def load_segmentation(segmentation_path: str): # related to read and write # this function reads a VisuAlign JSON and returns the slices def load_visualign_json(filename, apply_damage_mask): + """ + Reads a VisuAlign JSON file (.waln or .wwrp) and extracts slice information. + Slices may include anchoring, grid spacing, and other image metadata. + + Parameters + ---------- + filename : str + The path to the VisuAlign JSON file. + apply_damage_mask : bool + If True, retains 'grid' data in slices; if False, removes it. + + Returns + ------- + list + A list of slice dictionaries containing anchoring and other metadata. + float or None + Grid spacing if found, otherwise None. + """ with open(filename) as f: vafile = json.load(f) if filename.endswith(".waln") or filename.endswith("wwrp"): @@ -177,7 +226,22 @@ def load_visualign_json(filename, apply_damage_mask): # related to read_and_write, used in write_points_to_meshview # this function returns a dictionary of region names def create_region_dict(points, regions): - """points is a list of points and regions is an id for each point""" + """ + Groups point coordinates by their region labels and + returns a dictionary mapping each region to its 3D point list. + + Parameters + ---------- + points : numpy.ndarray + A (N, 3) array of 3D coordinates for all points. + regions : numpy.ndarray + A 1D array of integer region labels for each point. + + Returns + ------- + dict + Keys are unique region labels, and values are the flattened [x, y, z, ...] coordinates. + """ region_dict = { region: points[regions == region].flatten().tolist() for region in np.unique(regions) @@ -188,6 +252,21 @@ def create_region_dict(points, regions): # related to read and write: write_points # this function writes the region dictionary to a meshview json def write_points(points_dict, filename, info_file): + """ + Saves a region-based point dictionary to a MeshView-compatible JSON layout. + + Each region is recorded with: index (idx), name, color components (r, g, b), + and a count of how many points belong to that region. + + Parameters + ---------- + points_dict : dict + Keys are region IDs, values are flattened 3D coordinates. + filename : str + Destination JSON file to be written. + info_file : pandas.DataFrame + A table with region IDs, names, and color data (r, g, b) for each region. + """ meshview = [ { "idx": idx, @@ -208,6 +287,25 @@ def write_points(points_dict, filename, info_file): # related to read and write: write_points_to_meshview # this function combines create_region_dict and write_points functions def write_points_to_meshview(points, point_names, hemi_label, filename, info_file): + """ + Combines point data and region information into MeshView JSON files. + + If hemisphere labels are provided (1, 2), separate outputs are saved for + left and right hemispheres, as well as one file containing all points. + + Parameters + ---------- + points : numpy.ndarray + 2D array containing [N, 3] point coordinates. + point_names : numpy.ndarray + 1D array of region labels corresponding to each point. + hemi_label : numpy.ndarray + 1D array with hemisphere labels (1 for left, 2 for right), or None. + filename : str + Base path for output JSON. Separate hemispheres use prefixed filenames. + info_file : pandas.DataFrame + A table with region IDs, names, and color data (r, g, b) for each region. + """ if not (hemi_label == None).all(): split_fn_left = filename.split('/') split_fn_left[-1] = "left_hemisphere_" + split_fn_left[-1] diff --git a/PyNutil/io/reconstruct_dzi.py b/PyNutil/io/reconstruct_dzi.py index bbdff40..17ed267 100644 --- a/PyNutil/io/reconstruct_dzi.py +++ b/PyNutil/io/reconstruct_dzi.py @@ -8,8 +8,17 @@ import xmltodict def reconstruct_dzi(zip_file_path): """ Reconstructs a Deep Zoom Image (DZI) from a zip file containing the tiles. - :param zip_file_path: Path to the zip file containing the tiles. - :return: The reconstructed DZI. + Parameters + ---------- + zip_file_path : str + Path to the zip file containing the tiles. + apply_damage_mask : bool + Whether to apply the damage mask. + + Returns + ------- + ndarray + The reconstructed DZI. """ with zipfile.ZipFile(zip_file_path, "r") as zip_file: # Get the highest level of the pyramid diff --git a/PyNutil/main.py b/PyNutil/main.py index 0c1fc8e..4bee193 100644 --- a/PyNutil/main.py +++ b/PyNutil/main.py @@ -11,18 +11,18 @@ from .processing.coordinate_extraction import folder_to_atlas_space class PyNutil: """ - A class used to perform brain-wide quantification and spatial analysis of features in serial section images. + A class to perform brain-wide quantification and spatial analysis of serial section images. Methods ------- - __init__(self, segmentation_folder=None, alignment_json=None, colour=None, atlas_name=None, atlas_path=None, label_path=None, settings_file=None) - Initializes the PyNutil class with the given parameters. - get_coordinates(self, non_linear=True, object_cutoff=0, use_flat=False) - Extracts pixel coordinates from the segmentation data. - quantify_coordinates(self) - Quantifies the pixel coordinates by region. - save_analysis(self, output_folder) - Saves the pixel coordinates and pixel counts to different files in the specified output folder. + __init__(...) + Initialize the PyNutil class with segmentation, alignment, atlas and region settings. + get_coordinates(...) + Extract and transform pixel coordinates from segmentation files. + quantify_coordinates() + Quantify pixel and centroid counts by atlas regions. + save_analysis(output_folder) + Save the analysis output to the specified directory. """ def __init__( @@ -132,20 +132,18 @@ class PyNutil: def get_coordinates(self, non_linear=True, object_cutoff=0, use_flat=False, apply_damage_mask=True): """ - Extracts pixel coordinates from the segmentation data. + Retrieves pixel and centroid coordinates from segmentation data, + applies atlas-space transformations, and optionally uses a damage + mask if specified. - Parameters - ---------- - non_linear : bool, optional - Whether to use non-linear transformation from the VisuAlign markers (default is True). - object_cutoff : int, optional - The minimum size of objects to be considered (default is 0). - use_flat : bool, optional - Whether to use flat file atlas maps exported from QuickNII or VisuAlign. This is usually not needed since we can calculate them automatically. This setting is for testing and compatibility purposes (default is False). + Args: + non_linear (bool, optional): Enable non-linear transformation. + object_cutoff (int, optional): Minimum object size. + use_flat (bool, optional): Use flat maps if True. + apply_damage_mask (bool, optional): Apply damage mask if True. - Returns - ------- - None + Returns: + None: Results are stored in class attributes. """ try: ( @@ -190,16 +188,20 @@ class PyNutil: def quantify_coordinates(self): """ - Quantifies the pixel coordinates by region. + Quantifies and summarizes pixel and centroid coordinates by atlas region, + storing the aggregated results in class attributes. - Returns - ------- - None + Attributes: + label_df (pd.DataFrame): Contains aggregated label information. + per_section_df (list of pd.DataFrame): DataFrames with section-wise statistics. + custom_label_df (pd.DataFrame): Label data enriched with custom regions if custom regions is set. + custom_per_section_df (list of pd.DataFrame): Section-wise stats for custom regions if custom regions is set. - Raises - ------ - ValueError - If get_coordinates has not been run before running quantify_coordinates. + Raises: + ValueError: If required attributes are missing or computation fails. + + Returns: + None """ if not hasattr(self, "pixel_points") and not hasattr(self, "centroids"): raise ValueError( diff --git a/PyNutil/processing/coordinate_extraction.py b/PyNutil/processing/coordinate_extraction.py index cdd260c..6a44490 100644 --- a/PyNutil/processing/coordinate_extraction.py +++ b/PyNutil/processing/coordinate_extraction.py @@ -27,14 +27,14 @@ from .utils import ( def get_centroids_and_area(segmentation, pixel_cut_off=0): """ - Returns the center coordinate of each object in the segmentation. + Retrieves centroids, areas, and pixel coordinates of labeled regions. Args: - segmentation (ndarray): Segmentation array. - pixel_cut_off (int, optional): Pixel cutoff to remove small objects. Defaults to 0. + segmentation (ndarray): Binary segmentation array. + pixel_cut_off (int, optional): Minimum object size threshold. Returns: - tuple: Centroids, area, and coordinates of objects. + tuple: (centroids, area, coords) of retained objects. """ labels = measure.label(segmentation) labels_info = measure.regionprops(labels) @@ -46,6 +46,18 @@ def get_centroids_and_area(segmentation, pixel_cut_off=0): def update_spacing(anchoring, width, height, grid_spacing): + """ + Calculates spacing along width and height from slice anchoring. + + Args: + anchoring (list): Anchoring transformation parameters. + width (int): Image width. + height (int): Image height. + grid_spacing (int): Grid spacing in image units. + + Returns: + tuple: (xspacing, yspacing) + """ if len(anchoring) != 9: print("Anchoring does not have 9 elements.") ow = np.sqrt(sum([anchoring[i + 3] ** 2 for i in range(3)])) @@ -56,6 +68,16 @@ def update_spacing(anchoring, width, height, grid_spacing): def create_damage_mask(section, grid_spacing): + """ + Creates a binary damage mask from grid information in the given section. + + Args: + section (dict): Dictionary with slice and grid data. + grid_spacing (int): Space between grid marks. + + Returns: + ndarray: Binary mask with damaged areas marked as 0. + """ width = section["width"] height = section["height"] anchoring = section["anchoring"] @@ -95,20 +117,22 @@ def folder_to_atlas_space( apply_damage_mask=True ): """ - Applies segmentation to atlas space for all segmentations in a folder. + Processes all segmentation files in a folder, mapping each one to atlas space. Args: - folder (str): Path to the folder. - quint_alignment (str): Path to the QuickNII alignment file. + folder (str): Path to segmentation files. + quint_alignment (str): Path to alignment JSON. atlas_labels (DataFrame): DataFrame with atlas labels. - pixel_id (list, optional): Pixel ID to match. Defaults to [0, 0, 0]. - non_linear (bool, optional): Whether to use non-linear transformation. Defaults to True. - object_cutoff (int, optional): Pixel cutoff to remove small objects. Defaults to 0. - atlas_volume (ndarray, optional): Volume with atlas labels. Defaults to None. - use_flat (bool, optional): Whether to use flat files. Defaults to False. + pixel_id (list, optional): Pixel color to match. + non_linear (bool, optional): Apply non-linear transform. + object_cutoff (int, optional): Minimum object size. + atlas_volume (ndarray, optional): Atlas volume data. + hemi_map (ndarray, optional): Hemisphere mask data. + use_flat (bool, optional): If True, load flat files. + apply_damage_mask (bool, optional): If True, apply damage mask. Returns: - tuple: Points, centroids, region areas list, points length, centroids length, segmentations. + tuple: Various arrays and lists containing transformed coordinates and labels. """ slices, gridspacing = load_visualign_json(quint_alignment, apply_damage_mask) segmentations = get_segmentations(folder) @@ -222,27 +246,33 @@ def create_threads( gridspacing, ): """ - Creates threads for processing segmentations. + Creates threads to transform each segmentation into atlas space. Args: - segmentations (list): List of segmentation files. - slices (list): List of slices. - flat_files (list): List of flat files. - flat_file_nrs (list): List of flat file section numbers. - atlas_labels (DataFrame): DataFrame with atlas labels. - pixel_id (list): Pixel ID to match. - non_linear (bool): Whether to use non-linear transformation. - points_list (list): List to store points. - centroids_list (list): List to store centroids. - centroids_labels(list): List to store centroids labels. - points_labels(list): List to store points labels. - region_areas_list (list): List to store region areas. - object_cutoff (int): Pixel cutoff to remove small objects. - atlas_volume (ndarray): Volume with atlas labels. - use_flat (bool): Whether to use flat files. + segmentations (list): Paths to segmentation files. + slices (list): Slice metadata from alignment JSON. + flat_files (list): Flat file paths for optional flat maps. + flat_file_nrs (list): Numeric indices for flat files. + atlas_labels (DataFrame): Atlas labels. + pixel_id (list): Pixel color defined as [R, G, B]. + non_linear (bool): Enable non-linear transformation if True. + points_list (list): Stores point coordinates per segmentation. + centroids_list (list): Stores centroid coordinates per segmentation. + centroids_labels (list): Stores labels for each centroid array. + points_labels (list): Stores labels for each point array. + region_areas_list (list): Stores region area data per segmentation. + per_point_undamaged_list (list): Track undamaged points per segmentation. + per_centroid_undamaged_list (list): Track undamaged centroids per segmentation. + point_hemi_labels (list): Hemisphere labels for points. + centroid_hemi_labels (list): Hemisphere labels for centroids. + object_cutoff (int): Minimum object size threshold. + atlas_volume (ndarray): 3D atlas volume (optional). + hemi_map (ndarray): Hemisphere mask (optional). + use_flat (bool): Use flat files if True. + gridspacing (int): Spacing value from alignment data. Returns: - list: List of threads. + list: A list of threads for parallel execution. """ threads = [] for segmentation_path, index in zip(segmentations, range(len(segmentations))): @@ -291,13 +321,13 @@ def create_threads( def load_segmentation(segmentation_path: str): """ - Loads a segmentation from a file. + Loads segmentation data, handling .dzip files if necessary. Args: - segmentation_path (str): Path to the segmentation file. + segmentation_path (str): File path. Returns: - ndarray: Segmentation array. + ndarray: Image array of the segmentation. """ if segmentation_path.endswith(".dzip"): return reconstruct_dzi(segmentation_path) @@ -307,13 +337,13 @@ def load_segmentation(segmentation_path: str): def detect_pixel_id(segmentation: np.array): """ - Removes the background from the segmentation and returns the pixel ID. + Infers pixel color from the first non-background region. Args: segmentation (ndarray): Segmentation array. Returns: - ndarray: Pixel ID. + ndarray: Identified pixel color (RGB). """ segmentation_no_background = segmentation[~np.all(segmentation == 0, axis=2)] pixel_id = segmentation_no_background[0] @@ -334,20 +364,22 @@ def get_region_areas( damage_mask, ): """ - Gets the region areas. + Builds the atlas map for a slice and calculates the region areas. Args: - use_flat (bool): Whether to use flat files. - atlas_labels (DataFrame): DataFrame with atlas labels. - flat_file_atlas (str): Path to the flat file atlas. - seg_width (int): Segmentation width. - seg_height (int): Segmentation height. - slice_dict (dict): Dictionary with slice information. - atlas_volume (ndarray): Volume with atlas labels. - triangulation (ndarray): Triangulation data. + use_flat (bool): If True, uses flat files. + atlas_labels (DataFrame): DataFrame containing atlas labels. + flat_file_atlas (str): Path to the flat atlas file. + seg_width (int): Segmentation image width. + seg_height (int): Segmentation image height. + slice_dict (dict): Dictionary with slice metadata (anchoring, etc.). + atlas_volume (ndarray): 3D atlas volume. + hemi_mask (ndarray): Hemisphere mask. + triangulation (ndarray): Triangulation data for non-linear transforms. + damage_mask (ndarray): Binary damage mask. Returns: - DataFrame: DataFrame with region areas. + tuple: (DataFrame of region areas, atlas map array). """ atlas_map = load_image( flat_file_atlas, @@ -385,22 +417,33 @@ def segmentation_to_atlas_space( grid_spacing=None, ): """ - Converts a segmentation to atlas space. + Transforms a single segmentation file into atlas space. Args: - slice_dict (dict): Dictionary with slice information. + slice_dict (dict): Slice information from alignment JSON. segmentation_path (str): Path to the segmentation file. - atlas_labels (DataFrame): DataFrame with atlas labels. - flat_file_atlas (str, optional): Path to the flat file atlas. Defaults to None. - pixel_id (str, optional): Pixel ID to match. Defaults to "auto". - non_linear (bool, optional): Whether to use non-linear transformation. Defaults to True. - points_list (list, optional): List to store points. Defaults to None. - centroids_list (list, optional): List to store centroids. Defaults to None. - region_areas_list (list, optional): List to store region areas. Defaults to None. - index (int, optional): Index of the current segmentation. Defaults to None. - object_cutoff (int, optional): Pixel cutoff to remove small objects. Defaults to 0. - atlas_volume (ndarray, optional): Volume with atlas labels. Defaults to None. - use_flat (bool, optional): Whether to use flat files. Defaults to False. + atlas_labels (DataFrame): Atlas labels. + flat_file_atlas (str, optional): Path to flat atlas, if using flat files. + pixel_id (str or list, optional): Pixel color or 'auto'. + non_linear (bool, optional): Use non-linear transforms if True. + points_list (list, optional): Storage for transformed point coordinates. + centroids_list (list, optional): Storage for transformed centroid coordinates. + points_labels (list, optional): Storage for assigned point labels. + centroids_labels (list, optional): Storage for assigned centroid labels. + region_areas_list (list, optional): Storage for region area data. + per_point_undamaged_list (list, optional): Track undamaged points. + per_centroid_undamaged_list (list, optional): Track undamaged centroids. + points_hemi_labels (list, optional): Hemisphere labels for points. + centroids_hemi_labels (list, optional): Hemisphere labels for centroids. + index (int, optional): Index in the lists. + object_cutoff (int, optional): Minimum object size. + atlas_volume (ndarray, optional): 3D atlas volume. + hemi_map (ndarray, optional): Hemisphere mask. + use_flat (bool, optional): Indicates use of flat files. + grid_spacing (int, optional): Spacing value for damage mask. + + Returns: + None """ segmentation = load_segmentation(segmentation_path) if pixel_id == "auto": @@ -521,16 +564,16 @@ def segmentation_to_atlas_space( def get_triangulation(slice_dict, reg_width, reg_height, non_linear): """ - Gets the triangulation for the slice. + Generates triangulation data if non-linear markers exist. Args: - slice_dict (dict): Dictionary with slice information. + slice_dict (dict): Slice metadata from alignment JSON. reg_width (int): Registration width. reg_height (int): Registration height. - non_linear (bool): Whether to use non-linear transformation. + non_linear (bool): Whether to use non-linear transform. Returns: - list: Triangulation data. + list or None: Triangulation info or None if not applicable. """ if non_linear and "markers" in slice_dict: return triangulate(reg_width, reg_height, slice_dict["markers"]) @@ -539,17 +582,17 @@ def get_triangulation(slice_dict, reg_width, reg_height, non_linear): def get_centroids(segmentation, pixel_id, y_scale, x_scale, object_cutoff=0): """ - Gets the centroids of objects in the segmentation. + Finds object centroids for a given pixel color and applies scaling. Args: segmentation (ndarray): Segmentation array. - pixel_id (int): Pixel ID to match. - y_scale (float): Y scaling factor. - x_scale (float): X scaling factor. - object_cutoff (int, optional): Pixel cutoff to remove small objects. Defaults to 0. + pixel_id (int): Pixel color to match. + y_scale (float): Vertical scaling factor. + x_scale (float): Horizontal scaling factor. + object_cutoff (int, optional): Minimum object size. Returns: - tuple: Centroids, scaled X coordinates, and scaled Y coordinates. + tuple: (centroids, scaled_centroidsX, scaled_centroidsY) """ binary_seg = segmentation == pixel_id binary_seg = np.all(binary_seg, axis=2) @@ -568,16 +611,16 @@ def get_centroids(segmentation, pixel_id, y_scale, x_scale, object_cutoff=0): def get_scaled_pixels(segmentation, pixel_id, y_scale, x_scale): """ - Gets the scaled pixel coordinates. + Retrieves pixel coordinates for a specified color and scales them. Args: segmentation (ndarray): Segmentation array. - pixel_id (int): Pixel ID to match. - y_scale (float): Y scaling factor. - x_scale (float): X scaling factor. + pixel_id (int): Pixel color to match. + y_scale (float): Vertical scaling factor. + x_scale (float): Horizontal scaling factor. Returns: - tuple: Scaled Y and X coordinates. + tuple: (scaled_y, scaled_x) """ id_pixels = find_matching_pixels(segmentation, pixel_id) if len(id_pixels[0]) == 0: diff --git a/PyNutil/processing/counting_and_load.py b/PyNutil/processing/counting_and_load.py index e4b68b1..e7b7107 100644 --- a/PyNutil/processing/counting_and_load.py +++ b/PyNutil/processing/counting_and_load.py @@ -7,13 +7,14 @@ from .visualign_deformations import transform_vec def create_base_counts_dict(with_hemisphere=False, with_damage=False): """ - Creates the base dictionary structure for counts. + Creates and returns a base dictionary structure for tracking counts. Args: - with_hemisphere (bool): Whether to include hemisphere-specific fields + with_hemisphere (bool): If True, include hemisphere fields. + with_damage (bool): If True, include damage fields. Returns: - dict: Base dictionary with count fields + dict: Structure containing count lists for pixels/objects. """ counts = { "idx": [], @@ -66,15 +67,20 @@ def pixel_count_per_region( with_damage=False ): """ - Counts the number of pixels per region and writes to a DataFrame. + Tally object counts by region, optionally tracking damage and hemispheres. Args: - labels_dict_points (dict): Dictionary with region as key and points as value. - labeled_dict_centroids (dict): Dictionary with region as key and centroids as value. - df_label_colours (DataFrame): DataFrame with label colours. + labels_dict_points (dict): Maps points to region labels. + labeled_dict_centroids (dict): Maps centroids to region labels. + current_points_undamaged (ndarray): Undamaged-state flags for points. + current_centroids_undamaged (ndarray): Undamaged-state flags for centroids. + current_points_hemi (ndarray): Hemisphere tags for points. + current_centroids_hemi (ndarray): Hemisphere tags for centroids. + df_label_colours (DataFrame): Region label colors. + with_damage (bool, optional): Track damage counts if True. Returns: - DataFrame: DataFrame with counts and colours per region. + DataFrame: Summed counts per region. """ with_hemi = None not in current_points_hemi counts_per_label = create_base_counts_dict(with_hemisphere=with_hemi, with_damage=with_damage) @@ -336,19 +342,16 @@ def pixel_count_per_region( return df_counts_per_label -"""Read flat file and write into an np array""" -"""Read flat file, write into an np array, assign label file values, return array""" - def read_flat_file(file): """ - Reads a flat file and returns an image array. + Reads a flat file and produces an image array. Args: file (str): Path to the flat file. Returns: - ndarray: Image array. + ndarray: Image array extracted from the file. """ with open(file, "rb") as f: b, w, h = struct.unpack(">BII", f.read(9)) @@ -363,13 +366,13 @@ def read_flat_file(file): def read_seg_file(file): """ - Reads a segmentation file and returns an image array. + Reads a segmentation file into an image array. Args: file (str): Path to the segmentation file. Returns: - ndarray: Image array. + ndarray: The segmentation image. """ with open(file, "rb") as f: @@ -399,14 +402,14 @@ def read_seg_file(file): def rescale_image(image, rescaleXY): """ - Rescales an image. + Rescales an image to the specified dimensions. Args: - image (ndarray): Image array. - rescaleXY (tuple): Tuple with new width and height. + image (ndarray): Input image array. + rescaleXY (tuple): (width, height) as new size. Returns: - ndarray: Rescaled image. + ndarray: The rescaled image. """ w, h = rescaleXY return cv2.resize(image, (h, w), interpolation=cv2.INTER_NEAREST) @@ -414,11 +417,11 @@ def rescale_image(image, rescaleXY): def assign_labels_to_image(image, labelfile): """ - Assigns labels to an image based on a label file. + Assigns atlas or region labels to an image array. Args: - image (ndarray): Image array. - labelfile (DataFrame): DataFrame with label information. + image (ndarray): Image array to label. + labelfile (DataFrame): Contains label IDs in the 'idx' column. Returns: ndarray: Image with assigned labels. @@ -436,14 +439,14 @@ def assign_labels_to_image(image, labelfile): def count_pixels_per_label(image, scale_factor=False): """ - Counts the number of pixels per label in an image. + Counts the pixels associated with each label in an image. Args: - image (ndarray): Image array. - scale_factor (bool, optional): Whether to apply a scaling factor. Defaults to False. + image (ndarray): Image array containing labels. + scale_factor (bool, optional): Apply scaling if True. Returns: - DataFrame: DataFrame with pixel counts per label. + DataFrame: Table of label IDs and pixel counts. """ unique_ids, counts = np.unique(image, return_counts=True) if scale_factor: @@ -455,15 +458,15 @@ def count_pixels_per_label(image, scale_factor=False): def warp_image(image, triangulation, rescaleXY): """ - Warps an image based on triangulation. + Warps an image using triangulation, applying optional resizing. Args: - image (ndarray): Image array. - triangulation (ndarray): Triangulation data. - rescaleXY (tuple, optional): Tuple with new dimensions. Defaults to None. + image (ndarray): Image array to be warped. + triangulation (ndarray): Triangulation data for remapping. + rescaleXY (tuple, optional): (width, height) for resizing. Returns: - ndarray: Warped image. + ndarray: The warped image array. """ if rescaleXY is not None: w, h = rescaleXY @@ -500,19 +503,17 @@ def warp_image(image, triangulation, rescaleXY): def flat_to_dataframe(image, damage_mask, hemi_mask, rescaleXY=None): """ - Converts a flat file to a DataFrame. + Builds a DataFrame from an image, incorporating optional damage/hemisphere masks. Args: - labelfile (DataFrame): DataFrame with label information. - file (str, optional): Path to the flat file. Defaults to None. - rescaleXY (tuple, optional): Tuple with new dimensions. Defaults to None. - image_vector (ndarray, optional): Image vector. Defaults to None. - volume (ndarray, optional): Volume data. Defaults to None. - triangulation (ndarray, optional): Triangulation data. Defaults to None. + image (ndarray): Source image with label IDs. + damage_mask (ndarray): Binary mask indicating damaged areas. + hemi_mask (ndarray): Binary mask for hemisphere assignment. + rescaleXY (tuple, optional): (width, height) for resizing. Returns: - DataFrame: DataFrame with area per label. - np.array: array in shape of alignment XY scaled by rescaleXY with allen ID for each point + DataFrame: Pixel counts grouped by label. + ndarray: Scaled label map of the image. """ scale_factor = calculate_scale_factor(image, rescaleXY) df_area_per_label = pd.DataFrame(columns=["idx"]) @@ -599,17 +600,18 @@ def flat_to_dataframe(image, damage_mask, hemi_mask, rescaleXY=None): def load_image(file, image_vector, volume, triangulation, rescaleXY, labelfile=None): """ - Loads an image from a file or generates it from a vector and volume. + Loads an image from file or transforms a preloaded array, optionally applying warping. Args: - file (str): Path to the file. - image_vector (ndarray): Image vector. - volume (ndarray): Volume data. - triangulation (ndarray): Triangulation data. - rescaleXY (tuple): Tuple with new dimensions. + file (str): File path for the source image. + image_vector (ndarray): Preloaded image data array. + volume (ndarray): Atlas volume or similar data. + triangulation (ndarray): Triangulation data for warping. + rescaleXY (tuple): (width, height) for resizing. + labelfile (DataFrame, optional): Label definitions. Returns: - ndarray: Loaded or generated image. + ndarray: The loaded or transformed image. """ if image_vector is not None and volume is not None: image = generate_target_slice(image_vector, volume) @@ -628,14 +630,14 @@ def load_image(file, image_vector, volume, triangulation, rescaleXY, labelfile=N def calculate_scale_factor(image, rescaleXY): """ - Calculates the scale factor for an image. + Computes a factor for resizing if needed. Args: - image (ndarray): Image array. - rescaleXY (tuple): Tuple with new dimensions. + image (ndarray): Original image array. + rescaleXY (tuple): (width, height) for potential resizing. Returns: - float: Scale factor. + float or bool: Scale factor or False if not applicable. """ if rescaleXY: image_shapeY, image_shapeX = image.shape[0], image.shape[1] diff --git a/PyNutil/processing/data_analysis.py b/PyNutil/processing/data_analysis.py index a18bdb7..bbe1ba4 100644 --- a/PyNutil/processing/data_analysis.py +++ b/PyNutil/processing/data_analysis.py @@ -4,6 +4,16 @@ import numpy as np def map_to_custom_regions(custom_regions_dict, points_labels): + """ + Reassigns atlas-region labels into user-defined custom regions. + + Args: + atlas_labeled_points (DataFrame): DataFrame of points, each with atlas labels. + custom_region_map (dict): Mapping of atlas region IDs to custom IDs. + + Returns: + DataFrame: Points with updated region assignments. + """ custom_points_labels = np.zeros_like(points_labels) for i in np.unique(points_labels): new_id = np.where([i in r for r in custom_regions_dict["subregion_ids"]])[0] @@ -18,6 +28,16 @@ def map_to_custom_regions(custom_regions_dict, points_labels): def apply_custom_regions(df, custom_regions_dict): + """ + Applies a custom region definition to the image's region labels. + + Args: + image (ndarray): The image array whose regions are being remapped. + custom_region_file (str): File path or identifier for custom region definitions. + + Returns: + ndarray: The image with modified region labels. + """ # Create mappings id_mapping = {} name_mapping = {} @@ -129,21 +149,15 @@ def quantify_labeled_points( apply_damage_mask ): """ - Quantifies labeled points and returns various DataFrames. + Aggregates labeled points into a summary table. Args: - pixel_points (ndarray): Array of pixel points. - centroids (ndarray): Array of centroids. - points_len (list): List of lengths of points per section. - centroids_len (list): List of lengths of centroids per section. - region_areas_list (list): List of region areas per section. - atlas_labels (DataFrame): DataFrame with atlas labels. - atlas_volume (ndarray): Volume with atlas labels. + points (ndarray): Array of point coordinates and labels. + atlas_labels (DataFrame): DataFrame containing atlas region labels. Returns: - tuple: Labeled points, labeled centroids, label DataFrame, per section DataFrame. + DataFrame: Summarized point counts per region. """ - per_section_df = _quantify_per_section( labeled_points, labeled_points_centroids, diff --git a/PyNutil/processing/generate_target_slice.py b/PyNutil/processing/generate_target_slice.py index a8bc9c5..c07f526 100644 --- a/PyNutil/processing/generate_target_slice.py +++ b/PyNutil/processing/generate_target_slice.py @@ -4,14 +4,14 @@ import math def generate_target_slice(ouv, atlas): """ - Generates a target slice from the atlas using the given orientation and vectors. + Generate a 2D slice from a 3D atlas based on orientation vectors. Args: - ouv (list): List containing origin and vectors [ox, oy, oz, ux, uy, uz, vx, vy, vz]. - atlas (ndarray): 3D atlas array. + ouv (list): Orientation vector [ox, oy, oz, ux, uy, uz, vx, vy, vz]. + atlas (ndarray): 3D atlas volume. Returns: - ndarray: 2D slice generated from the atlas. + ndarray: 2D slice extracted from the atlas. """ ox, oy, oz, ux, uy, uz, vx, vy, vz = ouv width = np.floor(math.hypot(ux, uy, uz)).astype(int) + 1 diff --git a/PyNutil/processing/transformations.py b/PyNutil/processing/transformations.py index 62b8ddb..240fc59 100644 --- a/PyNutil/processing/transformations.py +++ b/PyNutil/processing/transformations.py @@ -57,18 +57,18 @@ def get_transformed_coordinates( triangulation, ): """ - Gets the transformed coordinates. + Compute transformed coordinates for both points and centroids. Args: - non_linear (bool): Whether to use non-linear transformation. - slice_dict (dict): Dictionary with slice information. - scaled_x (ndarray): Scaled X coordinates. - scaled_y (ndarray): Scaled Y coordinates. - centroids (ndarray): Centroids. - scaled_centroidsX (ndarray): Scaled X coordinates of centroids. - scaled_centroidsY (ndarray): Scaled Y coordinates of centroids. - triangulation (ndarray): Triangulation data. + non_linear (bool): Flag to indicate if non-linear transformation is applied. + slice_dict (dict): Slice metadata including markers. + scaled_x (ndarray): Scaled x coordinates for points. + scaled_y (ndarray): Scaled y coordinates for points. + scaled_centroidsX (ndarray): Scaled x coordinates for centroids. + scaled_centroidsY (ndarray): Scaled y coordinates for centroids. + triangulation (list): Triangulation structure to be used if non_linear is True. + Returns: Returns: tuple: Transformed coordinates. """ diff --git a/PyNutil/processing/utils.py b/PyNutil/processing/utils.py index 344f84d..697ee6a 100644 --- a/PyNutil/processing/utils.py +++ b/PyNutil/processing/utils.py @@ -6,14 +6,14 @@ from glob import glob def number_sections(filenames, legacy=False): """ - Returns the section numbers of filenames. + Extract section numbers from a list of filenames. Args: - filenames (list): List of filenames. - legacy (bool, optional): Whether to use legacy mode. Defaults to False. + filenames (list): List of file paths. + legacy (bool, optional): Use a legacy extraction mode if True. Defaults to False. Returns: - list: List of section numbers. + list: List of section numbers as integers. """ filenames = [filename.split("\\")[-1] for filename in filenames] section_numbers = [] @@ -92,13 +92,13 @@ def calculate_scale_factor(image, rescaleXY): def get_segmentations(folder): """ - Gets the list of segmentation files in the folder. + Collects segmentation file paths from the specified folder. Args: - folder (str): Path to the folder. + folder (str): Path to the folder containing segmentations. Returns: - list: List of segmentation files. + list: List of segmentation file paths. """ segmentation_file_types = [".png", ".tif", ".tiff", ".jpg", ".jpeg", ".dzip"] segmentations = [ @@ -114,16 +114,16 @@ def get_segmentations(folder): return segmentations -def get_flat_files(folder, use_flat): +def get_flat_files(folder, use_flat=False): """ - Gets the list of flat files in the folder. + Retrieves flat file paths from the given folder. Args: - folder (str): Path to the folder. - use_flat (bool): Whether to use flat files. + folder (str): Path to the folder containing flat files. + use_flat (bool, optional): If True, filter only flat files. Returns: - tuple: List of flat files and their section numbers. + tuple: A list of flat file paths and their numeric indices. """ if use_flat: flat_files = [ @@ -169,16 +169,16 @@ def initialize_lists(length): def get_current_flat_file(seg_nr, flat_files, flat_file_nrs, use_flat): """ - Gets the current flat file for the given section number. + Determines the correct flat file for a given section number. Args: - seg_nr (int): Section number. - flat_files (list): List of flat files. - flat_file_nrs (list): List of flat file section numbers. - use_flat (bool): Whether to use flat files. + seg_nr (int): Numeric index of the segmentation. + flat_files (list): List of flat file paths. + flat_file_nrs (list): Numeric indices for each flat file. + use_flat (bool): If True, attempts to match flat files to segments. Returns: - str: Path to the current flat file. + str or None: The matched flat file path, or None if not found or unused. """ if use_flat: current_flat_file_index = np.where([f == seg_nr for f in flat_file_nrs]) @@ -188,10 +188,13 @@ def get_current_flat_file(seg_nr, flat_files, flat_file_nrs, use_flat): def start_and_join_threads(threads): """ - Starts and joins the threads. + Starts a list of threads and joins them to ensure completion. Args: - threads (list): List of threads. + threads (list): A list of threading.Thread objects. + + Returns: + None """ [t.start() for t in threads] [t.join() for t in threads] @@ -208,14 +211,29 @@ def process_results( centroids_undamaged_list, ): """ - Processes the results from the threads. + Consolidates and organizes results from multiple segmentations. Args: - points_list (list): List of points. - centroids_list (list): List of centroids. + points_list (list): A list of arrays containing point coordinates. + centroids_list (list): A list of arrays containing centroid coordinates. + points_labels (list): A list of arrays with labels for each point. + centroids_labels (list): A list of arrays with labels for each centroid. + points_hemi_labels (list): A list of arrays storing hemisphere info per point. + centroids_hemi_labels (list): A list of arrays storing hemisphere info per centroid. + points_undamaged_list (list): Tracks undamaged status of each point. + centroids_undamaged_list (list): Tracks undamaged status of each centroid. Returns: - tuple: Processed points, centroids, points length, and centroids length. + points (ndarray): Consolidated point coordinates. + centroids (ndarray): Consolidated centroid coordinates. + points_labels (ndarray): Combined labels for points. + centroids_labels (ndarray): Combined labels for centroids. + points_hemi_labels (ndarray): Combined hemisphere info for points. + centroids_hemi_labels (ndarray): Combined hemisphere info for centroids. + points_len (int): Total number of points. + centroids_len (int): Total number of centroids. + points_undamaged (ndarray): Updated track of undamaged status for points. + centroids_undamaged (ndarray): Updated track of undamaged status for centroids. """ points_len = [len(points) if None not in points else 0 for points in points_list] centroids_len = [ -- GitLab