stdcnet.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. import paddle.nn as nn
  17. from paddleseg.utils import utils
  18. from paddleseg.cvlibs import manager, param_init
  19. from paddleseg.models.layers.layer_libs import SyncBatchNorm
  20. __all__ = ["STDC1", "STDC2", "STDC_Small", "STDC_Tiny"]
  21. class STDCNet(nn.Layer):
  22. """
  23. The STDCNet implementation based on PaddlePaddle.
  24. The original article refers to Meituan
  25. Fan, Mingyuan, et al. "Rethinking BiSeNet For Real-time Semantic Segmentation."
  26. (https://arxiv.org/abs/2104.13188)
  27. Args:
  28. base(int, optional): base channels. Default: 64.
  29. layers(list, optional): layers numbers list. It determines STDC block numbers of STDCNet's stage3\4\5. Defualt: [4, 5, 3].
  30. block_num(int,optional): block_num of features block. Default: 4.
  31. type(str,optional): feature fusion method "cat"/"add". Default: "cat".
  32. pretrained(str, optional): the path of pretrained model.
  33. """
  34. def __init__(self,
  35. input_channels=3,
  36. channels=[32, 64, 256, 512, 1024],
  37. layers=[4, 5, 3],
  38. block_num=4,
  39. type="cat",
  40. pretrained=None):
  41. super(STDCNet, self).__init__()
  42. if type == "cat":
  43. block = CatBottleneck
  44. elif type == "add":
  45. block = AddBottleneck
  46. self.input_channels = input_channels
  47. self.layers = layers
  48. self.feat_channels = channels
  49. self.features = self._make_layers(channels, layers, block_num, block)
  50. self.pretrained = pretrained
  51. self.init_weight()
  52. def forward(self, x):
  53. """
  54. forward function for feature extract.
  55. """
  56. out_feats = []
  57. x = self.features[0](x)
  58. out_feats.append(x)
  59. x = self.features[1](x)
  60. out_feats.append(x)
  61. idx = [[2, 2 + self.layers[0]],
  62. [2 + self.layers[0], 2 + sum(self.layers[0:2])],
  63. [2 + sum(self.layers[0:2]), 2 + sum(self.layers)]]
  64. for start_idx, end_idx in idx:
  65. for i in range(start_idx, end_idx):
  66. x = self.features[i](x)
  67. out_feats.append(x)
  68. return out_feats
  69. def _make_layers(self, channels, layers, block_num, block):
  70. features = []
  71. features += [ConvBNRelu(self.input_channels, channels[0], 3, 2)]
  72. features += [ConvBNRelu(channels[0], channels[1], 3, 2)]
  73. for i, layer in enumerate(layers):
  74. for j in range(layer):
  75. if i == 0 and j == 0:
  76. features.append(
  77. block(channels[i + 1], channels[i + 2], block_num, 2))
  78. elif j == 0:
  79. features.append(
  80. block(channels[i + 1], channels[i + 2], block_num, 2))
  81. else:
  82. features.append(
  83. block(channels[i + 2], channels[i + 2], block_num, 1))
  84. return nn.Sequential(*features)
  85. def init_weight(self):
  86. for layer in self.sublayers():
  87. if isinstance(layer, nn.Conv2D):
  88. param_init.normal_init(layer.weight, std=0.001)
  89. elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
  90. param_init.constant_init(layer.weight, value=1.0)
  91. param_init.constant_init(layer.bias, value=0.0)
  92. if self.pretrained is not None:
  93. utils.load_pretrained_model(self, self.pretrained)
  94. class ConvBNRelu(nn.Layer):
  95. def __init__(self, in_planes, out_planes, kernel=3, stride=1):
  96. super(ConvBNRelu, self).__init__()
  97. self.conv = nn.Conv2D(
  98. in_planes,
  99. out_planes,
  100. kernel_size=kernel,
  101. stride=stride,
  102. padding=kernel // 2,
  103. bias_attr=False)
  104. self.bn = SyncBatchNorm(out_planes, data_format='NCHW')
  105. self.relu = nn.ReLU()
  106. def forward(self, x):
  107. out = self.relu(self.bn(self.conv(x)))
  108. return out
  109. class AddBottleneck(nn.Layer):
  110. def __init__(self, in_planes, out_planes, block_num=3, stride=1):
  111. super(AddBottleneck, self).__init__()
  112. assert block_num > 1, "block number should be larger than 1."
  113. self.conv_list = nn.LayerList()
  114. self.stride = stride
  115. if stride == 2:
  116. self.avd_layer = nn.Sequential(
  117. nn.Conv2D(
  118. out_planes // 2,
  119. out_planes // 2,
  120. kernel_size=3,
  121. stride=2,
  122. padding=1,
  123. groups=out_planes // 2,
  124. bias_attr=False),
  125. nn.BatchNorm2D(out_planes // 2), )
  126. self.skip = nn.Sequential(
  127. nn.Conv2D(
  128. in_planes,
  129. in_planes,
  130. kernel_size=3,
  131. stride=2,
  132. padding=1,
  133. groups=in_planes,
  134. bias_attr=False),
  135. nn.BatchNorm2D(in_planes),
  136. nn.Conv2D(
  137. in_planes, out_planes, kernel_size=1, bias_attr=False),
  138. nn.BatchNorm2D(out_planes), )
  139. stride = 1
  140. for idx in range(block_num):
  141. if idx == 0:
  142. self.conv_list.append(
  143. ConvBNRelu(
  144. in_planes, out_planes // 2, kernel=1))
  145. elif idx == 1 and block_num == 2:
  146. self.conv_list.append(
  147. ConvBNRelu(
  148. out_planes // 2, out_planes // 2, stride=stride))
  149. elif idx == 1 and block_num > 2:
  150. self.conv_list.append(
  151. ConvBNRelu(
  152. out_planes // 2, out_planes // 4, stride=stride))
  153. elif idx < block_num - 1:
  154. self.conv_list.append(
  155. ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes
  156. // int(math.pow(2, idx + 1))))
  157. else:
  158. self.conv_list.append(
  159. ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes
  160. // int(math.pow(2, idx))))
  161. def forward(self, x):
  162. out_list = []
  163. out = x
  164. for idx, conv in enumerate(self.conv_list):
  165. if idx == 0 and self.stride == 2:
  166. out = self.avd_layer(conv(out))
  167. else:
  168. out = conv(out)
  169. out_list.append(out)
  170. if self.stride == 2:
  171. x = self.skip(x)
  172. return paddle.concat(out_list, axis=1) + x
  173. class CatBottleneck(nn.Layer):
  174. def __init__(self, in_planes, out_planes, block_num=3, stride=1):
  175. super(CatBottleneck, self).__init__()
  176. assert block_num > 1, "block number should be larger than 1."
  177. self.conv_list = nn.LayerList()
  178. self.stride = stride
  179. if stride == 2:
  180. self.avd_layer = nn.Sequential(
  181. nn.Conv2D(
  182. out_planes // 2,
  183. out_planes // 2,
  184. kernel_size=3,
  185. stride=2,
  186. padding=1,
  187. groups=out_planes // 2,
  188. bias_attr=False),
  189. nn.BatchNorm2D(out_planes // 2), )
  190. self.skip = nn.AvgPool2D(kernel_size=3, stride=2, padding=1)
  191. stride = 1
  192. for idx in range(block_num):
  193. if idx == 0:
  194. self.conv_list.append(
  195. ConvBNRelu(
  196. in_planes, out_planes // 2, kernel=1))
  197. elif idx == 1 and block_num == 2:
  198. self.conv_list.append(
  199. ConvBNRelu(
  200. out_planes // 2, out_planes // 2, stride=stride))
  201. elif idx == 1 and block_num > 2:
  202. self.conv_list.append(
  203. ConvBNRelu(
  204. out_planes // 2, out_planes // 4, stride=stride))
  205. elif idx < block_num - 1:
  206. self.conv_list.append(
  207. ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes
  208. // int(math.pow(2, idx + 1))))
  209. else:
  210. self.conv_list.append(
  211. ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes
  212. // int(math.pow(2, idx))))
  213. def forward(self, x):
  214. out_list = []
  215. out1 = self.conv_list[0](x)
  216. for idx, conv in enumerate(self.conv_list[1:]):
  217. if idx == 0:
  218. if self.stride == 2:
  219. out = conv(self.avd_layer(out1))
  220. else:
  221. out = conv(out1)
  222. else:
  223. out = conv(out)
  224. out_list.append(out)
  225. if self.stride == 2:
  226. out1 = self.skip(out1)
  227. out_list.insert(0, out1)
  228. out = paddle.concat(out_list, axis=1)
  229. return out
  230. @manager.BACKBONES.add_component
  231. def STDC2(**kwargs):
  232. model = STDCNet(
  233. channels=[32, 64, 256, 512, 1024], layers=[4, 5, 3], **kwargs)
  234. return model
  235. @manager.BACKBONES.add_component
  236. def STDC1(**kwargs):
  237. model = STDCNet(
  238. channels=[32, 64, 256, 512, 1024], layers=[2, 2, 2], **kwargs)
  239. return model
  240. @manager.BACKBONES.add_component
  241. def STDC_Small(**kwargs):
  242. model = STDCNet(channels=[32, 32, 64, 128, 256], layers=[4, 5, 3], **kwargs)
  243. return model
  244. @manager.BACKBONES.add_component
  245. def STDC_Tiny(**kwargs):
  246. model = STDCNet(channels=[32, 32, 64, 128, 256], layers=[2, 2, 2], **kwargs)
  247. return model