diff --git a/notebooks/training/panopticnets/Nuclear Segmentation - DeepWatershed.ipynb b/notebooks/training/panopticnets/Nuclear Segmentation - DeepWatershed.ipynb index 7258e8c5..3f06ed5c 100644 --- a/notebooks/training/panopticnets/Nuclear Segmentation - DeepWatershed.ipynb +++ b/notebooks/training/panopticnets/Nuclear Segmentation - DeepWatershed.ipynb @@ -30,6 +30,7 @@ "import numpy as np\n", "from skimage.feature import peak_local_max\n", "import tensorflow as tf\n", + "import tempfile\n", "\n", "from deepcell.applications import NuclearSegmentation\n", "from deepcell.image_generators import CroppingDataGenerator\n", @@ -54,7 +55,7 @@ "metadata": {}, "outputs": [], "source": [ - "data_dir = '/notebooks/data'\n", + "data_dir = '/data'\n", "model_path = 'NuclearSegmentation'\n", "metrics_path = 'metrics.yaml'\n", "train_log = 'train_log.csv'" @@ -146,6 +147,7 @@ "outputs": [], "source": [ "# Post processing parameters\n", + "radius = 10\n", "maxima_threshold = 0.1\n", "interior_threshold = 0.01\n", "exclude_border = False\n", @@ -187,6 +189,7 @@ "X_train = histogram_normalization(X_train)\n", "X_val = histogram_normalization(X_val)\n", "\n", + "\n", "# use augmentation for training but not validation\n", "datagen = CroppingDataGenerator(\n", " rotation_range=180,\n", @@ -452,15 +455,72 @@ "## Predict on test data" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare data by using validation data generator" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "X_test = histograph_normalization(X_test)\n", + "X_test = histogram_normalization(X_test)\n", "\n", - "test_images = prediction_model.predict(X_test)" + "# Generator used to crop data and transform y\n", + "test_data = datagen_val.flow(\n", + " {'X': X_test, 'y': y_test},\n", + " seed=seed,\n", + " min_objects=min_objects,\n", + " transforms=transforms,\n", + " transforms_kwargs=transforms_kwargs,\n", + " batch_size=batch_size,\n", + ")\n", + "\n", + "# Generator used to crop y without transform\n", + "test_data_y = datagen_val.flow(\n", + " {'X': y_test, 'y': y_test},\n", + " seed=seed,\n", + " min_objects=min_objects,\n", + " transforms=[],\n", + " transforms_kwargs={},\n", + " batch_size=batch_size,\n", + ")\n", + "\n", + "X_crop, y_crop, y_crop_t = None, None, None\n", + "for i, j in test_data:\n", + " \n", + " X_crop = np.concatenate((X_crop, i), axis=0) if X_crop is not None else i\n", + " # select needed transform as y_crop_t\n", + " y_crop_t = np.concatenate((y_crop_t, j[2]), axis=0) if y_crop_t is not None else j[2]\n", + "\n", + " if len(X_crop)>=len(X_test):\n", + " print(X_crop.shape)\n", + " \n", + " break\n", + "\n", + "for i, _ in test_data_y:\n", + " if y_crop is None:\n", + " y_crop = i\n", + "\n", + " elif len(y_crop)>=len(y_test):\n", + " print(y_crop.shape)\n", + " break\n", + " \n", + " else:\n", + " y_crop = np.concatenate((y_crop, i), axis=0)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Predict and visualize" ] }, { @@ -471,7 +531,8 @@ }, "outputs": [], "source": [ - "index = np.random.choice(X_test.shape[0])\n", + "test_images = prediction_model.predict(X_crop)\n", + "index = np.random.choice(X_crop.shape[0])\n", "print(index)\n", "\n", "fig, axes = plt.subplots(1, 4, figsize=(20, 20))\n", @@ -496,7 +557,7 @@ ")\n", "\n", "# raw image with centroid\n", - "axes[0].imshow(X_test[index, ..., 0])\n", + "axes[0].imshow(X_crop[index, ..., 0])\n", "axes[0].scatter(coords[..., 1], coords[..., 0],\n", " color='r', marker='.', s=10)\n", "\n", @@ -524,7 +585,7 @@ }, "outputs": [], "source": [ - "outputs = model.predict(X_test)\n", + "outputs = model.predict(X_crop)\n", "\n", "y_pred = []\n", "\n", @@ -542,12 +603,36 @@ " y_pred.append(mask[0])\n", "\n", "y_pred = np.stack(y_pred, axis=0)\n", - "y_pred = np.expand_dims(y_pred, axis=-1)\n", - "y_true = y_test.copy()\n", + "y_true = y_crop_t[:, :, :, 0].copy().astype(int)\n", + "y_true = np.expand_dims(y_true, axis=-1)\n", "\n", "m = Metrics('DeepWatershed', seg=False)\n", "m.calc_object_stats(y_true, y_pred)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visual check" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for _ in range(5):\n", + " index = np.random.choice(X_crop.shape[0])\n", + " print(index)\n", + " fig, axes = plt.subplots(1, 3, figsize=(20, 20))\n", + " \n", + " axes[0].imshow(X_crop[index, ..., 0], cmap='jet')\n", + " axes[1].imshow(y_true[index, ..., 0], cmap='jet')\n", + " axes[2].imshow(y_pred[index, ..., 0], cmap='jet')\n", + " plt.show()" + ] } ], "metadata": {