Skip to content

Commit 93b19d3

Browse files
authored
Enable any resolution for Unet (#1029)
* Fix type hint for models * Use inference mode in tests * Add test for any resolution (not divisible by 32) * Use inference mode in tests * Enable any res for Unet and better docs * Fix check_input_shape condition * Interpolation for unet
1 parent eaf8be6 commit 93b19d3

File tree

7 files changed

+187
-75
lines changed

7 files changed

+187
-75
lines changed

segmentation_models_pytorch/base/model.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import torch
2+
from typing import TypeVar, Type
23

34
from . import initialization as init
45
from .hub_mixin import SMPHubMixin
56

7+
T = TypeVar("T", bound="SegmentationModel")
8+
69

710
class SegmentationModel(torch.nn.Module, SMPHubMixin):
811
"""Base class for all segmentation models."""
@@ -11,6 +14,11 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin):
1114
# set to False
1215
requires_divisible_input_shape = True
1316

17+
# Fix type-hint for models, to avoid HubMixin signature
18+
def __new__(cls: Type[T], *args, **kwargs) -> T:
19+
instance = super().__new__(cls, *args, **kwargs)
20+
return instance
21+
1422
def initialize(self):
1523
init.initialize_decoder(self.decoder)
1624
init.initialize_head(self.segmentation_head)
@@ -42,7 +50,7 @@ def check_input_shape(self, x):
4250
def forward(self, x):
4351
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
4452

45-
if not torch.jit.is_tracing() or self.requires_divisible_input_shape:
53+
if not torch.jit.is_tracing() and self.requires_divisible_input_shape:
4654
self.check_input_shape(x)
4755

4856
features = self.encoder(x)

segmentation_models_pytorch/decoders/unet/decoder.py

+74-36
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,24 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5+
from typing import Optional, Sequence
56
from segmentation_models_pytorch.base import modules as md
67

78

8-
class DecoderBlock(nn.Module):
9+
class UnetDecoderBlock(nn.Module):
10+
"""A decoder block in the U-Net architecture that performs upsampling and feature fusion."""
11+
912
def __init__(
1013
self,
11-
in_channels,
12-
skip_channels,
13-
out_channels,
14-
use_batchnorm=True,
15-
attention_type=None,
14+
in_channels: int,
15+
skip_channels: int,
16+
out_channels: int,
17+
use_batchnorm: bool = True,
18+
attention_type: Optional[str] = None,
19+
interpolation_mode: str = "nearest",
1620
):
1721
super().__init__()
22+
self.interpolation_mode = interpolation_mode
1823
self.conv1 = md.Conv2dReLU(
1924
in_channels + skip_channels,
2025
out_channels,
@@ -34,19 +39,31 @@ def __init__(
3439
)
3540
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
3641

37-
def forward(self, x, skip=None):
38-
x = F.interpolate(x, scale_factor=2, mode="nearest")
39-
if skip is not None:
40-
x = torch.cat([x, skip], dim=1)
41-
x = self.attention1(x)
42-
x = self.conv1(x)
43-
x = self.conv2(x)
44-
x = self.attention2(x)
45-
return x
42+
def forward(
43+
self,
44+
feature_map: torch.Tensor,
45+
target_height: int,
46+
target_width: int,
47+
skip_connection: Optional[torch.Tensor] = None,
48+
) -> torch.Tensor:
49+
feature_map = F.interpolate(
50+
feature_map,
51+
size=(target_height, target_width),
52+
mode=self.interpolation_mode,
53+
)
54+
if skip_connection is not None:
55+
feature_map = torch.cat([feature_map, skip_connection], dim=1)
56+
feature_map = self.attention1(feature_map)
57+
feature_map = self.conv1(feature_map)
58+
feature_map = self.conv2(feature_map)
59+
feature_map = self.attention2(feature_map)
60+
return feature_map
61+
4662

63+
class UnetCenterBlock(nn.Sequential):
64+
"""Center block of the Unet decoder. Applied to the last feature map of the encoder."""
4765

48-
class CenterBlock(nn.Sequential):
49-
def __init__(self, in_channels, out_channels, use_batchnorm=True):
66+
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
5067
conv1 = md.Conv2dReLU(
5168
in_channels,
5269
out_channels,
@@ -65,14 +82,21 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True):
6582

6683

6784
class UnetDecoder(nn.Module):
85+
"""The decoder part of the U-Net architecture.
86+
87+
Takes encoded features from different stages of the encoder and progressively upsamples them while
88+
combining with skip connections. This helps preserve fine-grained details in the final segmentation.
89+
"""
90+
6891
def __init__(
6992
self,
70-
encoder_channels,
71-
decoder_channels,
72-
n_blocks=5,
73-
use_batchnorm=True,
74-
attention_type=None,
75-
center=False,
93+
encoder_channels: Sequence[int],
94+
decoder_channels: Sequence[int],
95+
n_blocks: int = 5,
96+
use_batchnorm: bool = True,
97+
attention_type: Optional[str] = None,
98+
add_center_block: bool = False,
99+
interpolation_mode: str = "nearest",
76100
):
77101
super().__init__()
78102

@@ -94,31 +118,45 @@ def __init__(
94118
skip_channels = list(encoder_channels[1:]) + [0]
95119
out_channels = decoder_channels
96120

97-
if center:
98-
self.center = CenterBlock(
121+
if add_center_block:
122+
self.center = UnetCenterBlock(
99123
head_channels, head_channels, use_batchnorm=use_batchnorm
100124
)
101125
else:
102126
self.center = nn.Identity()
103127

104128
# combine decoder keyword arguments
105-
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
106-
blocks = [
107-
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
108-
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
109-
]
110-
self.blocks = nn.ModuleList(blocks)
111-
112-
def forward(self, *features):
129+
self.blocks = nn.ModuleList()
130+
for block_in_channels, block_skip_channels, block_out_channels in zip(
131+
in_channels, skip_channels, out_channels
132+
):
133+
block = UnetDecoderBlock(
134+
block_in_channels,
135+
block_skip_channels,
136+
block_out_channels,
137+
use_batchnorm=use_batchnorm,
138+
attention_type=attention_type,
139+
interpolation_mode=interpolation_mode,
140+
)
141+
self.blocks.append(block)
142+
143+
def forward(self, *features: torch.Tensor) -> torch.Tensor:
144+
# spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
145+
spatial_shapes = [feature.shape[2:] for feature in features]
146+
spatial_shapes = spatial_shapes[::-1]
147+
113148
features = features[1:] # remove first skip with same spatial resolution
114149
features = features[::-1] # reverse channels to start from head of encoder
115150

116151
head = features[0]
117-
skips = features[1:]
152+
skip_connections = features[1:]
118153

119154
x = self.center(head)
155+
120156
for i, decoder_block in enumerate(self.blocks):
121-
skip = skips[i] if i < len(skips) else None
122-
x = decoder_block(x, skip)
157+
# upsample to the next spatial shape
158+
height, width = spatial_shapes[i + 1]
159+
skip_connection = skip_connections[i] if i < len(skip_connections) else None
160+
x = decoder_block(x, height, width, skip_connection=skip_connection)
123161

124162
return x

segmentation_models_pytorch/decoders/unet/model.py

+43-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Union, Tuple, Callable
1+
from typing import Any, Optional, Union, Callable, Sequence
22

33
from segmentation_models_pytorch.base import (
44
ClassificationHead,
@@ -12,10 +12,21 @@
1212

1313

1414
class Unet(SegmentationModel):
15-
"""Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
16-
and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
17-
resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation*
18-
for fusing decoder blocks with skip connections.
15+
"""
16+
U-Net is a fully convolutional neural network architecture designed for semantic image segmentation.
17+
18+
It consists of two main parts:
19+
20+
1. An encoder (downsampling path) that extracts increasingly abstract features
21+
2. A decoder (upsampling path) that gradually recovers spatial details
22+
23+
The key is the use of skip connections between corresponding encoder and decoder layers.
24+
These connections allow the decoder to access fine-grained details from earlier encoder layers,
25+
which helps produce more precise segmentation masks.
26+
27+
The skip connections work by concatenating feature maps from the encoder directly into the decoder
28+
at corresponding resolutions. This helps preserve important spatial information that would
29+
otherwise be lost during the encoding process.
1930
2031
Args:
2132
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
@@ -33,6 +44,8 @@ class Unet(SegmentationModel):
3344
Available options are **True, False, "inplace"**
3445
decoder_attention_type: Attention module used in decoder of the model. Available options are
3546
**None** and **scse** (https://arxiv.org/abs/1808.08127).
47+
decoder_interpolation_mode: Interpolation mode used in decoder of the model. Available options are
48+
**"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
3649
in_channels: A number of input channels for the model, default is 3 (RGB images)
3750
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
3851
activation: An activation function to apply after the final convolution layer.
@@ -51,20 +64,41 @@ class Unet(SegmentationModel):
5164
Returns:
5265
``torch.nn.Module``: Unet
5366
67+
Example:
68+
.. code-block:: python
69+
70+
import torch
71+
import segmentation_models_pytorch as smp
72+
73+
model = smp.Unet("resnet18", encoder_weights="imagenet", classes=5)
74+
model.eval()
75+
76+
# generate random images
77+
images = torch.rand(2, 3, 256, 256)
78+
79+
with torch.inference_mode():
80+
mask = model(images)
81+
82+
print(mask.shape)
83+
# torch.Size([2, 5, 256, 256])
84+
5485
.. _Unet:
5586
https://arxiv.org/abs/1505.04597
5687
5788
"""
5889

90+
requires_divisible_input_shape = False
91+
5992
@supports_config_loading
6093
def __init__(
6194
self,
6295
encoder_name: str = "resnet34",
6396
encoder_depth: int = 5,
6497
encoder_weights: Optional[str] = "imagenet",
6598
decoder_use_batchnorm: bool = True,
66-
decoder_channels: Tuple[int, ...] = (256, 128, 64, 32, 16),
99+
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
67100
decoder_attention_type: Optional[str] = None,
101+
decoder_interpolation_mode: str = "nearest",
68102
in_channels: int = 3,
69103
classes: int = 1,
70104
activation: Optional[Union[str, Callable]] = None,
@@ -81,13 +115,15 @@ def __init__(
81115
**kwargs,
82116
)
83117

118+
add_center_block = encoder_name.startswith("vgg")
84119
self.decoder = UnetDecoder(
85120
encoder_channels=self.encoder.out_channels,
86121
decoder_channels=decoder_channels,
87122
n_blocks=encoder_depth,
88123
use_batchnorm=decoder_use_batchnorm,
89-
center=True if encoder_name.startswith("vgg") else False,
124+
add_center_block=add_center_block,
90125
attention_type=decoder_attention_type,
126+
interpolation_mode=decoder_interpolation_mode,
91127
)
92128

93129
self.segmentation_head = SegmentationHead(

tests/encoders/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_in_channels(self):
8282
encoder.eval()
8383

8484
# forward
85-
with torch.no_grad():
85+
with torch.inference_mode():
8686
encoder.forward(sample)
8787

8888
def test_depth(self):
@@ -110,7 +110,7 @@ def test_depth(self):
110110
encoder.eval()
111111

112112
# forward
113-
with torch.no_grad():
113+
with torch.inference_mode():
114114
features = encoder.forward(sample)
115115

116116
# check number of features
@@ -187,7 +187,7 @@ def test_dilated(self):
187187
encoder.eval()
188188

189189
# forward
190-
with torch.no_grad():
190+
with torch.inference_mode():
191191
features = encoder.forward(sample)
192192

193193
height_strides, width_strides = self.get_features_output_strides(

0 commit comments

Comments
 (0)