diff --git a/PyNutil/coordinate_extraction.py b/PyNutil/coordinate_extraction.py index 277c3b75ea0abce1183d44cfa6d94ee172df9a0b..359835a06dbdcdb86edda6651452ee02fa87a5d4 100644 --- a/PyNutil/coordinate_extraction.py +++ b/PyNutil/coordinate_extraction.py @@ -135,14 +135,21 @@ def folder_to_atlas_space( for file in glob(folder + "/segmentations/*") if any([file.endswith(type) for type in segmentation_file_types]) ] - flat_files = [ - file - for file in glob(folder + "/flat_files/*") - if any([file.endswith('.flat'), file.endswith('.seg')]) - ] + if len(segmentations) == 0: + raise ValueError( + f"No segmentations found in folder {folder}. Make sure the folder contains a segmentations folder with segmentations." + ) print(f"Found {len(segmentations)} segmentations in folder {folder}") - print(f"Found {len(flat_files)} flat files in folder {folder}") - + if use_flat == True: + flat_files = [ + file + for file in glob(folder + "/flat_files/*") + if any([file.endswith('.flat'), file.endswith('.seg')]) + ] + print(f"Found {len(flat_files)} flat files in folder {folder}") + flat_file_nrs = [int(number_sections([ff])[0]) for ff in flat_files] + + # Order segmentations and section_numbers # segmentations = [x for _,x in sorted(zip(section_numbers,segmentations))] # section_numbers.sort() @@ -150,11 +157,12 @@ def folder_to_atlas_space( centroids_list = [None] * len(segmentations) region_areas_list = [None] * len(segmentations) threads = [] - flat_file_nrs = [int(number_sections([ff])[0]) for ff in flat_files] for segmentation_path, index in zip(segmentations, range(len(segmentations))): seg_nr = int(number_sections([segmentation_path])[0]) current_slice_index = np.where([s["nr"] == seg_nr for s in slices]) current_slice = slices[current_slice_index[0][0]] + if current_slice["anchoring"] == []: + continue if use_flat == True: current_flat_file_index = np.where([f == seg_nr for f in flat_file_nrs]) current_flat = flat_files[current_flat_file_index[0][0]] @@ -187,6 +195,7 @@ def folder_to_atlas_space( # Wait for threads to finish [t.join() for t in threads] # Flatten points_list + points_len = [len(points) for points in points_list] centroids_len = [len(centroids) for centroids in centroids_list] diff --git a/PyNutil/counting_and_load.py b/PyNutil/counting_and_load.py index 87adbe2d683a7565e1e45ed818e791a41dedc280..22dda979be3e1ae1e4d6f10b04c46d93c2d0c518 100644 --- a/PyNutil/counting_and_load.py +++ b/PyNutil/counting_and_load.py @@ -100,9 +100,9 @@ def pixel_count_per_region( current_region_blue = current_region_row["b"].values row["name"] = current_region_name[0] - row["r"] = current_region_red[0] - row["g"] = current_region_green[0] - row["b"] = current_region_blue[0] + row["r"] = int(current_region_red[0]) + row["g"] = int(current_region_green[0]) + row["b"] = int(current_region_blue[0]) new_rows.append(row) diff --git a/PyNutil/propagation.py b/PyNutil/propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..d4af1e81db50556b52da144d180eb41a6b85ee76 --- /dev/null +++ b/PyNutil/propagation.py @@ -0,0 +1,121 @@ +import math,re + +def propagate(arr): + for slice in arr: + if "nr" not in slice: + slice["nr"]=int(re.search(r"_s(\d+)",slice["filename"]).group(1)) + + arr.sort(key=lambda slice:slice["nr"]) + + linregs=[LinReg() for i in range(11)] + count=0 + for slice in arr: + if "anchoring" in slice: + a=slice["anchoring"] + for i in range(3): + a[i]+=(a[i+3]+a[i+6])/2 + a.extend([normalize(a,3)/slice["width"],normalize(a,6)/slice["height"]]) + for i in range(len(linregs)): + linregs[i].add(slice["nr"],a[i]) + count+=1 + + if count>=2: + l=len(arr) + if not "anchoring" in arr[0]: + nr=arr[0]["nr"] + a=[linreg.get(nr) for linreg in linregs] + orthonormalize(a) + arr[0]["anchoring"]=a + count+=1 + if not "anchoring" in arr[l-1]: + nr=arr[l-1]["nr"] + a=[linreg.get(nr) for linreg in linregs] + orthonormalize(a) + arr[l-1]["anchoring"]=a + count+=1 + + start=1 + while count<l: + while "anchoring" in arr[start]: + start+=1 + next=start+1 + while not "anchoring" in arr[next]: + next+=1 + pnr=arr[start-1]["nr"] + nnr=arr[next]["nr"] + panch=arr[start-1]["anchoring"] + nanch=arr[next]["anchoring"] + linints=[LinInt(pnr,panch[i],nnr,nanch[i]) for i in range(len(panch))] + for i in range(start,next): + nr=arr[i]["nr"] + arr[i]["anchoring"]=[linint.get(nr) for linint in linints] + count+=1 + start=next+1 + + for slice in arr: + a=slice["anchoring"] + orthonormalize(a) + v=a.pop() + u=a.pop() + for i in range(3): + a[i+3]*=u*slice["width"] + a[i+6]*=v*slice["height"] + a[i]-=(a[i+3]+a[i+6])/2 + return arr + +class LinInt: + def __init__(self,x1,y1,x2,y2): + self.x1=x1 + self.x2=x2 + self.y1=y1 + self.y2=y2 + + def get(self,x): + return self.y1+(self.y2-self.y1)*(x-self.x1)/(self.x2-self.x1) + +class LinReg: + def __init__(self): + self.n=0 + self.Sx=0 + self.Sy=0 + self.Sxx=0 + self.Sxy=0 + + def add(self,x,y): + self.n+=1 + self.Sx+=x + self.Sy+=y + self.Sxx+=x*x + self.Sxy+=x*y + if self.n>=2: + self.b=(self.n*self.Sxy-self.Sx*self.Sy)/(self.n*self.Sxx-self.Sx*self.Sx) + self.a=self.Sy/self.n-self.b*self.Sx/self.n + + def get(self,x): + return self.a+self.b*x + +def normalize(arr,idx): + l=0 + for i in range(3): + l+=arr[idx+i]*arr[idx+i] + l=math.sqrt(l) + for i in range(3): + arr[idx+i]/=l + return l + +def orthonormalize(arr): + normalize(arr,3) + dot=0 + for i in range(3): + dot+=arr[i+3]*arr[i+6] + for i in range(3): + arr[i+6]-=arr[i+3]*dot + normalize(arr,6) + +if __name__=="__main__": + import json,sys + with open(sys.argv[1]) as f: + series=json.load(f) + propagate(series["slices"]) + with open(sys.argv[2],"w") as f: + json.dump(series,f) diff --git a/PyNutil/read_and_write.py b/PyNutil/read_and_write.py index 5a7e26f8503caf60dd3495c3c57b83d464e8f7a9..728a6abe410d4cba04440d1888eda510f3a6b6ae 100644 --- a/PyNutil/read_and_write.py +++ b/PyNutil/read_and_write.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt import os import nrrd import re - +from .propagation import propagate # related to read and write # this function reads a VisuAlign JSON and returns the slices @@ -19,7 +19,11 @@ def load_visualign_json(filename): for slice in slices: print(slice) slice["nr"] = int(re.search(r"_s(\d+)", slice["filename"]).group(1)) - slice["anchoring"] = slice["ouv"] + if "ouv" in slice: + slice["anchoring"] = slice["ouv"] + else: + slice["anchoring"] = [] + name = os.path.basename(filename) lz_compat_file = { "name":name, "target":vafile["atlas"], @@ -31,16 +35,10 @@ def load_visualign_json(filename): json.dump(lz_compat_file, f, indent=4) else: slices = vafile["slices"] - # overwrite the existing file - name = os.path.basename(filename) - + slices = propagate(slices) return slices -def load_visualign_json(filename): - with open(filename) as f: - vafile = json.load(f) - slices = vafile["slices"] - return slices + # related to read_and_write, used in write_points_to_meshview # this function returns a dictionary of region names def create_region_dict(points, regions): @@ -61,9 +59,9 @@ def write_points(points_dict, filename, info_file): "count": len(points_dict[name]) // 3, "name": str(info_file["name"].values[info_file["idx"] == name][0]), "triplets": points_dict[name], - "r": str(info_file["r"].values[info_file["idx"] == name][0]), - "g": str(info_file["g"].values[info_file["idx"] == name][0]), - "b": str(info_file["b"].values[info_file["idx"] == name][0]), + "r": int(info_file["r"].values[info_file["idx"] == name][0]), + "g": int(info_file["g"].values[info_file["idx"] == name][0]), + "b": int(info_file["b"].values[info_file["idx"] == name][0]), } for name, idx in zip(points_dict.keys(), range(len(points_dict.keys()))) ]