Skip to content
Snippets Groups Projects
Commit 71a5ca9b authored by Harry Carey's avatar Harry Carey
Browse files

including @tevemadar s propagation fix

parent 53cdbadd
No related branches found
No related tags found
No related merge requests found
...@@ -135,14 +135,21 @@ def folder_to_atlas_space( ...@@ -135,14 +135,21 @@ def folder_to_atlas_space(
for file in glob(folder + "/segmentations/*") for file in glob(folder + "/segmentations/*")
if any([file.endswith(type) for type in segmentation_file_types]) if any([file.endswith(type) for type in segmentation_file_types])
] ]
flat_files = [ if len(segmentations) == 0:
file raise ValueError(
for file in glob(folder + "/flat_files/*") f"No segmentations found in folder {folder}. Make sure the folder contains a segmentations folder with segmentations."
if any([file.endswith('.flat'), file.endswith('.seg')]) )
]
print(f"Found {len(segmentations)} segmentations in folder {folder}") print(f"Found {len(segmentations)} segmentations in folder {folder}")
print(f"Found {len(flat_files)} flat files in folder {folder}") if use_flat == True:
flat_files = [
file
for file in glob(folder + "/flat_files/*")
if any([file.endswith('.flat'), file.endswith('.seg')])
]
print(f"Found {len(flat_files)} flat files in folder {folder}")
flat_file_nrs = [int(number_sections([ff])[0]) for ff in flat_files]
# Order segmentations and section_numbers # Order segmentations and section_numbers
# segmentations = [x for _,x in sorted(zip(section_numbers,segmentations))] # segmentations = [x for _,x in sorted(zip(section_numbers,segmentations))]
# section_numbers.sort() # section_numbers.sort()
...@@ -150,11 +157,12 @@ def folder_to_atlas_space( ...@@ -150,11 +157,12 @@ def folder_to_atlas_space(
centroids_list = [None] * len(segmentations) centroids_list = [None] * len(segmentations)
region_areas_list = [None] * len(segmentations) region_areas_list = [None] * len(segmentations)
threads = [] threads = []
flat_file_nrs = [int(number_sections([ff])[0]) for ff in flat_files]
for segmentation_path, index in zip(segmentations, range(len(segmentations))): for segmentation_path, index in zip(segmentations, range(len(segmentations))):
seg_nr = int(number_sections([segmentation_path])[0]) seg_nr = int(number_sections([segmentation_path])[0])
current_slice_index = np.where([s["nr"] == seg_nr for s in slices]) current_slice_index = np.where([s["nr"] == seg_nr for s in slices])
current_slice = slices[current_slice_index[0][0]] current_slice = slices[current_slice_index[0][0]]
if current_slice["anchoring"] == []:
continue
if use_flat == True: if use_flat == True:
current_flat_file_index = np.where([f == seg_nr for f in flat_file_nrs]) current_flat_file_index = np.where([f == seg_nr for f in flat_file_nrs])
current_flat = flat_files[current_flat_file_index[0][0]] current_flat = flat_files[current_flat_file_index[0][0]]
...@@ -187,6 +195,7 @@ def folder_to_atlas_space( ...@@ -187,6 +195,7 @@ def folder_to_atlas_space(
# Wait for threads to finish # Wait for threads to finish
[t.join() for t in threads] [t.join() for t in threads]
# Flatten points_list # Flatten points_list
points_len = [len(points) for points in points_list] points_len = [len(points) for points in points_list]
centroids_len = [len(centroids) for centroids in centroids_list] centroids_len = [len(centroids) for centroids in centroids_list]
......
...@@ -100,9 +100,9 @@ def pixel_count_per_region( ...@@ -100,9 +100,9 @@ def pixel_count_per_region(
current_region_blue = current_region_row["b"].values current_region_blue = current_region_row["b"].values
row["name"] = current_region_name[0] row["name"] = current_region_name[0]
row["r"] = current_region_red[0] row["r"] = int(current_region_red[0])
row["g"] = current_region_green[0] row["g"] = int(current_region_green[0])
row["b"] = current_region_blue[0] row["b"] = int(current_region_blue[0])
new_rows.append(row) new_rows.append(row)
......
import math,re
def propagate(arr):
for slice in arr:
if "nr" not in slice:
slice["nr"]=int(re.search(r"_s(\d+)",slice["filename"]).group(1))
arr.sort(key=lambda slice:slice["nr"])
linregs=[LinReg() for i in range(11)]
count=0
for slice in arr:
if "anchoring" in slice:
a=slice["anchoring"]
for i in range(3):
a[i]+=(a[i+3]+a[i+6])/2
a.extend([normalize(a,3)/slice["width"],normalize(a,6)/slice["height"]])
for i in range(len(linregs)):
linregs[i].add(slice["nr"],a[i])
count+=1
if count>=2:
l=len(arr)
if not "anchoring" in arr[0]:
nr=arr[0]["nr"]
a=[linreg.get(nr) for linreg in linregs]
orthonormalize(a)
arr[0]["anchoring"]=a
count+=1
if not "anchoring" in arr[l-1]:
nr=arr[l-1]["nr"]
a=[linreg.get(nr) for linreg in linregs]
orthonormalize(a)
arr[l-1]["anchoring"]=a
count+=1
start=1
while count<l:
while "anchoring" in arr[start]:
start+=1
next=start+1
while not "anchoring" in arr[next]:
next+=1
pnr=arr[start-1]["nr"]
nnr=arr[next]["nr"]
panch=arr[start-1]["anchoring"]
nanch=arr[next]["anchoring"]
linints=[LinInt(pnr,panch[i],nnr,nanch[i]) for i in range(len(panch))]
for i in range(start,next):
nr=arr[i]["nr"]
arr[i]["anchoring"]=[linint.get(nr) for linint in linints]
count+=1
start=next+1
for slice in arr:
a=slice["anchoring"]
orthonormalize(a)
v=a.pop()
u=a.pop()
for i in range(3):
a[i+3]*=u*slice["width"]
a[i+6]*=v*slice["height"]
a[i]-=(a[i+3]+a[i+6])/2
return arr
class LinInt:
def __init__(self,x1,y1,x2,y2):
self.x1=x1
self.x2=x2
self.y1=y1
self.y2=y2
def get(self,x):
return self.y1+(self.y2-self.y1)*(x-self.x1)/(self.x2-self.x1)
class LinReg:
def __init__(self):
self.n=0
self.Sx=0
self.Sy=0
self.Sxx=0
self.Sxy=0
def add(self,x,y):
self.n+=1
self.Sx+=x
self.Sy+=y
self.Sxx+=x*x
self.Sxy+=x*y
if self.n>=2:
self.b=(self.n*self.Sxy-self.Sx*self.Sy)/(self.n*self.Sxx-self.Sx*self.Sx)
self.a=self.Sy/self.n-self.b*self.Sx/self.n
def get(self,x):
return self.a+self.b*x
def normalize(arr,idx):
l=0
for i in range(3):
l+=arr[idx+i]*arr[idx+i]
l=math.sqrt(l)
for i in range(3):
arr[idx+i]/=l
return l
def orthonormalize(arr):
normalize(arr,3)
dot=0
for i in range(3):
dot+=arr[i+3]*arr[i+6]
for i in range(3):
arr[i+6]-=arr[i+3]*dot
normalize(arr,6)
if __name__=="__main__":
import json,sys
with open(sys.argv[1]) as f:
series=json.load(f)
propagate(series["slices"])
with open(sys.argv[2],"w") as f:
json.dump(series,f)
...@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt ...@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
import os import os
import nrrd import nrrd
import re import re
from .propagation import propagate
# related to read and write # related to read and write
# this function reads a VisuAlign JSON and returns the slices # this function reads a VisuAlign JSON and returns the slices
...@@ -19,7 +19,11 @@ def load_visualign_json(filename): ...@@ -19,7 +19,11 @@ def load_visualign_json(filename):
for slice in slices: for slice in slices:
print(slice) print(slice)
slice["nr"] = int(re.search(r"_s(\d+)", slice["filename"]).group(1)) slice["nr"] = int(re.search(r"_s(\d+)", slice["filename"]).group(1))
slice["anchoring"] = slice["ouv"] if "ouv" in slice:
slice["anchoring"] = slice["ouv"]
else:
slice["anchoring"] = []
name = os.path.basename(filename)
lz_compat_file = { lz_compat_file = {
"name":name, "name":name,
"target":vafile["atlas"], "target":vafile["atlas"],
...@@ -31,16 +35,10 @@ def load_visualign_json(filename): ...@@ -31,16 +35,10 @@ def load_visualign_json(filename):
json.dump(lz_compat_file, f, indent=4) json.dump(lz_compat_file, f, indent=4)
else: else:
slices = vafile["slices"] slices = vafile["slices"]
# overwrite the existing file slices = propagate(slices)
name = os.path.basename(filename)
return slices return slices
def load_visualign_json(filename):
with open(filename) as f:
vafile = json.load(f)
slices = vafile["slices"]
return slices
# related to read_and_write, used in write_points_to_meshview # related to read_and_write, used in write_points_to_meshview
# this function returns a dictionary of region names # this function returns a dictionary of region names
def create_region_dict(points, regions): def create_region_dict(points, regions):
...@@ -61,9 +59,9 @@ def write_points(points_dict, filename, info_file): ...@@ -61,9 +59,9 @@ def write_points(points_dict, filename, info_file):
"count": len(points_dict[name]) // 3, "count": len(points_dict[name]) // 3,
"name": str(info_file["name"].values[info_file["idx"] == name][0]), "name": str(info_file["name"].values[info_file["idx"] == name][0]),
"triplets": points_dict[name], "triplets": points_dict[name],
"r": str(info_file["r"].values[info_file["idx"] == name][0]), "r": int(info_file["r"].values[info_file["idx"] == name][0]),
"g": str(info_file["g"].values[info_file["idx"] == name][0]), "g": int(info_file["g"].values[info_file["idx"] == name][0]),
"b": str(info_file["b"].values[info_file["idx"] == name][0]), "b": int(info_file["b"].values[info_file["idx"] == name][0]),
} }
for name, idx in zip(points_dict.keys(), range(len(points_dict.keys()))) for name, idx in zip(points_dict.keys(), range(len(points_dict.keys())))
] ]
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment