| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from .utils import _SimpleSegmentationModel |
| from ._deeplab import ASPPConv, ASPPPooling, ASPP, AtrousSeparableConvolution |
| from .enhanced_modules import EOANetModule |
|
|
|
|
| class EnhancedDeepLabV3(_SimpleSegmentationModel): |
| """ |
| Implements Enhanced DeepLabV3 model with Normalized Multi-Scale Attention and Entropy-Optimized Gating. |
| """ |
| pass |
|
|
|
|
| class EnhancedDeepLabHeadV3Plus(nn.Module): |
| def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36], |
| use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
| super(EnhancedDeepLabHeadV3Plus, self).__init__() |
| self.use_eoaNet = use_eoaNet |
| |
| self.project = nn.Sequential( |
| nn.Conv2d(low_level_channels, 48, 1, bias=False), |
| nn.BatchNorm2d(48), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| self.aspp = ASPP(in_channels, aspp_dilate) |
| |
| |
| if self.use_eoaNet: |
| self.eoaNet = EOANetModule(256, scales=msa_scales, beta=eog_beta) |
| |
| self.classifier = nn.Sequential( |
| nn.Conv2d(304, 256, 3, padding=1, bias=False), |
| nn.BatchNorm2d(256), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(256, num_classes, 1) |
| ) |
| self._init_weight() |
|
|
| def forward(self, feature): |
| low_level_feature = self.project(feature['low_level']) |
| output_feature = self.aspp(feature['out']) |
| |
| |
| if self.use_eoaNet: |
| output_feature = self.eoaNet(output_feature) |
| |
| output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False) |
| return self.classifier(torch.cat([low_level_feature, output_feature], dim=1)) |
| |
| def _init_weight(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight) |
| elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
|
|
|
|
| class EnhancedDeepLabHead(nn.Module): |
| def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36], |
| use_eoaNet=True, msa_scales=[1, 2, 4], eog_beta=0.5): |
| super(EnhancedDeepLabHead, self).__init__() |
| self.use_eoaNet = use_eoaNet |
|
|
| self.aspp = ASPP(in_channels, aspp_dilate) |
| |
| |
| if self.use_eoaNet: |
| self.eoaNet = EOANetModule(256, scales=msa_scales, beta=eog_beta) |
| |
| self.classifier = nn.Sequential( |
| nn.Conv2d(256, 256, 3, padding=1, bias=False), |
| nn.BatchNorm2d(256), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(256, num_classes, 1) |
| ) |
| self._init_weight() |
|
|
| def forward(self, feature): |
| output = self.aspp(feature['out']) |
| |
| |
| if self.use_eoaNet: |
| output = self.eoaNet(output) |
| |
| return self.classifier(output) |
|
|
| def _init_weight(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight) |
| elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
|
|
|
|
| def convert_to_separable_conv(module): |
| new_module = module |
| if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1: |
| new_module = AtrousSeparableConvolution(module.in_channels, |
| module.out_channels, |
| module.kernel_size, |
| module.stride, |
| module.padding, |
| module.dilation, |
| module.bias) |
| for name, child in module.named_children(): |
| new_module.add_module(name, convert_to_separable_conv(child)) |
| return new_module |