Skip to content

Commit 1aeabb7

Browse files
DokholyanScitator
andauthored
Draw mask callback (#999)
* add DynamicBalanceClassSampler * add DynamicBalanceClassSampler: add usage example * add DynamicBalanceClassSampler: add tests * Update catalyst/data/tests/test_sampler.py * Update catalyst/data/tests/test_sampler.py * add DynamicBalanceClassSampler: debag tests * update sampler: add mode * add example notebook * sampler: fixes * samler: docs * DynamicBalanceClassSampler: fixes * change import order * change import order * add draw_masks_callback * fix legacy * fix import * fix import * fixes + white background * fix codestyle * fix bag * add draw_masks_callback * fix color selection * fix tensorboard * fix tensorboard * fix imports * fix catalyst import * add draw_masks_callback * fix init * add draw_casks callback to docs * add draw_masks_callack to pipeline * fix changelog * rename keys * fix keys * fix activation keys Co-authored-by: Sergey Kolesnikov <[email protected]>
1 parent ace4e96 commit 1aeabb7

File tree

6 files changed

+232
-2
lines changed

6 files changed

+232
-2
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010
### Added
1111

1212
- CVS Logger ([#1005](https://github.com/catalyst-team/catalyst/pull/1005))
13+
- DrawMasksCallback ([#999](https://github.com/catalyst-team/catalyst/pull/999))
1314
- ([#1002](https://github.com/catalyst-team/catalyst/pull/1002))
1415
- a few docs
1516
- ([#998](https://github.com/catalyst-team/catalyst/pull/998))

catalyst/contrib/callbacks/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@
7171
try:
7272
import imageio
7373
from catalyst.contrib.callbacks.mask_inference import InferMaskCallback
74+
from catalyst.contrib.callbacks.draw_masks_callback import (
75+
DrawMasksCallback,
76+
)
7477
except ModuleNotFoundError as ex:
7578
if SETTINGS.cv_required:
7679
logger.warning(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
from typing import Iterable, List, Optional, TYPE_CHECKING
2+
import os
3+
4+
import numpy as np
5+
from skimage.color import label2rgb
6+
from skimage.color.colorlabel import DEFAULT_COLORS
7+
8+
import torch
9+
10+
from catalyst import utils
11+
from catalyst.callbacks import ILoggerCallback
12+
from catalyst.contrib.tools.tensorboard import SummaryWriter
13+
from catalyst.contrib.utils.cv.tensor import tensor_to_ndimage
14+
from catalyst.core.callback import CallbackNode, CallbackOrder
15+
16+
if TYPE_CHECKING:
17+
from catalyst.core.runner import IRunner
18+
19+
DEFAULT_COLORS = np.array(DEFAULT_COLORS)
20+
21+
22+
class DrawMasksCallback(ILoggerCallback):
23+
"""
24+
Logger callback draw masks for common segmentation task: image -> masks
25+
"""
26+
27+
def __init__(
28+
self,
29+
output_key: str,
30+
input_image_key: Optional[str] = None,
31+
input_mask_key: Optional[str] = None,
32+
mask2show: Optional[Iterable[int]] = None,
33+
activation: Optional[str] = "Sigmoid",
34+
log_name: str = "images",
35+
summary_step: int = 50,
36+
threshold: float = 0.5,
37+
):
38+
"""
39+
40+
Args:
41+
output_key: predicted mask key
42+
input_image_key: input image key. If None mask will be drawn on
43+
white background
44+
input_mask_key: ground truth mask key. If None, will not be drawn
45+
mask2show: mask indexes to show, if None all mask will be drawn. By
46+
this parameter you can change the mask order
47+
activation: An torch.nn activation applied to the outputs.
48+
Must be one of ``'none'``, ``'Sigmoid'``, ``'Softmax'``
49+
log_name: logging name. If you use several such "callbacks", they
50+
must have different logging names
51+
summary_step: logging frequency
52+
threshold: threshold for predicted masks, must be in (0, 1)
53+
"""
54+
assert 0 < threshold < 1
55+
assert activation in ["none", "Sigmoid", "Softmax2d"]
56+
super().__init__(order=CallbackOrder.logging, node=CallbackNode.master)
57+
58+
self.input_image_key = input_image_key
59+
self.input_mask_key = input_mask_key
60+
self.output_key = output_key
61+
62+
self.mask2show = mask2show
63+
self.summary_step = summary_step
64+
self.threshold = threshold
65+
self.log_name = log_name
66+
67+
if activation == "Sigmoid":
68+
self.activation = torch.nn.Sigmoid()
69+
elif activation == "Softmax":
70+
self.activation = torch.nn.Softmax2d()
71+
else:
72+
self.activation = torch.nn.Identity()
73+
74+
self.loggers = {}
75+
self.step = None # initialization
76+
77+
def on_loader_start(self, runner: "IRunner"):
78+
"""Loader start hook.
79+
80+
Args:
81+
runner: current runner
82+
"""
83+
if runner.loader_key not in self.loggers:
84+
log_dir = os.path.join(
85+
runner.logdir, f"{runner.loader_key}_log/images/"
86+
)
87+
self.loggers[runner.loader_key] = SummaryWriter(log_dir)
88+
self.step = 0
89+
90+
def _draw_masks(
91+
self,
92+
writer: SummaryWriter,
93+
global_step: int,
94+
image_over_predicted_mask: np.ndarray,
95+
image_over_gt_mask: Optional[np.ndarray] = None,
96+
) -> None:
97+
"""
98+
Draw image over mask to tensorboard
99+
100+
Args:
101+
writer: loader writer
102+
global_step: global step
103+
image_over_predicted_mask: image over predicted mask
104+
image_over_gt_mask: image over ground truth mask
105+
"""
106+
if image_over_gt_mask is not None:
107+
writer.add_image(
108+
f"{self.log_name} Ground Truth",
109+
image_over_gt_mask,
110+
global_step=global_step,
111+
dataformats="HWC",
112+
)
113+
114+
writer.add_image(
115+
f"{self.log_name} Prediction",
116+
image_over_predicted_mask,
117+
global_step=global_step,
118+
dataformats="HWC",
119+
)
120+
121+
def _prob2mask(self, prob_masks: np.ndarray) -> np.ndarray:
122+
"""
123+
Convert probability masks into label mask
124+
125+
Args:
126+
prob_masks: [n_classes, H, W], probability masks for each class
127+
128+
Returns: [H, W] label mask
129+
"""
130+
mask = np.zeros_like(prob_masks[0], dtype=np.uint8)
131+
n_classes = prob_masks.shape[0]
132+
if self.mask2show is not None:
133+
assert max(self.mask2show) < n_classes
134+
mask2show = self.mask2show
135+
else:
136+
mask2show = range(n_classes)
137+
138+
for i in mask2show:
139+
prob_mask = prob_masks[i]
140+
mask[prob_mask >= self.threshold] = i + 1
141+
return mask
142+
143+
@staticmethod
144+
def _get_colors(mask: np.ndarray) -> List[str]:
145+
"""
146+
Select colors for mask labels
147+
148+
Args:
149+
mask: [H, W] label mask
150+
151+
Returns: colors for labels
152+
"""
153+
colors_labels = np.unique(mask)
154+
colors_labels = colors_labels[colors_labels > 0] - 1
155+
colors = DEFAULT_COLORS[colors_labels % len(DEFAULT_COLORS)]
156+
return colors
157+
158+
def on_batch_end(self, runner: "IRunner"):
159+
"""Batch end hook.
160+
161+
Args:
162+
runner: current runner
163+
"""
164+
if self.step % self.summary_step == 0:
165+
pred_mask = runner.output[self.output_key][0]
166+
pred_mask = self.activation(pred_mask)
167+
pred_mask = utils.detach(pred_mask)
168+
pred_mask = self._prob2mask(pred_mask)
169+
170+
if self.input_mask_key is not None:
171+
gt_mask = runner.input[self.input_mask_key][0]
172+
gt_mask = utils.detach(gt_mask)
173+
gt_mask = self._prob2mask(gt_mask)
174+
else:
175+
gt_mask = None
176+
177+
if self.input_image_key is not None:
178+
image = runner.input[self.input_image_key][0].cpu()
179+
image = tensor_to_ndimage(image)
180+
else:
181+
# white background
182+
image = np.ones_like(pred_mask, dtype=np.uint8) * 255
183+
184+
pred_colors = self._get_colors(pred_mask)
185+
image_over_predicted_mask = label2rgb(
186+
pred_mask, image, bg_label=0, colors=pred_colors
187+
)
188+
if gt_mask is not None:
189+
gt_colors = self._get_colors(gt_mask)
190+
image_over_gt_mask = label2rgb(
191+
gt_mask, image, bg_label=0, colors=gt_colors
192+
)
193+
else:
194+
image_over_gt_mask = None
195+
196+
self._draw_masks(
197+
self.loggers[runner.loader_key],
198+
runner.global_sample_step,
199+
image_over_predicted_mask,
200+
image_over_gt_mask,
201+
)
202+
self.step += 1
203+
204+
205+
__all__ = ["DrawMasksCallback"]

docs/api/callbacks.rst

+7
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,13 @@ InferMaskCallback
262262
:undoc-members:
263263
:show-inheritance:
264264

265+
DrawMasksCallback
266+
~~~~~~~~~~~~~~~~~~~~~~
267+
.. automodule:: catalyst.contrib.callbacks.draw_masks_callback
268+
:members:
269+
:undoc-members:
270+
:show-inheritance:
271+
265272
MixupCallback
266273
~~~~~~~~~~~~~~~~~~~~~~
267274
.. automodule:: catalyst.contrib.callbacks.mixup_callback

examples/notebooks/segmentation-tutorial.ipynb

+9-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
"outputs": [],
5858
"source": [
5959
"# Catalyst\n",
60-
"!pip install catalyst==20.10.1\n",
60+
"!pip install catalyst==20.12\n",
6161
"\n",
6262
"# for augmentations\n",
6363
"!pip install albumentations==0.4.3\n",
@@ -804,6 +804,7 @@
804804
"source": [
805805
"from catalyst.dl import DiceCallback, IouCallback, \\\n",
806806
" CriterionCallback, MetricAggregationCallback\n",
807+
"from catalyst.contrib.callbacks import DrawMasksCallback\n",
807808
"\n",
808809
"callbacks = [\n",
809810
" # Each criterion is calculated separately.\n",
@@ -834,6 +835,12 @@
834835
" # metrics\n",
835836
" DiceCallback(input_key=\"mask\"),\n",
836837
" IouCallback(input_key=\"mask\"),\n",
838+
" # visualization\n",
839+
" DrawMasksCallback(output_key='logits',\n",
840+
" input_image_key='image',\n",
841+
" input_mask_key='mask',\n",
842+
" summary_step=50\n",
843+
" )\n",
837844
"]\n",
838845
"\n",
839846
"runner.train(\n",
@@ -1225,7 +1232,7 @@
12251232
"name": "python",
12261233
"nbconvert_exporter": "python",
12271234
"pygments_lexer": "ipython3",
1228-
"version": "3.7.7"
1235+
"version": "3.7.4"
12291236
}
12301237
},
12311238
"nbformat": 4,

tests/_tests_cv_segmentation/config.yml

+7
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,12 @@ stages:
7878
callback: IouCallback
7979
input_key: mask
8080

81+
visualise:
82+
callback: DrawMasksCallback
83+
input_image_key: "image"
84+
input_mask_key: "mask"
85+
output_key: "logits"
86+
summary_step: 300
87+
8188
saver:
8289
callback: CheckpointCallback

0 commit comments

Comments
 (0)