Skip to content

Adding DPT #1079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
78ba0e8
Initial timm vit encoder commit
vedantdalimkar Feb 28, 2025
2c38de6
Add DPT model and update logic for TimmViTEncoder class
vedantdalimkar Mar 2, 2025
5599409
Removed redudant documentation
vedantdalimkar Mar 2, 2025
c47bdfb
Added intitial test and some minor code modifications
vedantdalimkar Mar 5, 2025
71e2acb
Code refactor
vedantdalimkar Mar 8, 2025
e85836d
Added weight conversion script
vedantdalimkar Mar 22, 2025
35cb060
Moved conversion script to appropriate location
vedantdalimkar Mar 22, 2025
aa84f4e
Added logic in timm table generation for adding ViT encoders for DPT
Mar 22, 2025
67c4a75
Ruff formatting
vedantdalimkar Mar 22, 2025
85f22fb
Code revision
vedantdalimkar Mar 26, 2025
ef48032
Remove unnecessary comment
vedantdalimkar Mar 27, 2025
28204ad
Simplify ViT encoder
qubvel Apr 5, 2025
1b9a6f6
Refactor ProjectionReadout
qubvel Apr 5, 2025
334cfbb
Refactor modeling DPT
qubvel Apr 6, 2025
7e1ef3b
Support more encoders
qubvel Apr 6, 2025
d65c0f7
Refactor a bit conversion, added validation
qubvel Apr 6, 2025
0a62fe0
Fixup
qubvel Apr 6, 2025
e3238ae
Split forward for timm_vit
qubvel Apr 6, 2025
df4d087
Rename readout, remove feature_dim
qubvel Apr 6, 2025
8bcb0ed
refactor + add transform
qubvel Apr 6, 2025
6ba6746
Fixup
qubvel Apr 6, 2025
8fd8c77
Refine docs a bit
qubvel Apr 6, 2025
9bf1fd2
Refine docs
qubvel Apr 6, 2025
0e9170f
Refine model size a bit and docs
qubvel Apr 6, 2025
a0aa5a8
Add to docs
qubvel Apr 6, 2025
6cfd3be
Add note
qubvel Apr 6, 2025
d4b162d
Remove txt
qubvel Apr 6, 2025
5fe80a5
Fix doc
qubvel Apr 6, 2025
0a14972
Fix docstring
qubvel Apr 6, 2025
5b28978
Fixing list in activation
qubvel Apr 6, 2025
0ed621c
Fixing list
qubvel Apr 6, 2025
6207310
Fixing list
qubvel Apr 6, 2025
19eeebe
Fixup, fix type hint
qubvel Apr 6, 2025
f2e3f89
Merge branch 'main' into pr/vedantdalimkar/1079
qubvel Apr 6, 2025
1257c4b
Add to README
qubvel Apr 6, 2025
21a164a
Add example
qubvel Apr 6, 2025
8d3ed4f
Add decoder_readout according to initial impl
qubvel Apr 7, 2025
4eb6ec3
Tests update
vedantdalimkar Apr 7, 2025
165b9c0
Fix encoder tests
qubvel Apr 7, 2025
5603707
Fix DPT tests
qubvel Apr 7, 2025
9518964
Refactor a bit
qubvel Apr 7, 2025
38cb944
Tests
qubvel Apr 7, 2025
17d3328
Update gen test models
qubvel Apr 7, 2025
83b9655
Revert gitignore
qubvel Apr 7, 2025
343fbe0
Fix test
qubvel Apr 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions encoders_table.md
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove this file

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
|Encoder |Pretrained weights |Params, M |Script |Compile |Export |
|--------------------------------|:------------------------------:|:------------------------------:|:------------------------------:|:------------------------------:|:------------------------------:|
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .decoders.pan import PAN
from .decoders.upernet import UPerNet
from .decoders.segformer import Segformer
from .decoders.dpt import DPT
from .base.hub_mixin import from_pretrained

from .__version__ import __version__
Expand All @@ -34,6 +35,7 @@
PAN,
UPerNet,
Segformer,
DPT,
]
MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES}

Expand Down Expand Up @@ -84,6 +86,7 @@ def create_model(
"PAN",
"UPerNet",
"Segformer",
"DPT",
"from_pretrained",
"create_model",
"__version__",
Expand Down
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/decoders/dpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .model import DPT

__all__ = ["DPT"]
267 changes: 267 additions & 0 deletions segmentation_models_pytorch/decoders/dpt/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
import torch
import torch.nn as nn


def _get_feature_processing_out_channels(encoder_name: str) -> list[int]:
"""
Get the output embedding dimensions for the features after decoder processing
"""

encoder_name = encoder_name.lower()
# Output channels for hybrid ViT encoder after feature processing
if "vit" in encoder_name and "resnet" in encoder_name:
return [256, 512, 768, 768]

Check warning on line 13 in segmentation_models_pytorch/decoders/dpt/decoder.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/dpt/decoder.py#L13

Added line #L13 was not covered by tests

# Output channels for ViT-large,ViT-huge,ViT-giant encoders after feature processing
if "vit" in encoder_name and any(
[variant in encoder_name for variant in ["huge", "large", "giant"]]
):
return [256, 512, 1024, 1024]

Check warning on line 19 in segmentation_models_pytorch/decoders/dpt/decoder.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/dpt/decoder.py#L19

Added line #L19 was not covered by tests

# Output channels for ViT-base and other encoders after feature processing
return [96, 192, 384, 768]


class Transpose(nn.Module):
def __init__(self, dim0: int, dim1: int):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1

def forward(self, x: torch.Tensor):
return torch.transpose(x, dim0=self.dim0, dim1=self.dim1)


class ProjectionReadout(nn.Module):
"""
Concatenates the cls tokens with the features to make use of the global information aggregated in the cls token.
Projects the combined feature map to the original embedding dimension using a MLP
"""

def __init__(self, in_features: int, encoder_output_stride: int):
super().__init__()
self.project = nn.Sequential(
nn.Linear(in_features=2 * in_features, out_features=in_features), nn.GELU()
)

self.flatten = nn.Flatten(start_dim=2)
self.transpose = Transpose(dim0=1, dim1=2)
self.encoder_output_stride = encoder_output_stride

def forward(self, feature: torch.Tensor, cls_token: torch.Tensor):
batch_size, _, height_dim, width_dim = feature.shape
feature = self.flatten(feature)
feature = self.transpose(feature)

cls_token = cls_token.expand_as(feature)

features = torch.cat([feature, cls_token], dim=2)
features = self.project(features)
features = self.transpose(features)

features = features.view(batch_size, -1, height_dim, width_dim)
return features


class IgnoreReadout(nn.Module):
def __init__(self):
super().__init__()

Check warning on line 68 in segmentation_models_pytorch/decoders/dpt/decoder.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/dpt/decoder.py#L68

Added line #L68 was not covered by tests

def forward(self, feature: torch.Tensor, cls_token: torch.Tensor):
return feature

Check warning on line 71 in segmentation_models_pytorch/decoders/dpt/decoder.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/dpt/decoder.py#L71

Added line #L71 was not covered by tests


class FeatureProcessBlock(nn.Module):
"""
Processes the features such that they have progressively increasing embedding size and progressively decreasing
spatial dimension
"""

def __init__(
self, embed_dim: int, feature_dim: int, out_channel: int, upsample_factor: int
):
super().__init__()

self.project_to_out_channel = nn.Conv2d(
in_channels=embed_dim, out_channels=out_channel, kernel_size=1
)

if upsample_factor > 1.0:
self.upsample = nn.ConvTranspose2d(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=int(upsample_factor),
stride=int(upsample_factor),
)

elif upsample_factor == 1.0:
self.upsample = nn.Identity()

else:
self.upsample = nn.Conv2d(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=3,
stride=int(1 / upsample_factor),
padding=1,
)

self.project_to_feature_dim = nn.Conv2d(
in_channels=out_channel, out_channels=feature_dim, kernel_size=3, padding=1
)

def forward(self, x: torch.Tensor):
x = self.project_to_out_channel(x)
x = self.upsample(x)
x = self.project_to_feature_dim(x)

return x


class ResidualConvBlock(nn.Module):
def __init__(self, feature_dim: int):
super().__init__()
self.conv_block = nn.Sequential(
nn.ReLU(),
nn.Conv2d(
in_channels=feature_dim,
out_channels=feature_dim,
kernel_size=3,
padding=1,
bias=False,
),
nn.BatchNorm2d(num_features=feature_dim),
nn.ReLU(),
nn.Conv2d(
in_channels=feature_dim,
out_channels=feature_dim,
kernel_size=3,
padding=1,
bias=False,
),
nn.BatchNorm2d(num_features=feature_dim),
)

def forward(self, x: torch.Tensor):
return x + self.conv_block(x)


class FusionBlock(nn.Module):
"""
Fuses the processed encoder features in a residual manner and upsamples them
"""

def __init__(self, feature_dim: int):
super().__init__()
self.residual_conv_block1 = ResidualConvBlock(feature_dim=feature_dim)
self.residual_conv_block2 = ResidualConvBlock(feature_dim=feature_dim)
self.project = nn.Conv2d(
in_channels=feature_dim, out_channels=feature_dim, kernel_size=1
)
self.activation = nn.ReLU()

def forward(self, feature: torch.Tensor, preceding_layer_feature: torch.Tensor):
feature = self.residual_conv_block1(feature)

if preceding_layer_feature is not None:
feature += preceding_layer_feature

feature = self.residual_conv_block2(feature)

feature = nn.functional.interpolate(
feature, scale_factor=2, align_corners=True, mode="bilinear"
)
feature = self.project(feature)
feature = self.activation(feature)

return feature


class DPTDecoder(nn.Module):
"""
Decoder part for DPT

Processes the encoder features and class tokens (if encoder has class_tokens) to have spatial downsampling ratios of
[1/32,1/16,1/8,1/4] relative to the input image spatial dimension.

The decoder then fuses these features in a residual manner and progressively upsamples them by a factor of 2 so that the
output has a downsampling ratio of 1/2 relative to the input image spatial dimension

"""

def __init__(
self,
encoder_name: str,
transformer_embed_dim: int,
encoder_output_stride: int,
feature_dim: int = 256,
encoder_depth: int = 4,
cls_token_supported: bool = False,
):
super().__init__()

self.cls_token_supported = cls_token_supported

# If encoder has cls token, then concatenate it with the features along the embedding dimension and project it
# back to the feature_dim dimension. Else, ignore the non-existent cls token

if cls_token_supported:
self.readout_blocks = nn.ModuleList(
[
ProjectionReadout(
in_features=transformer_embed_dim,
encoder_output_stride=encoder_output_stride,
)
for _ in range(encoder_depth)
]
)
else:
self.readout_blocks = [IgnoreReadout() for _ in range(encoder_depth)]

Check warning on line 219 in segmentation_models_pytorch/decoders/dpt/decoder.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/dpt/decoder.py#L219

Added line #L219 was not covered by tests

upsample_factors = [
(encoder_output_stride / 2 ** (index + 2))
for index in range(0, encoder_depth)
]
feature_processing_out_channels = _get_feature_processing_out_channels(
encoder_name
)
if encoder_depth < len(feature_processing_out_channels):
feature_processing_out_channels = feature_processing_out_channels[
:encoder_depth
]

self.feature_processing_blocks = nn.ModuleList(
[
FeatureProcessBlock(
transformer_embed_dim, feature_dim, out_channel, upsample_factor
)
for upsample_factor, out_channel in zip(
upsample_factors, feature_processing_out_channels
)
]
)

self.fusion_blocks = nn.ModuleList(
[FusionBlock(feature_dim=feature_dim) for _ in range(encoder_depth)]
)

def forward(
self, features: list[torch.Tensor], cls_tokens: list[torch.Tensor]
) -> torch.Tensor:
processed_features = []

# Process the encoder features to scale of [1/32,1/16,1/8,1/4]
for index, (feature, cls_token) in enumerate(zip(features, cls_tokens)):
readout_feature = self.readout_blocks[index](feature, cls_token)
processed_feature = self.feature_processing_blocks[index](readout_feature)
processed_features.append(processed_feature)

preceding_layer_feature = None

# Fusion and progressive upsampling starting from the last processed feature
processed_features = processed_features[::-1]
for fusion_block, feature in zip(self.fusion_blocks, processed_features):
out = fusion_block(feature, preceding_layer_feature)
preceding_layer_feature = out

return out
Loading
Loading