Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encoder decoder isolated cells #41

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 58 additions & 68 deletions DATASET_smFISH/MERFISH_create_dataset_isolated_cells.ipynb

Large diffs are not rendered by default.

444 changes: 444 additions & 0 deletions DATASET_smFISH/MERFISH_train_VAE_isolated_cells.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions ML_parameters.json
Binary file modified MODULES/__pycache__/cropper_uncropper.cpython-37.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/cropper_uncropper.cpython-38.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/encoders_decoders.cpython-37.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/encoders_decoders.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file modified MODULES/__pycache__/graph_clustering.cpython-38.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/namedtuple.cpython-37.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/namedtuple.cpython-38.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/non_max_suppression.cpython-37.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/non_max_suppression.cpython-38.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/unet_model.cpython-37.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/unet_model.cpython-38.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/unet_parts.cpython-37.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/unet_parts.cpython-38.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/utilities.cpython-37.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/utilities.cpython-38.pyc
Binary file not shown.
Binary file added MODULES/__pycache__/utilities_ml.cpython-37.pyc
Binary file not shown.
Binary file added MODULES/__pycache__/utilities_ml.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified MODULES/__pycache__/vae_model.cpython-37.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/vae_model.cpython-38.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/vae_parts.cpython-37.pyc
Binary file not shown.
Binary file modified MODULES/__pycache__/vae_parts.cpython-38.pyc
Binary file not shown.
1 change: 1 addition & 0 deletions MODULES/cropper_uncropper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F

from .namedtuple import BB


Expand Down
5 changes: 3 additions & 2 deletions MODULES/encoders_decoders.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

from .namedtuple import ZZ
from typing import List, Optional

EPS_STD = 1E-3 # standard_deviation = F.softplus(x) + EPS_STD >= EPS_STD


class MLP_1by1(nn.Module):
""" Use 1x1 convolution, if ch_hidden <= 0 there is NO hidden layer """
def __init__(self, ch_in: int, ch_out: int, ch_hidden: int):
Expand Down Expand Up @@ -228,4 +230,3 @@ def forward(self, x: torch.Tensor) -> ZZ: # this is right
mu = self.compute_mu(x2).view(independent_dim + [self.dim_z])
std = F.softplus(self.compute_std(x2)).view(independent_dim + [self.dim_z])
return ZZ(mu=mu, std=std + EPS_STD)

112 changes: 62 additions & 50 deletions MODULES/graph_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
import leidenalg as la
import igraph as ig
from MODULES.namedtuple import Segmentation, Partition, SparseSimilarity, Suggestion, ConcordancePartition
from typing import Optional, List, Union
from typing import Optional, List
from matplotlib import pyplot as plt
import time
import neptune
from MODULES.utilities_neptune import log_img_and_chart


# I HAVE LEARNED:
# 1. If I use a lot of negihbours then all methods are roughyl equivalent b/c graph becomes ALL-TO-ALL
# 1. If I use a lot of negihbours then all methods are roughly equivalent b/c graph becomes ALL-TO-ALL
# 2. Radius=10 means each pixel has 121 neighbours
# 3. CPM does not suffer from the resolution limit which means that it tends to shave off small part from a cell.
# 4. For now I prefer to use a graph with normalized edges, modularity and single gigantic cluster (i.e. each_cc_component=False)
# 4. For now I prefer to use a graph with normalized edges,
# modularity and single gigantic cluster (i.e. each_cc_component=False)

with torch.no_grad():
class GraphSegmentation(object):
Expand Down Expand Up @@ -47,7 +49,7 @@ def __init__(self, segmentation: Segmentation,
self.device = torch.device("cpu")

self.raw_image = segmentation.raw_image[0].to(self.device)
self.example_integer_mask = segmentation.integer_mask[0, 0].to(self.device) # set batch=0, ch=0
self.example_integer_mask = segmentation.integer_mask[0, 0].to(self.device) # set batch=0, ch=0

# it should be able to handle both DenseSimilarity and SparseSimilarity
b, c, ni, nj = segmentation.integer_mask.shape
Expand All @@ -68,8 +70,8 @@ def __init__(self, segmentation: Segmentation,
self._partition_connected_components = None
self._partition_sample_segmask = None

#TODO: Compute median density of connected components so that resolution parameter is about 1
#self.reference_density = AUCH
# TODO: Compute median density of connected components so that resolution parameter is about 1
# self.reference_density = AUCH

@property
def partition_connected_components(self):
Expand All @@ -80,31 +82,26 @@ def partition_connected_components(self):
dtype=torch.long,
device=self.device)[self.i_coordinate_fg_pixel,
self.j_coordinate_fg_pixel]
self._partition_connected_components = Partition(which="connected",
membership=membership_from_cc,
sizes=torch.bincount(membership_from_cc),
params={})
self._partition_connected_components = Partition(membership=membership_from_cc,
sizes=torch.bincount(membership_from_cc))
return self._partition_connected_components

@property
def partition_sample_segmask(self):
if self._partition_sample_segmask is None:
membership_from_example_segmask = self.example_integer_mask[self.i_coordinate_fg_pixel,
self.j_coordinate_fg_pixel].long()
self._partition_sample_segmask = Partition(which="one_sample",
membership=membership_from_example_segmask,
sizes=torch.bincount(membership_from_example_segmask),
params={})
self._partition_sample_segmask = Partition(membership=membership_from_example_segmask,
sizes=torch.bincount(membership_from_example_segmask))
return self._partition_sample_segmask

def similarity_2_graph(self, similarity: SparseSimilarity,
fg_prob: torch.tensor,
min_fg_prob: float,
min_edge_weight: float,
normalize_graph_edges: bool = True) -> ig.Graph:
normalize_graph_edges: bool) -> ig.Graph:
""" Create the graph from the sparse similarity matrix """

if ~normalize_graph_edges:
if not normalize_graph_edges:
print("WARNING! You are going to create a graph without normalizing the edges by the sqrt of the node degree. \
Are you sure you know what you are doing?!")

Expand Down Expand Up @@ -174,10 +171,12 @@ def similarity_2_graph(self, similarity: SparseSimilarity,
"total_nodes": self.n_fg_pixel},
directed=False)

def partition_2_mask(self, partition: Partition):
segmask = torch.zeros_like(self.index_matrix)
segmask[self.i_coordinate_fg_pixel, self.j_coordinate_fg_pixel] = partition.membership
return segmask
def partition_2_label(self, partition: Partition):
label = torch.zeros_like(self.index_matrix,
dtype=partition.membership.dtype,
device=partition.membership.device)
label[self.i_coordinate_fg_pixel, self.j_coordinate_fg_pixel] = partition.membership
return label

def is_vertex_in_window(self, window: tuple):
""" Same convention as scikit image:
Expand Down Expand Up @@ -263,16 +262,13 @@ def suggest_resolution_parameter(self,
resolutions = numpy.arange(0.5, 10, 0.5) if sweep_range is None else sweep_range
iou = numpy.zeros(resolutions.shape[0], dtype=float)
mi = numpy.zeros_like(iou)
seg = numpy.zeros((resolutions.shape[0], window[2]-window[0], window[3]-window[1]), dtype=numpy.int)
label = numpy.zeros((resolutions.shape[0], window[2]-window[0], window[3]-window[1]), dtype=numpy.int)
delta_n_cells = numpy.zeros(resolutions.shape[0], dtype=numpy.int)
n_cells = numpy.zeros_like(delta_n_cells)
sizes_list = list()

t0 = time.time()
for n, res in enumerate(resolutions):
print("n, res",n,res,time.time()-t0)

if (n%10 == 0) or (n == resolutions.shape[0]-1) :
if (n % 10 == 0) or (n == resolutions.shape[0]-1) :
print("resolution sweep, {0:3d} out of {1:3d}".format(n, resolutions.shape[0]-1))

p_tmp = self.find_partition_leiden(resolution=res,
Expand All @@ -281,30 +277,26 @@ def suggest_resolution_parameter(self,
max_size=max_size,
cpm_or_modularity=cpm_or_modularity,
each_cc_separately=each_cc_separately)
#print("AAAA",time.time()-t0, p_tmp.membership.device, p_tmp.sizes.device)


#TODO: the following lines are very slow for a large graph
#FROM AAAA to CCCC
n_cells[n] = len(p_tmp.sizes)-1
seg[n] = self.partition_2_mask(p_tmp)[window[0]:window[2], window[1]:window[3]].cpu().numpy()
label[n] = self.partition_2_label(p_tmp)[window[0]:window[2], window[1]:window[3]].cpu().numpy()
sizes_list.append(p_tmp.sizes.cpu().numpy())
#print("BBBB",time.time()-t0)


# Conpute concordance
c_tmp: ConcordancePartition = p_tmp.concordance_with_partition(other_partition=other_partition)
delta_n_cells[n] = c_tmp.delta_n
iou[n] = c_tmp.iou
mi[n] = c_tmp.mutual_information
#print("CCCC",time.time()-t0)


i_max = numpy.argmax(iou)
return Suggestion(best_resolution=resolutions[i_max],
best_index=i_max.item(),
sweep_resolution=resolutions,
sweep_mi=mi,
sweep_iou=iou,
sweep_delta_n=delta_n_cells,
sweep_seg_mask=seg,
sweep_seg_mask=label,
sweep_sizes=sizes_list,
sweep_n_cells=n_cells)

Expand All @@ -322,7 +314,6 @@ def find_partition_leiden(self,
The graph can have both normalized and un-normalized weight.
The strong recommendation is to use CPM with normalized edge weight.


The metric can be both cpm or modularity
The results are all similar (provided the resolution parameter is tuned correctly).

Expand All @@ -338,6 +329,8 @@ def find_partition_leiden(self,
To speed up the calculation the graph partitioning can be done separately for each connected components.
This is absolutely ok for CPM metric while a bit questionable for Modularity metric.
It is not likely to make much difference either way.

window has the same convention as scikit image, i.e. window = (min_row, min_col, max_row, max_col)
"""

if cpm_or_modularity == "cpm":
Expand All @@ -354,8 +347,7 @@ def find_partition_leiden(self,
else:
raise Exception("Warning!! Argument not recognized. \
CPM_or_modularity can only be 'CPM' or 'modularity'")



# Subset graph by connected components and windows if necessary
max_label = 0
membership = torch.zeros(self.n_fg_pixel, dtype=torch.long, device=self.device)
Expand All @@ -374,14 +366,12 @@ def find_partition_leiden(self,
# Only if the graph has node I tried to find the partition

print("find partition internal")
start_time=time.time()
p = la.find_partition(graph=g,
partition_type=partition_type,
initial_membership=initial_membership,
weights=g.es['weight'],
n_iterations=n_iterations,
resolution_parameter=resolution)
print("end find partition internal",time.time()-start_time)

labels = torch.tensor(p.membership, device=self.device, dtype=torch.long) + 1
shifted_labels = labels + max_label
Expand All @@ -392,9 +382,23 @@ def find_partition_leiden(self,
return Partition(sizes=torch.bincount(membership),
membership=membership).filter_by_size(min_size=min_size, max_size=max_size)

def QC_on_label(self, old_label, min_area):
""" This function filter the labels by some criteria. For example by min size"""
labels = skimage.measure.label(old_label.cpu(), background=0, return_num=False, connectivity=2)
mydict = skimage.measure.regionprops_table(labels, properties=['label', 'area'])
my_filter = mydict["area"] > min_area

bad_labels = mydict["label"][~my_filter]
old2new = numpy.arange(mydict["label"][-1]+1)
old2new[bad_labels] = 0
new_labels = old2new[labels].astype(numpy.int32)
return new_labels

def plot_partition(self, partition: Optional[Partition] = None,
figsize: Optional[tuple] = (12, 12),
window: Optional[tuple] = None,
experiment: Optional[neptune.experiments.Experiment] = None,
neptune_name: Optional[str] = None,
**kargs) -> torch.tensor:
"""
If partition is None it prints the connected components
Expand All @@ -415,23 +419,31 @@ def plot_partition(self, partition: Optional[Partition] = None,
sizes_fg = sizes_fg[sizes_fg > 0] # since I am filtering the vertex some sizes might become zero
w = window

segmask = self.partition_2_mask(partition)[w[0]:w[2], w[1]:w[3]].cpu().numpy()
raw_img = self.raw_image[0, w[0]:w[2], w[1]:w[3]].cpu().numpy()
label = self.partition_2_label(partition)[w[0]:w[2], w[1]:w[3]].cpu().long().numpy() # shape: w, h
image = self.raw_image[:, w[0]:w[2], w[1]:w[3]].permute(1, 2, 0).cpu().float().numpy() # shape: w, h, ch
if len(image.shape) == 3 and (image.shape[-1] != 3):
image = image[..., 0]

figure, axes = plt.subplots(ncols=2, nrows=2, figsize=figsize)
axes[0, 0].imshow(skimage.color.label2rgb(label=segmask,
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=figsize)
axes[0, 0].imshow(skimage.color.label2rgb(label=label,
bg_label=0))
axes[0, 1].imshow(skimage.color.label2rgb(label=segmask,
image=raw_img,
axes[0, 1].imshow(skimage.color.label2rgb(label=label,
image=image,
alpha=0.25,
bg_label=0))
axes[1, 0].imshow(raw_img, cmap='gray')
axes[1, 0].imshow(image)
axes[1, 1].hist(sizes_fg.cpu(), **kargs)


title_partition = '{0:s}, #cells -> {1:3d}'.format(partition.which, sizes_fg.shape[0])
title_partition = 'Partition, #cells -> '+str(sizes_fg.shape[0])
axes[0, 0].set_title(title_partition)
axes[0, 1].set_title(title_partition)
axes[1, 0].set_title("raw image")
axes[1, 1].set_title("size distribution")

fig.tight_layout()
if neptune_name is not None:
log_img_and_chart(name=neptune_name, fig=fig, experiment=experiment)
plt.close(fig)
return fig


Loading