Skip to content

Commit 467057a

Browse files
Set use_batchnorm default to None so that default use_norm
Make warning visible by changing filter and add a test for it Fix test before after so that the value is looked and not the shape of tensor
1 parent 10d496a commit 467057a

File tree

8 files changed

+22
-9
lines changed

8 files changed

+22
-9
lines changed

segmentation_models_pytorch/base/modules.py

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def normalize_decoder_norm(decoder_use_batchnorm: Union[bool, str, None], decode
4040
warnings.warn(
4141
"The usage of use_batchnorm is deprecated. Please modify your code for use_norm",
4242
DeprecationWarning,
43+
stacklevel=2
4344
)
4445
if decoder_use_batchnorm is True:
4546
decoder_use_norm = {"type": "batchnorm"}

segmentation_models_pytorch/decoders/linknet/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(
7979
encoder_name: str = "resnet34",
8080
encoder_depth: int = 5,
8181
encoder_weights: Optional[str] = "imagenet",
82-
decoder_use_batchnorm: Union[bool, str, None] = True,
82+
decoder_use_batchnorm: Union[bool, str, None] = None,
8383
decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
8484
in_channels: int = 3,
8585
classes: int = 1,

segmentation_models_pytorch/decoders/manet/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
encoder_name: str = "resnet34",
8383
encoder_depth: int = 5,
8484
encoder_weights: Optional[str] = "imagenet",
85-
decoder_use_batchnorm: Union[bool, str, None] = True,
85+
decoder_use_batchnorm: Union[bool, str, None] = None,
8686
decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
8787
decoder_channels: List[int] = (256, 128, 64, 32, 16),
8888
decoder_pab_channels: int = 64,

segmentation_models_pytorch/decoders/pspnet/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
encoder_weights: Optional[str] = "imagenet",
8282
encoder_depth: int = 3,
8383
psp_out_channels: int = 512,
84-
psp_use_batchnorm: Union[bool, str, None] = True,
84+
psp_use_batchnorm: Union[bool, str, None] = None,
8585
decoder_use_norm: Union[bool, str, Dict[str, Any], None] = "batchnorm",
8686
psp_dropout: float = 0.2,
8787
in_channels: int = 3,

segmentation_models_pytorch/decoders/unet/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(
114114
encoder_name: str = "resnet34",
115115
encoder_depth: int = 5,
116116
encoder_weights: Optional[str] = "imagenet",
117-
decoder_use_batchnorm: Union[bool, str, None] = True,
117+
decoder_use_batchnorm: Union[bool, str, None] = None,
118118
decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
119119
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
120120
decoder_attention_type: Optional[str] = None,

segmentation_models_pytorch/decoders/unetplusplus/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
encoder_name: str = "resnet34",
8484
encoder_depth: int = 5,
8585
encoder_weights: Optional[str] = "imagenet",
86-
decoder_use_batchnorm: Union[bool, str, None] = True,
86+
decoder_use_batchnorm: Union[bool, str, None] = None,
8787
decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
8888
decoder_channels: List[int] = (256, 128, 64, 32, 16),
8989
decoder_attention_type: Optional[str] = None,

tests/encoders/test_batchnorm_deprecation.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
@pytest.mark.parametrize("decoder_option", [True, False, "inplace"])
1111
def test_seg_models_before_after_use_norm(model_name, decoder_option):
1212
torch.manual_seed(42)
13-
model_decoder_batchnorm = create_model(model_name, "mobilenet_v2", None, decoder_use_batchnorm=decoder_option)
13+
with pytest.warns(DeprecationWarning):
14+
model_decoder_batchnorm = create_model(
15+
model_name,
16+
"mobilenet_v2",
17+
None,
18+
decoder_use_batchnorm=decoder_option
19+
)
1420
torch.manual_seed(42)
1521
model_decoder_norm = create_model(model_name, "mobilenet_v2", None, decoder_use_batchnorm=None, decoder_use_norm=decoder_option)
1622

@@ -21,7 +27,13 @@ def test_seg_models_before_after_use_norm(model_name, decoder_option):
2127
@pytest.mark.parametrize("decoder_option", [True, False, "inplace"])
2228
def test_pspnet_before_after_use_norm(decoder_option):
2329
torch.manual_seed(42)
24-
model_decoder_batchnorm = create_model("pspnet", "mobilenet_v2", None, psp_use_batchnorm=decoder_option)
30+
with pytest.warns(DeprecationWarning):
31+
model_decoder_batchnorm = create_model(
32+
"pspnet",
33+
"mobilenet_v2",
34+
None,
35+
psp_use_batchnorm=decoder_option
36+
)
2537
torch.manual_seed(42)
2638
model_decoder_norm = create_model("pspnet", "mobilenet_v2", None, psp_use_batchnorm=None, decoder_use_norm=decoder_option)
2739

tests/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def check_run_test_on_diff_or_main(filepath_patterns: List[str]):
6060
return False
6161

6262

63-
def check_two_models_strictly_equal(model_a, model_b):
63+
def check_two_models_strictly_equal(model_a: torch.nn.Module, model_b: torch.nn.Module) -> None:
6464
for (k1, v1), (k2, v2) in zip(model_a.state_dict().items(),
6565
model_b.state_dict().items()):
6666
assert k1 == k2, f"Key mismatch: {k1} != {k2}"
67-
assert v1.shape == v2.shape, f"Shape mismatch in {k1}: {v1.shape} != {v2.shape}"
67+
assert (v1 == v2).all(), f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}"

0 commit comments

Comments
 (0)