Skip to content
Snippets Groups Projects
Commit 94b5c48d authored by polarbean's avatar polarbean
Browse files

modularise main.py

parent d6bb97ff
No related branches found
No related tags found
No related merge requests found
from .read_and_write import read_atlas_volume, write_points_to_meshview
from .coordinate_extraction import folder_to_atlas_space
from .counting_and_load import label_points, pixel_count_per_region
import json
import pandas as pd
import os
from datetime import datetime
import numpy as np
import brainglobe_atlasapi
import os
import numpy as np
import pandas as pd
from .coordinate_extraction import folder_to_atlas_space
from .counting_and_load import label_points, pixel_count_per_region
from .read_and_write import read_atlas_volume, write_points_to_meshview
class PyNutil:
"""A utility class for working with brain atlases and segmentation data.
Parameters
----------
segmentation_folder : str
The path to the folder containing the segmentation data.
alignment_json : str
The path to the alignment JSON file.
colour : int
The colour of the segmentation data to extract.
atlas_name : str
The name of the atlas volume to use. Uses BrainGlobe API name
settings_file : str, optional
The path to a JSON file containing the above parameters.
Raises
------
ValueError
If any of the required parameters are None.
Attributes
----------
segmentation_folder : str
The path to the folder containing the segmentation data.
alignment_json : str
The path to the alignment JSON file.
colour : int
The colour of the segmentation data to extract.
atlas : str
The name of the atlas volume being used.
atlas_volume : numpy.ndarray
The 3D array representing the atlas volume.
atlas_labels : pandas.DataFrame
A DataFrame containing the labels for the atlas volume.
pixel_points : numpy.ndarray
An array of pixel coordinates extracted from the segmentation data.
labeled_points : numpy.ndarray
An array of labeled pixel coordinates.
label_df : pandas.DataFrame
A DataFrame containing the pixel counts per region.
Methods
-------
load_atlas_data()
Loads the atlas volume and labels from disk.
get_coordinates(non_linear=True, method='all')
Extracts pixel coordinates from the segmentation data.
extract_coordinates(non_linear, method)
Extracts pixel coordinates from the segmentation data but is only used internally.
quantify_coordinates()
Quantifies the pixel coordinates by region.
label_points()
Labels the pixel coordinates by region but is only used internally.
count_pixels_per_region(labeled_points)
Counts the number of pixels per region but is only used internally.
save_analysis(output_folder)
Saves the pixel coordinates and pixel counts to disk.
write_points_to_meshview(output_folder)
Writes the pixel coordinates and labels to a JSON file for visualization but is only used internally.
"""
def __init__(
self,
segmentation_folder=None,
alignment_json=None,
colour=None,
atlas_name=None,
atlas_path=None,
label_path=None,
settings_file=None,
) -> None:
def __init__(self, segmentation_folder=None, alignment_json=None, colour=None,
atlas_name=None, atlas_path=None, label_path=None, settings_file=None):
if settings_file is not None:
with open(settings_file, "r") as f:
settings = json.load(f)
......@@ -94,45 +25,39 @@ class PyNutil:
raise KeyError(
"settings file must contain segmentation_folder, alignment_json, colour, and atlas_name"
) from exc
# check if any values are None
# if None in [segmentation_folder, alignment_json, colour, atlas_name]:
# raise ValueError(
# "segmentation_folder, alignment_json, colour, and volume_path must all be specified and not be None"
# )
# if atlas_name not in self.config["annotation_volumes"]:
# raise ValueError(
# f"Atlas {atlas_name} not found in config file, valid atlases are: \n{' , '.join(list(self.config['annotation_volumes'].keys()))}"
# )
self.segmentation_folder = segmentation_folder
self.alignment_json = alignment_json
self.colour = colour
self.atlas_name = atlas_name
if (atlas_path or label_path) and atlas_name:
raise ValueError(
"Please only specify an atlas_path and a label_path or an atlas_name, atlas and label paths are only used for loading custom atlases"
)
if atlas_path and label_path:
self.atlas_volume, self.atlas_labels = self.load_custom_atlas(
atlas_path, label_path
)
self.atlas_volume, self.atlas_labels = self.load_custom_atlas(atlas_path, label_path)
else:
self.atlas_volume, self.atlas_labels = self.load_atlas_data(
atlas_name=atlas_name
)
###This is just because of the migration to BrainGlobe
def load_atlas_data(self, atlas_name):
"""Loads the atlas volume and labels from disk.
self._check_atlas_name()
self.atlas_volume, self.atlas_labels = self.load_atlas_data(atlas_name=atlas_name)
def _check_atlas_name(self):
if not self.atlas_name:
raise ValueError("Atlas name must be specified")
def _load_settings(self, settings_file):
if settings_file:
with open(settings_file, "r") as f:
settings = json.load(f)
self.segmentation_folder = settings.get("segmentation_folder")
self.alignment_json = settings.get("alignment_json")
self.colour = settings.get("colour")
self.atlas_name = settings.get("atlas_name")
Returns
-------
tuple
A tuple containing the atlas volume as a numpy.ndarray and the atlas labels as a pandas.DataFrame.
"""
# load the metadata json as well as the path to stored data files
# this could potentially be moved into init
def load_atlas_data(self, atlas_name):
"""Loads the atlas volume and labels from disk."""
print("loading atlas volume")
atlas = brainglobe_atlasapi.BrainGlobeAtlas(atlas_name=atlas_name)
atlas_structures = {
......@@ -149,57 +74,27 @@ class PyNutil:
atlas_structures["b"].insert(0, 0)
atlas_labels = pd.DataFrame(atlas_structures)
if "allen_mouse_" in atlas_name:
print("reorienting allen atlas into quicknii space...")
atlas_volume = np.transpose(atlas.annotation, [2, 0, 1])[:, ::-1, ::-1]
else:
atlas_volume = atlas.annotation
atlas_volume = self._process_atlas_volume(atlas)
print("atlas labels loaded ✅")
return atlas_volume, atlas_labels
def _process_atlas_volume(self, atlas):
if "allen_mouse_" in self.atlas_name:
print("reorienting allen atlas into quicknii space...")
return np.transpose(atlas.annotation, [2, 0, 1])[:, ::-1, ::-1]
else:
return atlas.annotation
def load_custom_atlas(self, atlas_path, label_path):
atlas_volume = read_atlas_volume(atlas_path)
atlas_labels = pd.read_csv(label_path)
return atlas_volume, atlas_labels
def get_coordinates(
self, non_linear=True, method="all", object_cutoff=0, use_flat=False
):
"""Extracts pixel coordinates from the segmentation data.
Parameters
----------
non_linear : bool, optional
Whether to use non-linear registration. Default is True.
method : str, optional
The method to use for extracting coordinates. Valid options are 'per_pixel', 'per_object', or 'all'.
Default is 'all'.
object_cutoff : int, optional
The minimum number of pixels per object to be included in the analysis. Default is 1.
Raises
------
ValueError
If the specified method is not recognized.
"""
if not hasattr(self, "atlas_volume"):
raise ValueError(
"Please run build_quantifier before running get_coordinates"
)
if method not in ["per_pixel", "per_object", "all"]:
raise ValueError(
f"method {method} not recognised, valid methods are: per_pixel, per_object, or all"
)
def get_coordinates(self, non_linear=True, method="all", object_cutoff=0, use_flat=False):
"""Extracts pixel coordinates from the segmentation data."""
self._validate_method(method)
print("extracting coordinates with method:", method)
(
pixel_points,
centroids,
region_areas_list,
points_len,
centroids_len,
segmentation_filenames,
) = folder_to_atlas_space(
pixel_points, centroids, region_areas_list, points_len, centroids_len, segmentation_filenames = folder_to_atlas_space(
self.segmentation_folder,
self.alignment_json,
self.atlas_labels,
......@@ -212,191 +107,119 @@ class PyNutil:
)
self.pixel_points = pixel_points
self.centroids = centroids
##points len and centroids len tell us how many points were extracted from each section
##This will be used to split the data up later into per section files
self.points_len = points_len
self.centroids_len = centroids_len
self.segmentation_filenames = segmentation_filenames
self.region_areas_list = region_areas_list
self.method = method
def _validate_method(self, method):
valid_methods = ["per_pixel", "per_object", "all"]
if method not in valid_methods:
raise ValueError(f"method {method} not recognised, valid methods are: {', '.join(valid_methods)}")
def quantify_coordinates(self):
"""Quantifies the pixel coordinates by region.
"""Quantifies the pixel coordinates by region."""
self._check_coordinates_extracted()
print("quantifying coordinates")
labeled_points_centroids = self._label_points(self.centroids) if self.method in ["per_object", "all"] else None
labeled_points = self._label_points(self.pixel_points) if self.method in ["per_pixel", "all"] else None
Raises
------
ValueError
If the pixel coordinates have not been extracted.
self._quantify_per_section(labeled_points, labeled_points_centroids)
self._combine_slice_reports()
"""
self.labeled_points = labeled_points
self.labeled_points_centroids = labeled_points_centroids
print("quantification complete ✅")
def _check_coordinates_extracted(self):
if not hasattr(self, "pixel_points") and not hasattr(self, "centroids"):
raise ValueError(
"Please run get_coordinates before running quantify_coordinates"
)
print("quantifying coordinates")
labeled_points_centroids = None
labeled_points = None
if self.method == "per_object" or self.method == "all":
labeled_points_centroids = label_points(
self.centroids, self.atlas_volume, scale_factor=1
)
if self.method == "per_pixel" or self.method == "all":
labeled_points = label_points(
self.pixel_points, self.atlas_volume, scale_factor=1
)
raise ValueError("Please run get_coordinates before running quantify_coordinates")
def _label_points(self, points):
return label_points(points, self.atlas_volume, scale_factor=1)
def _quantify_per_section(self, labeled_points, labeled_points_centroids):
prev_pl = 0
prev_cl = 0
per_section_df = []
current_centroids = None
current_points = None
for pl, cl, ra in zip(
self.points_len, self.centroids_len, self.region_areas_list
):
if self.method == "per_object" or self.method == "all":
current_centroids = labeled_points_centroids[prev_cl : prev_cl + cl]
if self.method == "per_pixel" or self.method == "all":
current_points = labeled_points[prev_pl : prev_pl + pl]
current_df = pixel_count_per_region(
current_points, current_centroids, self.atlas_labels
)
# create the df for section report and all report
# pixel_count_per_region returns a df with idx, pixel count, name and RGB.
# ra is region area list from
# merge current_df onto ra (region_areas_list) based on idx column
# (left means use only keys from left frame, preserve key order)
"""
Merge region areas and object areas onto the atlas label file.
Remove duplicate columns
Calculate and add area_fraction to new column in the df.
"""
all_region_df = self.atlas_labels.merge(ra, on="idx", how="left")
current_df_new = all_region_df.merge(
current_df, on="idx", how="left", suffixes=(None, "_y")
).drop(columns=["name_y", "r_y", "g_y", "b_y"])
current_df_new["area_fraction"] = (
current_df_new["pixel_count"] / current_df_new["region_area"]
)
current_df_new.fillna(0, inplace=True)
for pl, cl, ra in zip(self.points_len, self.centroids_len, self.region_areas_list):
current_centroids = labeled_points_centroids[prev_cl : prev_cl + cl] if self.method in ["per_object", "all"] else None
current_points = labeled_points[prev_pl : prev_pl + pl] if self.method in ["per_pixel", "all"] else None
current_df = pixel_count_per_region(current_points, current_centroids, self.atlas_labels)
current_df_new = self._merge_dataframes(current_df, ra)
per_section_df.append(current_df_new)
prev_pl += pl
prev_cl += cl
##combine all the slice reports, groupby idx, name, rgb and sum region and object pixels. Remove area_fraction column and recalculate.
self.per_section_df = per_section_df
def _merge_dataframes(self, current_df, ra):
all_region_df = self.atlas_labels.merge(ra, on="idx", how="left")
current_df_new = all_region_df.merge(current_df, on="idx", how="left", suffixes=(None, "_y")).drop(columns=["name_y", "r_y", "g_y", "b_y"])
current_df_new["area_fraction"] = current_df_new["pixel_count"] / current_df_new["region_area"]
current_df_new.fillna(0, inplace=True)
return current_df_new
def _combine_slice_reports(self):
self.label_df = (
pd.concat(per_section_df)
pd.concat(self.per_section_df)
.groupby(["idx", "name", "r", "g", "b"])
.sum()
.reset_index()
.drop(columns=["area_fraction"])
)
self.label_df["area_fraction"] = (
self.label_df["pixel_count"] / self.label_df["region_area"]
)
self.label_df["area_fraction"] = self.label_df["pixel_count"] / self.label_df["region_area"]
self.label_df.fillna(0, inplace=True)
"""
Potential source of error:
If there are duplicates in the label file, regional results will be duplicated and summed leading to incorrect results
"""
# reorder the df to match the order of idx column in self.atlas_labels
self.label_df = self.label_df.set_index("idx")
self.label_df = self.label_df.reindex(index=self.atlas_labels["idx"])
self.label_df = self.label_df.reset_index()
self.labeled_points = labeled_points
self.labeled_points_centroids = labeled_points_centroids
self.per_section_df = per_section_df
print("quantification complete ✅")
def save_analysis(self, output_folder):
"""Saves the pixel coordinates and pixel counts to different files in the specified
output folder.
Parameters
----------
output_folder : str
The path to the output folder.
Raises
------
ValueError
If the pixel coordinates have not been extracted.
"""
if not os.path.exists(output_folder):
os.makedirs(output_folder)
output folder."""
self._create_output_dirs(output_folder)
self._save_quantification(output_folder)
self._save_per_section_reports(output_folder)
self._save_whole_series_meshview(output_folder)
print("analysis saved ✅")
if not os.path.exists(f"{output_folder}/whole_series_report"):
os.makedirs(f"{output_folder}/whole_series_report")
def _create_output_dirs(self, output_folder):
os.makedirs(output_folder, exist_ok=True)
os.makedirs(f"{output_folder}/whole_series_report", exist_ok=True)
os.makedirs(f"{output_folder}/per_section_meshview", exist_ok=True)
os.makedirs(f"{output_folder}/per_section_reports", exist_ok=True)
os.makedirs(f"{output_folder}/whole_series_meshview", exist_ok=True)
if not hasattr(self, "label_df"):
print("no quantification found so we will only save the coordinates")
print(
"if you want to save the quantification please run quantify_coordinates"
)
def _save_quantification(self, output_folder):
if hasattr(self, "label_df"):
self.label_df.to_csv(f"{output_folder}/whole_series_report/counts.csv", sep=";", na_rep="", index=False)
else:
self.label_df.to_csv(
f"{output_folder}/whole_series_report/counts.csv",
sep=";",
na_rep="",
index=False,
)
if not os.path.exists(f"{output_folder}/per_section_meshview"):
os.makedirs(f"{output_folder}/per_section_meshview")
if not os.path.exists(f"{output_folder}/per_section_reports"):
os.makedirs(f"{output_folder}/per_section_reports")
if not os.path.exists(f"{output_folder}/whole_series_meshview"):
os.makedirs(f"{output_folder}/whole_series_meshview")
print("no quantification found so we will only save the coordinates")
print("if you want to save the quantification please run quantify_coordinates")
def _save_per_section_reports(self, output_folder):
prev_pl = 0
prev_cl = 0
for pl, cl, fn, df in zip(
self.points_len,
self.centroids_len,
self.segmentation_filenames,
self.per_section_df,
):
for pl, cl, fn, df in zip(self.points_len, self.centroids_len, self.segmentation_filenames, self.per_section_df):
split_fn = fn.split(os.sep)[-1].split(".")[0]
df.to_csv(
f"{output_folder}/per_section_reports/{split_fn}.csv",
sep=";",
na_rep="",
index=False,
)
if self.method == "per_pixel" or self.method == "all":
write_points_to_meshview(
self.pixel_points[prev_pl : pl + prev_pl],
self.labeled_points[prev_pl : pl + prev_pl],
f"{output_folder}/per_section_meshview/{split_fn}_pixels.json",
self.atlas_labels,
)
if self.method == "per_object" or self.method == "all":
write_points_to_meshview(
self.centroids[prev_cl : cl + prev_cl],
self.labeled_points_centroids[prev_cl : cl + prev_cl],
f"{output_folder}/per_section_meshview/{split_fn}_centroids.json",
self.atlas_labels,
)
df.to_csv(f"{output_folder}/per_section_reports/{split_fn}.csv", sep=";", na_rep="", index=False)
self._save_per_section_meshview(output_folder, split_fn, pl, cl, prev_pl, prev_cl)
prev_cl += cl
prev_pl += pl
if self.method == "per_pixel" or self.method == "all":
write_points_to_meshview(
self.pixel_points,
self.labeled_points,
f"{output_folder}/whole_series_meshview/pixels_meshview.json",
self.atlas_labels,
)
if self.method == "per_object" or self.method == "all":
write_points_to_meshview(
self.centroids,
self.labeled_points_centroids,
f"{output_folder}/whole_series_meshview/objects_meshview.json",
self.atlas_labels,
)
print("analysis saved ✅")
def _save_per_section_meshview(self, output_folder, split_fn, pl, cl, prev_pl, prev_cl):
if self.method in ["per_pixel", "all"]:
write_points_to_meshview(self.pixel_points[prev_pl : pl + prev_pl], self.labeled_points[prev_pl : pl + prev_pl], f"{output_folder}/per_section_meshview/{split_fn}_pixels.json", self.atlas_labels)
if self.method in ["per_object", "all"]:
write_points_to_meshview(self.centroids[prev_cl : cl + prev_cl], self.labeled_points_centroids[prev_cl : cl + prev_cl], f"{output_folder}/per_section_meshview/{split_fn}_centroids.json", self.atlas_labels)
def _save_whole_series_meshview(self, output_folder):
if self.method in ["per_pixel", "all"]:
write_points_to_meshview(self.pixel_points, self.labeled_points, f"{output_folder}/whole_series_meshview/pixels_meshview.json", self.atlas_labels)
if self.method in ["per_object", "all"]:
write_points_to_meshview(self.centroids, self.labeled_points_centroids, f"{output_folder}/whole_series_meshview/objects_meshview.json", self.atlas_labels)
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment