Skip to content
Snippets Groups Projects
coordinate_extraction.py 13.75 KiB
import numpy as np
import pandas as pd
from .read_and_write import load_visualign_json
from .counting_and_load import flat_to_dataframe
from .visualign_deformations import triangulate, transform_vec
from glob import glob
import cv2
from skimage import measure
import threading
import re
from .reconstruct_dzi import reconstruct_dzi


def number_sections(filenames, legacy=False):
    """
    returns the section numbers of filenames

    :param filenames: list of filenames
    :type filenames: list[str]
    :return: list of section numbers
    :rtype: list[int]
    """
    filenames = [filename.split("\\")[-1] for filename in filenames]
    section_numbers = []
    for filename in filenames:
        if not legacy:
            match = re.findall(r"\_s\d+", filename)
            if len(match) == 0:
                raise ValueError(f"No section number found in filename: {filename}")
            if len(match) > 1:
                raise ValueError(
                    "Multiple section numbers found in filename, ensure only one instance of _s### is present, where ### is the section number"
                )
            section_numbers.append(int(match[-1][2:]))
        else:
            match = re.sub("[^0-9]", "", filename)
            ###this gets the three numbers closest to the end
            section_numbers.append(match[-3:])
    if len(section_numbers) == 0:
        raise ValueError("No section numbers found in filenames")
    return section_numbers


# related to coordinate_extraction
def get_centroids_and_area(segmentation, pixel_cut_off=0):
    """This function returns the center coordinate of each object in the segmentation.
    You can set a pixel_cut_off to remove objects that are smaller than that number of pixels.
    """
    labels = measure.label(segmentation)
    # This finds all the objects in the image
    labels_info = measure.regionprops(labels)
    # Remove objects that are less than pixel_cut_off
    labels_info = [label for label in labels_info if label.area > pixel_cut_off]
    # Get the centre points of the objects
    centroids = np.array([label.centroid for label in labels_info])
    # Get the area of the objects
    area = np.array([label.area for label in labels_info])
    # Get the coordinates for all the pixels in each object
    coords = np.array([label.coords for label in labels_info], dtype=object)
    return centroids, area, coords


# related to coordinate extraction
def transform_to_registration(seg_height, seg_width, reg_height, reg_width):
    """This function returns the scaling factors to transform the segmentation to the registration space."""
    y_scale = reg_height / seg_height
    x_scale = reg_width / seg_width
    return y_scale, x_scale


# related to coordinate extraction
def find_matching_pixels(segmentation, id):
    """This function returns the Y and X coordinates of all the pixels in the segmentation that match the id provided."""
    mask = segmentation == id
    mask = np.all(mask, axis=2)
    id_positions = np.where(mask)
    id_y, id_x = id_positions[0], id_positions[1]
    return id_y, id_x


# related to coordinate extraction
def scale_positions(id_y, id_x, y_scale, x_scale):
    """This function scales the Y and X coordinates to the registration space.
    (The y_scale and x_scale are the output of transform_to_registration.)
    """
    id_y = id_y * y_scale
    id_x = id_x * x_scale
    return id_y, id_x


# related to coordinate extraction
def transform_to_atlas_space(anchoring, y, x, reg_height, reg_width):
    """Transform to atlas space using the QuickNII anchoring vector."""
    o = anchoring[0:3]
    u = anchoring[3:6]
    # Swap order of U
    u = np.array([u[0], u[1], u[2]])
    v = anchoring[6:9]
    # Swap order of V
    v = np.array([v[0], v[1], v[2]])
    # Scale X and Y to between 0 and 1 using the registration width and height
    y_scale = y / reg_height
    x_scale = x / reg_width
    xyz_v = np.array([y_scale * v[0], y_scale * v[1], y_scale * v[2]])
    xyz_u = np.array([x_scale * u[0], x_scale * u[1], x_scale * u[2]])
    o = np.reshape(o, (3, 1))
    return (o + xyz_u + xyz_v).T


# points.append would make list of lists, keeping sections separate.


# related to coordinate extraction
# This function returns an array of points
def folder_to_atlas_space(
    folder,
    quint_alignment,
    atlas_labels,
    pixel_id=[0, 0, 0],
    non_linear=True,
    method="all",
    object_cutoff=0,
    atlas_volume=None,
    use_flat=False,
):
    """Apply Segmentation to atlas space to all segmentations in a folder."""
    """Return pixel_points, centroids, points_len, centroids_len, segmentation_filenames, """
    # This should be loaded above and passed as an argument
    slices = load_visualign_json(quint_alignment)

    segmentation_file_types = [".png", ".tif", ".tiff", ".jpg", ".jpeg", ".dzip"]
    segmentations = [
        file
        for file in glob(folder + "/segmentations/*")
        if any([file.endswith(type) for type in segmentation_file_types])
    ]
    if len(segmentations) == 0:
        raise ValueError(
            f"No segmentations found in folder {folder}. Make sure the folder contains a segmentations folder with segmentations."
        )
    print(f"Found {len(segmentations)} segmentations in folder {folder}")
    if use_flat == True:
        flat_files = [
            file
            for file in glob(folder + "/flat_files/*")
            if any([file.endswith(".flat"), file.endswith(".seg")])
        ]
        print(f"Found {len(flat_files)} flat files in folder {folder}")
        flat_file_nrs = [int(number_sections([ff])[0]) for ff in flat_files]

    # Order segmentations and section_numbers
    # segmentations = [x for _,x in sorted(zip(section_numbers,segmentations))]
    # section_numbers.sort()
    points_list = [np.array([])] * len(segmentations)
    centroids_list = [np.array([])] * len(segmentations)
    region_areas_list = [
        pd.DataFrame(
            {
                "idx": [],
                "name": [],
                "r": [],
                "g": [],
                "b": [],
                "region_area": [],
                "pixel_count": [],
                "object_count": [],
                "area_fraction": [],
            }
        )
    ] * len(segmentations)
    threads = []
    for segmentation_path, index in zip(segmentations, range(len(segmentations))):
        seg_nr = int(number_sections([segmentation_path])[0])
        current_slice_index = np.where([s["nr"] == seg_nr for s in slices])
        current_slice = slices[current_slice_index[0][0]]
        if current_slice["anchoring"] == []:
            continue
        if use_flat == True:
            current_flat_file_index = np.where([f == seg_nr for f in flat_file_nrs])
            current_flat = flat_files[current_flat_file_index[0][0]]
        else:
            current_flat = None

        x = threading.Thread(
            target=segmentation_to_atlas_space,
            args=(
                current_slice,
                segmentation_path,
                atlas_labels,
                current_flat,
                pixel_id,
                non_linear,
                points_list,
                centroids_list,
                region_areas_list,
                index,
                method,
                object_cutoff,
                atlas_volume,
                use_flat,
            ),
        )
        threads.append(x)
        ## This converts the segmentation to a point cloud
    # Start threads
    [t.start() for t in threads]
    # Wait for threads to finish
    [t.join() for t in threads]
    # Flatten points_list

    points_len = [len(points) if None not in points else 0 for points in points_list]
    centroids_len = [
        len(centroids) if None not in centroids else 0 for centroids in centroids_list
    ]
    points_list = [points for points in points_list if None not in points]
    centroids_list = [
        centroids for centroids in centroids_list if None not in centroids
    ]
    if len(points_list) == 0:
        points = np.array([])
    else:
        points = np.concatenate(points_list)
    if len(centroids_list) == 0:
        centroids = np.array([])
    else:
        centroids = np.concatenate(centroids_list)

    return (
        np.array(points),
        np.array(centroids),
        region_areas_list,
        points_len,
        centroids_len,
        segmentations,
    )


def load_segmentation(segmentation_path: str):
    """Load a segmentation from a file."""
    print(f"working on {segmentation_path}")
    if segmentation_path.endswith(".dzip"):
        print("Reconstructing dzi")
        return reconstruct_dzi(segmentation_path)
    else:
        return cv2.imread(segmentation_path)


def detect_pixel_id(segmentation: np.array):
    """Remove the background from the segmentation and return the pixel id."""
    segmentation_no_background = segmentation[~np.all(segmentation == 0, axis=2)]
    pixel_id = segmentation_no_background[0]
    print("detected pixel_id: ", pixel_id)
    return pixel_id


def get_region_areas(
    use_flat,
    atlas_labels,
    flat_file_atlas,
    seg_width,
    seg_height,
    slice_dict,
    atlas_volume,
    triangulation,
):
    if use_flat:
        region_areas = flat_to_dataframe(
            atlas_labels, flat_file_atlas, (seg_width, seg_height)
        )
    else:
        region_areas = flat_to_dataframe(
            atlas_labels,
            flat_file_atlas,
            (seg_width, seg_height),
            slice_dict["anchoring"],
            atlas_volume,
            triangulation,
        )
    return region_areas


def get_transformed_coordinates(
    non_linear,
    slice_dict,
    method,
    scaled_x,
    scaled_y,
    centroids,
    scaled_centroidsX,
    scaled_centroidsY,
    triangulation,
):
    new_x, new_y, centroids_new_x, centroids_new_y = None, None, None, None
    if non_linear and "markers" in slice_dict:
        if method in ["per_pixel", "all"] and scaled_x is not None:
            new_x, new_y = transform_vec(triangulation, scaled_x, scaled_y)
        if method in ["per_object", "all"] and centroids is not None:
            centroids_new_x, centroids_new_y = transform_vec(
                triangulation, scaled_centroidsX, scaled_centroidsY
            )
    else:
        if method in ["per_pixel", "all"]:
            new_x, new_y = scaled_x, scaled_y
        if method in ["per_object", "all"]:
            centroids_new_x, centroids_new_y = scaled_centroidsX, scaled_centroidsY
    return new_x, new_y, centroids_new_x, centroids_new_y


def segmentation_to_atlas_space(
    slice_dict,
    segmentation_path,
    atlas_labels,
    flat_file_atlas=None,
    pixel_id="auto",
    non_linear=True,
    points_list=None,
    centroids_list=None,
    region_areas_list=None,
    index=None,
    method="per_pixel",
    object_cutoff=0,
    atlas_volume=None,
    use_flat=False,
):
    segmentation = load_segmentation(segmentation_path)
    if pixel_id == "auto":
        pixel_id = detect_pixel_id(segmentation)
    seg_height, seg_width = segmentation.shape[:2]
    reg_height, reg_width = slice_dict["height"], slice_dict["width"]
    if non_linear and "markers" in slice_dict:
        triangulation = triangulate(reg_width, reg_height, slice_dict["markers"])
    else:
        triangulation = None
    region_areas = get_region_areas(
        use_flat,
        atlas_labels,
        flat_file_atlas,
        seg_width,
        seg_height,
        slice_dict,
        atlas_volume,
        triangulation,
    )
    y_scale, x_scale = transform_to_registration(
        seg_height, seg_width, reg_height, reg_width
    )
    centroids, points = None, None
    scaled_centroidsX, scaled_centroidsY, scaled_x, scaled_y = None, None, None, None
    if method in ["per_object", "all"]:
        centroids, scaled_centroidsX, scaled_centroidsY = get_centroids(
            segmentation, pixel_id, y_scale, x_scale, object_cutoff
        )
    if method in ["per_pixel", "all"]:
        scaled_y, scaled_x = get_scaled_pixels(segmentation, pixel_id, y_scale, x_scale)

    new_x, new_y, centroids_new_x, centroids_new_y = get_transformed_coordinates(
        non_linear,
        slice_dict,
        method,
        scaled_x,
        scaled_y,
        centroids,
        scaled_centroidsX,
        scaled_centroidsY,
        triangulation,
    )
    if method in ["per_pixel", "all"] and new_x is not None:
        points = transform_to_atlas_space(
            slice_dict["anchoring"], new_y, new_x, reg_height, reg_width
        )
    if method in ["per_object", "all"] and centroids_new_x is not None:
        centroids = transform_to_atlas_space(
            slice_dict["anchoring"],
            centroids_new_y,
            centroids_new_x,
            reg_height,
            reg_width,
        )
    points_list[index] = np.array(points if points is not None else [])
    centroids_list[index] = np.array(centroids if centroids is not None else [])
    region_areas_list[index] = region_areas


def get_centroids(segmentation, pixel_id, y_scale, x_scale, object_cutoff=0):
    binary_seg = segmentation == pixel_id
    binary_seg = np.all(binary_seg, axis=2)
    centroids, area, coords = get_centroids_and_area(
        binary_seg, pixel_cut_off=object_cutoff
    )

    print(f"using pixel id {pixel_id}")
    print(f"Found {len(centroids)} objects in the segmentation")
    if len(centroids) == 0:
        return None, None, None
    centroidsX = centroids[:, 1]
    centroidsY = centroids[:, 0]
    scaled_centroidsY, scaled_centroidsX = scale_positions(
        centroidsY, centroidsX, y_scale, x_scale
    )
    return centroids, scaled_centroidsX, scaled_centroidsY


def get_scaled_pixels(segmentation, pixel_id, y_scale, x_scale):
    id_pixels = find_matching_pixels(segmentation, pixel_id)
    if len(id_pixels[0]) == 0:
        return None, None
    # Scale the seg coordinates to reg/seg
    scaled_y, scaled_x = scale_positions(id_pixels[0], id_pixels[1], y_scale, x_scale)
    return scaled_y, scaled_x