-
Harry Carey authored97fff9e0
coordinate_extraction.py 11.01 KiB
import numpy as np
import pandas as pd
import json
from .read_and_write import load_visualign_json
from .counting_and_load import label_points, flat_to_dataframe
from .visualign_deformations import triangulate, transform_vec
from glob import glob
from tqdm import tqdm
import cv2
from skimage import measure
import threading
import re
def number_sections(filenames: List[str], legacy=False) -> List[int]:
"""
returns the section numbers of filenames
:param filenames: list of filenames
:type filenames: list[str]
:return: list of section numbers
:rtype: list[int]
"""
filenames = [filename.split("\\")[-1] for filename in filenames]
section_numbers = []
for filename in filenames:
if not legacy:
match = re.findall(r"\_s\d+", filename)
if len(match) == 0:
raise ValueError(f"No section number found in filename: {filename}")
if len(match) > 1:
raise ValueError(
"Multiple section numbers found in filename, ensure only one instance of _s### is present, where ### is the section number"
)
section_numbers.append(int(match[-1][2:]))
else:
match = re.sub("[^0-9]", "", filename)
###this gets the three numbers closest to the end
section_numbers.append(match[-3:])
if len(section_numbers) == 0:
raise ValueError("No section numbers found in filenames")
return section_numbers
# related to coordinate_extraction
def get_centroids_and_area(segmentation, pixel_cut_off=0):
"""This function returns the center coordinate of each object in the segmentation.
You can set a pixel_cut_off to remove objects that are smaller than that number of pixels.
"""
labels = measure.label(segmentation)
# This finds all the objects in the image
labels_info = measure.regionprops(labels)
# Remove objects that are less than pixel_cut_off
labels_info = [label for label in labels_info if label.area > pixel_cut_off]
# Get the centre points of the objects
centroids = np.array([label.centroid for label in labels_info])
# Get the area of the objects
area = np.array([label.area for label in labels_info])
# Get the coordinates for all the pixels in each object
coords = np.array([label.coords for label in labels_info], dtype=object)
return centroids, area, coords
# related to coordinate extraction
def transform_to_registration(seg_height, seg_width, reg_height, reg_width):
"""This function returns the scaling factors to transform the segmentation to the registration space."""
y_scale = reg_height / seg_height
x_scale = reg_width / seg_width
return y_scale, x_scale
# related to coordinate extraction
def find_matching_pixels(segmentation, id):
"""This function returns the Y and X coordinates of all the pixels in the segmentation that match the id provided."""
mask = segmentation == id
mask = np.all(mask, axis=2)
id_positions = np.where(mask)
id_y, id_x = id_positions[0], id_positions[1]
return id_y, id_x
# related to coordinate extraction
def scale_positions(id_y, id_x, y_scale, x_scale):
"""This function scales the Y and X coordinates to the registration space.
(The y_scale and x_scale are the output of transform_to_registration.)
"""
id_y = id_y * y_scale
id_x = id_x * x_scale
return id_y, id_x
# related to coordinate extraction
def transform_to_atlas_space(anchoring, y, x, reg_height, reg_width):
"""Transform to atlas space using the QuickNII anchoring vector."""
o = anchoring[0:3]
u = anchoring[3:6]
# Swap order of U
u = np.array([u[0], u[1], u[2]])
v = anchoring[6:9]
# Swap order of V
v = np.array([v[0], v[1], v[2]])
# Scale X and Y to between 0 and 1 using the registration width and height
y_scale = y / reg_height
x_scale = x / reg_width
xyz_v = np.array([y_scale * v[0], y_scale * v[1], y_scale * v[2]])
xyz_u = np.array([x_scale * u[0], x_scale * u[1], x_scale * u[2]])
o = np.reshape(o, (3, 1))
return (o + xyz_u + xyz_v).T
# points.append would make list of lists, keeping sections separate.
# related to coordinate extraction
# This function returns an array of points
def folder_to_atlas_space(
folder,
quint_alignment,
atlas_labels,
pixel_id=[0, 0, 0],
non_linear=True,
method="all",
object_cutoff=0,
):
"""Apply Segmentation to atlas space to all segmentations in a folder."""
"""Return pixel_points, centroids, points_len, centroids_len, segmentation_filenames, """
# This should be loaded above and passed as an argument
slices = load_visualign_json(quint_alignment)
segmentation_file_types = [".png", ".tif", ".tiff", ".jpg", ".jpeg"]
segmentations = [
file
for file in glob(folder + "/*")
if any([file.endswith(type) for type in segmentation_file_types])
]
flat_files = [
file
for file in glob(folder + "/flat_files/*")
if any([file.endswith('.flat')])
]
# Order segmentations and section_numbers
# segmentations = [x for _,x in sorted(zip(section_numbers,segmentations))]
# section_numbers.sort()
points_list = [None] * len(segmentations)
centroids_list = [None] * len(segmentations)
region_areas_list = [None] * len(segmentations)
threads = []
flat_file_nrs = [int(number_sections([ff])[0]) for ff in flat_files]
for segmentation_path, index in zip(segmentations, range(len(segmentations))):
seg_nr = int(number_sections([segmentation_path])[0])
current_slice_index = np.where([s["nr"] == seg_nr for s in slices])
current_slice = slices[current_slice_index[0][0]]
current_flat_file_index = np.where([f == seg_nr for f in flat_file_nrs])
current_flat = flat_files[current_flat_file_index[0][0]]
x = threading.Thread(
target=segmentation_to_atlas_space,
args=(
current_slice,
segmentation_path,
atlas_labels,
current_flat,
pixel_id,
non_linear,
points_list,
centroids_list,
region_areas_list,
index,
method,
object_cutoff,
),
)
threads.append(x)
## This converts the segmentation to a point cloud
# Start threads
[t.start() for t in threads]
# Wait for threads to finish
[t.join() for t in threads]
# Flatten points_list
points_len = [len(points) for points in points_list]
centroids_len = [len(centroids) for centroids in centroids_list]
points = np.concatenate(points_list)
centroids = np.concatenate(centroids_list)
return (
np.array(points),
np.array(centroids),
region_areas_list,
points_len,
centroids_len,
segmentations,
)
def segmentation_to_atlas_space(
slice,
segmentation_path,
atlas_labels,
flat_file_atlas,
pixel_id="auto",
non_linear=True,
points_list=None,
centroids_list=None,
region_areas_list=None,
index=None,
method="per_pixel",
object_cutoff=0,
):
"""Combines many functions to convert a segmentation to atlas space. It takes care
of deformations."""
segmentation = cv2.imread(segmentation_path)
if pixel_id == "auto":
# Remove the background from the segmentation
segmentation_no_background = segmentation[~np.all(segmentation == 255, axis=2)]
pixel_id = np.vstack(
{tuple(r) for r in segmentation_no_background.reshape(-1, 3)}
) # Remove background
# Currently only works for a single label
pixel_id = pixel_id[0]
# Transform pixels to registration space (the registered image and segmentation have different dimensions)
seg_height = segmentation.shape[0]
seg_width = segmentation.shape[1]
reg_height = slice["height"]
reg_width = slice["width"]
region_areas = flat_to_dataframe(flat_file_atlas, atlas_labels, (seg_width,seg_height))
# This calculates reg/seg
y_scale, x_scale = transform_to_registration(
seg_height, seg_width, reg_height, reg_width
)
centroids, points = None, None
if method in ["per_object", "all"]:
centroids, scaled_centroidsX, scaled_centroidsY = get_centroids(
segmentation, pixel_id, y_scale, x_scale, object_cutoff
)
if method in ["per_pixel", "all"]:
scaled_y, scaled_x = get_scaled_pixels(segmentation, pixel_id, y_scale, x_scale)
if non_linear:
if "markers" in slice:
# This creates a triangulation using the reg width
triangulation = triangulate(reg_width, reg_height, slice["markers"])
if method in ["per_pixel", "all"]:
new_x, new_y = transform_vec(triangulation, scaled_x, scaled_y)
if method in ["per_object", "all"]:
centroids_new_x, centroids_new_y = transform_vec(
triangulation, scaled_centroidsX, scaled_centroidsY
)
else:
print(
f"No markers found for {slice['filename']}, result for section will be linear."
)
if method in ["per_pixel", "all"]:
new_x, new_y = scaled_x, scaled_y
if method in ["per_object", "all"]:
centroids_new_x, centroids_new_y = scaled_centroidsX, scaled_centroidsY
else:
if method in ["per_pixel", "all"]:
new_x, new_y = scaled_x, scaled_y
if method in ["per_object", "all"]:
centroids_new_x, centroids_new_y = scaled_centroidsX, scaled_centroidsY
# Scale U by Uxyz/RegWidth and V by Vxyz/RegHeight
if method in ["per_pixel", "all"]:
points = transform_to_atlas_space(
slice["anchoring"], new_y, new_x, reg_height, reg_width
)
if method in ["per_object", "all"]:
centroids = transform_to_atlas_space(
slice["anchoring"], centroids_new_y, centroids_new_x, reg_height, reg_width
)
print(
f"Finished and points len is: {len(points)} and centroids len is: {len(centroids)}"
)
points_list[index] = np.array(points)
centroids_list[index] = np.array(centroids)
region_areas_list[index] = region_areas
def get_centroids(segmentation, pixel_id, y_scale, x_scale, object_cutoff=0):
binary_seg = segmentation == pixel_id
binary_seg = np.all(binary_seg, axis=2)
centroids, area, coords = get_centroids_and_area(
binary_seg, pixel_cut_off=object_cutoff
)
centroidsX = centroids[:, 1]
centroidsY = centroids[:, 0]
scaled_centroidsY, scaled_centroidsX = scale_positions(
centroidsY, centroidsX, y_scale, x_scale
)
return centroids, scaled_centroidsX, scaled_centroidsY
def get_scaled_pixels(segmentation, pixel_id, y_scale, x_scale):
id_pixels = find_matching_pixels(segmentation, pixel_id)
# Scale the seg coordinates to reg/seg
scaled_y, scaled_x = scale_positions(id_pixels[0], id_pixels[1], y_scale, x_scale)
return scaled_y, scaled_x