|
| 1 | +from typing import Iterable, List, Optional, TYPE_CHECKING |
| 2 | +import os |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +from skimage.color import label2rgb |
| 6 | +from skimage.color.colorlabel import DEFAULT_COLORS |
| 7 | + |
| 8 | +import torch |
| 9 | + |
| 10 | +from catalyst import utils |
| 11 | +from catalyst.callbacks import ILoggerCallback |
| 12 | +from catalyst.contrib.tools.tensorboard import SummaryWriter |
| 13 | +from catalyst.contrib.utils.cv.tensor import tensor_to_ndimage |
| 14 | +from catalyst.core.callback import CallbackNode, CallbackOrder |
| 15 | + |
| 16 | +if TYPE_CHECKING: |
| 17 | + from catalyst.core.runner import IRunner |
| 18 | + |
| 19 | +DEFAULT_COLORS = np.array(DEFAULT_COLORS) |
| 20 | + |
| 21 | + |
| 22 | +class DrawMasksCallback(ILoggerCallback): |
| 23 | + """ |
| 24 | + Logger callback draw masks for common segmentation task: image -> masks |
| 25 | + """ |
| 26 | + |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + output_key: str, |
| 30 | + input_image_key: Optional[str] = None, |
| 31 | + input_mask_key: Optional[str] = None, |
| 32 | + mask2show: Optional[Iterable[int]] = None, |
| 33 | + activation: Optional[str] = "Sigmoid", |
| 34 | + log_name: str = "images", |
| 35 | + summary_step: int = 50, |
| 36 | + threshold: float = 0.5, |
| 37 | + ): |
| 38 | + """ |
| 39 | +
|
| 40 | + Args: |
| 41 | + output_key: predicted mask key |
| 42 | + input_image_key: input image key. If None mask will be drawn on |
| 43 | + white background |
| 44 | + input_mask_key: ground truth mask key. If None, will not be drawn |
| 45 | + mask2show: mask indexes to show, if None all mask will be drawn. By |
| 46 | + this parameter you can change the mask order |
| 47 | + activation: An torch.nn activation applied to the outputs. |
| 48 | + Must be one of ``'none'``, ``'Sigmoid'``, ``'Softmax'`` |
| 49 | + log_name: logging name. If you use several such "callbacks", they |
| 50 | + must have different logging names |
| 51 | + summary_step: logging frequency |
| 52 | + threshold: threshold for predicted masks, must be in (0, 1) |
| 53 | + """ |
| 54 | + assert 0 < threshold < 1 |
| 55 | + assert activation in ["none", "Sigmoid", "Softmax2d"] |
| 56 | + super().__init__(order=CallbackOrder.logging, node=CallbackNode.master) |
| 57 | + |
| 58 | + self.input_image_key = input_image_key |
| 59 | + self.input_mask_key = input_mask_key |
| 60 | + self.output_key = output_key |
| 61 | + |
| 62 | + self.mask2show = mask2show |
| 63 | + self.summary_step = summary_step |
| 64 | + self.threshold = threshold |
| 65 | + self.log_name = log_name |
| 66 | + |
| 67 | + if activation == "Sigmoid": |
| 68 | + self.activation = torch.nn.Sigmoid() |
| 69 | + elif activation == "Softmax": |
| 70 | + self.activation = torch.nn.Softmax2d() |
| 71 | + else: |
| 72 | + self.activation = torch.nn.Identity() |
| 73 | + |
| 74 | + self.loggers = {} |
| 75 | + self.step = None # initialization |
| 76 | + |
| 77 | + def on_loader_start(self, runner: "IRunner"): |
| 78 | + """Loader start hook. |
| 79 | +
|
| 80 | + Args: |
| 81 | + runner: current runner |
| 82 | + """ |
| 83 | + if runner.loader_key not in self.loggers: |
| 84 | + log_dir = os.path.join( |
| 85 | + runner.logdir, f"{runner.loader_key}_log/images/" |
| 86 | + ) |
| 87 | + self.loggers[runner.loader_key] = SummaryWriter(log_dir) |
| 88 | + self.step = 0 |
| 89 | + |
| 90 | + def _draw_masks( |
| 91 | + self, |
| 92 | + writer: SummaryWriter, |
| 93 | + global_step: int, |
| 94 | + image_over_predicted_mask: np.ndarray, |
| 95 | + image_over_gt_mask: Optional[np.ndarray] = None, |
| 96 | + ) -> None: |
| 97 | + """ |
| 98 | + Draw image over mask to tensorboard |
| 99 | +
|
| 100 | + Args: |
| 101 | + writer: loader writer |
| 102 | + global_step: global step |
| 103 | + image_over_predicted_mask: image over predicted mask |
| 104 | + image_over_gt_mask: image over ground truth mask |
| 105 | + """ |
| 106 | + if image_over_gt_mask is not None: |
| 107 | + writer.add_image( |
| 108 | + f"{self.log_name} Ground Truth", |
| 109 | + image_over_gt_mask, |
| 110 | + global_step=global_step, |
| 111 | + dataformats="HWC", |
| 112 | + ) |
| 113 | + |
| 114 | + writer.add_image( |
| 115 | + f"{self.log_name} Prediction", |
| 116 | + image_over_predicted_mask, |
| 117 | + global_step=global_step, |
| 118 | + dataformats="HWC", |
| 119 | + ) |
| 120 | + |
| 121 | + def _prob2mask(self, prob_masks: np.ndarray) -> np.ndarray: |
| 122 | + """ |
| 123 | + Convert probability masks into label mask |
| 124 | +
|
| 125 | + Args: |
| 126 | + prob_masks: [n_classes, H, W], probability masks for each class |
| 127 | +
|
| 128 | + Returns: [H, W] label mask |
| 129 | + """ |
| 130 | + mask = np.zeros_like(prob_masks[0], dtype=np.uint8) |
| 131 | + n_classes = prob_masks.shape[0] |
| 132 | + if self.mask2show is not None: |
| 133 | + assert max(self.mask2show) < n_classes |
| 134 | + mask2show = self.mask2show |
| 135 | + else: |
| 136 | + mask2show = range(n_classes) |
| 137 | + |
| 138 | + for i in mask2show: |
| 139 | + prob_mask = prob_masks[i] |
| 140 | + mask[prob_mask >= self.threshold] = i + 1 |
| 141 | + return mask |
| 142 | + |
| 143 | + @staticmethod |
| 144 | + def _get_colors(mask: np.ndarray) -> List[str]: |
| 145 | + """ |
| 146 | + Select colors for mask labels |
| 147 | +
|
| 148 | + Args: |
| 149 | + mask: [H, W] label mask |
| 150 | +
|
| 151 | + Returns: colors for labels |
| 152 | + """ |
| 153 | + colors_labels = np.unique(mask) |
| 154 | + colors_labels = colors_labels[colors_labels > 0] - 1 |
| 155 | + colors = DEFAULT_COLORS[colors_labels % len(DEFAULT_COLORS)] |
| 156 | + return colors |
| 157 | + |
| 158 | + def on_batch_end(self, runner: "IRunner"): |
| 159 | + """Batch end hook. |
| 160 | +
|
| 161 | + Args: |
| 162 | + runner: current runner |
| 163 | + """ |
| 164 | + if self.step % self.summary_step == 0: |
| 165 | + pred_mask = runner.output[self.output_key][0] |
| 166 | + pred_mask = self.activation(pred_mask) |
| 167 | + pred_mask = utils.detach(pred_mask) |
| 168 | + pred_mask = self._prob2mask(pred_mask) |
| 169 | + |
| 170 | + if self.input_mask_key is not None: |
| 171 | + gt_mask = runner.input[self.input_mask_key][0] |
| 172 | + gt_mask = utils.detach(gt_mask) |
| 173 | + gt_mask = self._prob2mask(gt_mask) |
| 174 | + else: |
| 175 | + gt_mask = None |
| 176 | + |
| 177 | + if self.input_image_key is not None: |
| 178 | + image = runner.input[self.input_image_key][0].cpu() |
| 179 | + image = tensor_to_ndimage(image) |
| 180 | + else: |
| 181 | + # white background |
| 182 | + image = np.ones_like(pred_mask, dtype=np.uint8) * 255 |
| 183 | + |
| 184 | + pred_colors = self._get_colors(pred_mask) |
| 185 | + image_over_predicted_mask = label2rgb( |
| 186 | + pred_mask, image, bg_label=0, colors=pred_colors |
| 187 | + ) |
| 188 | + if gt_mask is not None: |
| 189 | + gt_colors = self._get_colors(gt_mask) |
| 190 | + image_over_gt_mask = label2rgb( |
| 191 | + gt_mask, image, bg_label=0, colors=gt_colors |
| 192 | + ) |
| 193 | + else: |
| 194 | + image_over_gt_mask = None |
| 195 | + |
| 196 | + self._draw_masks( |
| 197 | + self.loggers[runner.loader_key], |
| 198 | + runner.global_sample_step, |
| 199 | + image_over_predicted_mask, |
| 200 | + image_over_gt_mask, |
| 201 | + ) |
| 202 | + self.step += 1 |
| 203 | + |
| 204 | + |
| 205 | +__all__ = ["DrawMasksCallback"] |
0 commit comments