hrnet.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddleseg.cvlibs import manager, param_init
  19. from paddleseg.models import layers
  20. import ppmatting
  21. __all__ = [
  22. "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
  23. "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", "HRNet_W60", "HRNet_W64"
  24. ]
  25. class HRNet(nn.Layer):
  26. """
  27. The HRNet implementation based on PaddlePaddle.
  28. The original article refers to
  29. Jingdong Wang, et, al. "HRNet:Deep High-Resolution Representation Learning for Visual Recognition"
  30. (https://arxiv.org/pdf/1908.07919.pdf).
  31. Args:
  32. pretrained (str, optional): The path of pretrained model.
  33. stage1_num_modules (int, optional): Number of modules for stage1. Default 1.
  34. stage1_num_blocks (list, optional): Number of blocks per module for stage1. Default (4).
  35. stage1_num_channels (list, optional): Number of channels per branch for stage1. Default (64).
  36. stage2_num_modules (int, optional): Number of modules for stage2. Default 1.
  37. stage2_num_blocks (list, optional): Number of blocks per module for stage2. Default (4, 4).
  38. stage2_num_channels (list, optional): Number of channels per branch for stage2. Default (18, 36).
  39. stage3_num_modules (int, optional): Number of modules for stage3. Default 4.
  40. stage3_num_blocks (list, optional): Number of blocks per module for stage3. Default (4, 4, 4).
  41. stage3_num_channels (list, optional): Number of channels per branch for stage3. Default [18, 36, 72).
  42. stage4_num_modules (int, optional): Number of modules for stage4. Default 3.
  43. stage4_num_blocks (list, optional): Number of blocks per module for stage4. Default (4, 4, 4, 4).
  44. stage4_num_channels (list, optional): Number of channels per branch for stage4. Default (18, 36, 72. 144).
  45. has_se (bool, optional): Whether to use Squeeze-and-Excitation module. Default False.
  46. align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
  47. e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
  48. """
  49. def __init__(self,
  50. input_channels=3,
  51. pretrained=None,
  52. stage1_num_modules=1,
  53. stage1_num_blocks=(4, ),
  54. stage1_num_channels=(64, ),
  55. stage2_num_modules=1,
  56. stage2_num_blocks=(4, 4),
  57. stage2_num_channels=(18, 36),
  58. stage3_num_modules=4,
  59. stage3_num_blocks=(4, 4, 4),
  60. stage3_num_channels=(18, 36, 72),
  61. stage4_num_modules=3,
  62. stage4_num_blocks=(4, 4, 4, 4),
  63. stage4_num_channels=(18, 36, 72, 144),
  64. has_se=False,
  65. align_corners=False,
  66. padding_same=True):
  67. super(HRNet, self).__init__()
  68. self.pretrained = pretrained
  69. self.stage1_num_modules = stage1_num_modules
  70. self.stage1_num_blocks = stage1_num_blocks
  71. self.stage1_num_channels = stage1_num_channels
  72. self.stage2_num_modules = stage2_num_modules
  73. self.stage2_num_blocks = stage2_num_blocks
  74. self.stage2_num_channels = stage2_num_channels
  75. self.stage3_num_modules = stage3_num_modules
  76. self.stage3_num_blocks = stage3_num_blocks
  77. self.stage3_num_channels = stage3_num_channels
  78. self.stage4_num_modules = stage4_num_modules
  79. self.stage4_num_blocks = stage4_num_blocks
  80. self.stage4_num_channels = stage4_num_channels
  81. self.has_se = has_se
  82. self.align_corners = align_corners
  83. self.feat_channels = [i for i in stage4_num_channels]
  84. self.feat_channels = [64] + self.feat_channels
  85. self.conv_layer1_1 = layers.ConvBNReLU(
  86. in_channels=input_channels,
  87. out_channels=64,
  88. kernel_size=3,
  89. stride=2,
  90. padding=1 if not padding_same else 'same',
  91. bias_attr=False)
  92. self.conv_layer1_2 = layers.ConvBNReLU(
  93. in_channels=64,
  94. out_channels=64,
  95. kernel_size=3,
  96. stride=2,
  97. padding=1 if not padding_same else 'same',
  98. bias_attr=False)
  99. self.la1 = Layer1(
  100. num_channels=64,
  101. num_blocks=self.stage1_num_blocks[0],
  102. num_filters=self.stage1_num_channels[0],
  103. has_se=has_se,
  104. name="layer2",
  105. padding_same=padding_same)
  106. self.tr1 = TransitionLayer(
  107. in_channels=[self.stage1_num_channels[0] * 4],
  108. out_channels=self.stage2_num_channels,
  109. name="tr1",
  110. padding_same=padding_same)
  111. self.st2 = Stage(
  112. num_channels=self.stage2_num_channels,
  113. num_modules=self.stage2_num_modules,
  114. num_blocks=self.stage2_num_blocks,
  115. num_filters=self.stage2_num_channels,
  116. has_se=self.has_se,
  117. name="st2",
  118. align_corners=align_corners,
  119. padding_same=padding_same)
  120. self.tr2 = TransitionLayer(
  121. in_channels=self.stage2_num_channels,
  122. out_channels=self.stage3_num_channels,
  123. name="tr2",
  124. padding_same=padding_same)
  125. self.st3 = Stage(
  126. num_channels=self.stage3_num_channels,
  127. num_modules=self.stage3_num_modules,
  128. num_blocks=self.stage3_num_blocks,
  129. num_filters=self.stage3_num_channels,
  130. has_se=self.has_se,
  131. name="st3",
  132. align_corners=align_corners,
  133. padding_same=padding_same)
  134. self.tr3 = TransitionLayer(
  135. in_channels=self.stage3_num_channels,
  136. out_channels=self.stage4_num_channels,
  137. name="tr3",
  138. padding_same=padding_same)
  139. self.st4 = Stage(
  140. num_channels=self.stage4_num_channels,
  141. num_modules=self.stage4_num_modules,
  142. num_blocks=self.stage4_num_blocks,
  143. num_filters=self.stage4_num_channels,
  144. has_se=self.has_se,
  145. name="st4",
  146. align_corners=align_corners,
  147. padding_same=padding_same)
  148. self.init_weight()
  149. def forward(self, x):
  150. feat_list = []
  151. conv1 = self.conv_layer1_1(x)
  152. feat_list.append(conv1)
  153. conv2 = self.conv_layer1_2(conv1)
  154. la1 = self.la1(conv2)
  155. tr1 = self.tr1([la1])
  156. st2 = self.st2(tr1)
  157. tr2 = self.tr2(st2)
  158. st3 = self.st3(tr2)
  159. tr3 = self.tr3(st3)
  160. st4 = self.st4(tr3)
  161. feat_list = feat_list + st4
  162. return feat_list
  163. def init_weight(self):
  164. for layer in self.sublayers():
  165. if isinstance(layer, nn.Conv2D):
  166. param_init.normal_init(layer.weight, std=0.001)
  167. elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
  168. param_init.constant_init(layer.weight, value=1.0)
  169. param_init.constant_init(layer.bias, value=0.0)
  170. if self.pretrained is not None:
  171. ppmatting.utils.load_pretrained_model(self, self.pretrained)
  172. class Layer1(nn.Layer):
  173. def __init__(self,
  174. num_channels,
  175. num_filters,
  176. num_blocks,
  177. has_se=False,
  178. name=None,
  179. padding_same=True):
  180. super(Layer1, self).__init__()
  181. self.bottleneck_block_list = []
  182. for i in range(num_blocks):
  183. bottleneck_block = self.add_sublayer(
  184. "bb_{}_{}".format(name, i + 1),
  185. BottleneckBlock(
  186. num_channels=num_channels if i == 0 else num_filters * 4,
  187. num_filters=num_filters,
  188. has_se=has_se,
  189. stride=1,
  190. downsample=True if i == 0 else False,
  191. name=name + '_' + str(i + 1),
  192. padding_same=padding_same))
  193. self.bottleneck_block_list.append(bottleneck_block)
  194. def forward(self, x):
  195. conv = x
  196. for block_func in self.bottleneck_block_list:
  197. conv = block_func(conv)
  198. return conv
  199. class TransitionLayer(nn.Layer):
  200. def __init__(self, in_channels, out_channels, name=None, padding_same=True):
  201. super(TransitionLayer, self).__init__()
  202. num_in = len(in_channels)
  203. num_out = len(out_channels)
  204. self.conv_bn_func_list = []
  205. for i in range(num_out):
  206. residual = None
  207. if i < num_in:
  208. if in_channels[i] != out_channels[i]:
  209. residual = self.add_sublayer(
  210. "transition_{}_layer_{}".format(name, i + 1),
  211. layers.ConvBNReLU(
  212. in_channels=in_channels[i],
  213. out_channels=out_channels[i],
  214. kernel_size=3,
  215. padding=1 if not padding_same else 'same',
  216. bias_attr=False))
  217. else:
  218. residual = self.add_sublayer(
  219. "transition_{}_layer_{}".format(name, i + 1),
  220. layers.ConvBNReLU(
  221. in_channels=in_channels[-1],
  222. out_channels=out_channels[i],
  223. kernel_size=3,
  224. stride=2,
  225. padding=1 if not padding_same else 'same',
  226. bias_attr=False))
  227. self.conv_bn_func_list.append(residual)
  228. def forward(self, x):
  229. outs = []
  230. for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
  231. if conv_bn_func is None:
  232. outs.append(x[idx])
  233. else:
  234. if idx < len(x):
  235. outs.append(conv_bn_func(x[idx]))
  236. else:
  237. outs.append(conv_bn_func(x[-1]))
  238. return outs
  239. class Branches(nn.Layer):
  240. def __init__(self,
  241. num_blocks,
  242. in_channels,
  243. out_channels,
  244. has_se=False,
  245. name=None,
  246. padding_same=True):
  247. super(Branches, self).__init__()
  248. self.basic_block_list = []
  249. for i in range(len(out_channels)):
  250. self.basic_block_list.append([])
  251. for j in range(num_blocks[i]):
  252. in_ch = in_channels[i] if j == 0 else out_channels[i]
  253. basic_block_func = self.add_sublayer(
  254. "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
  255. BasicBlock(
  256. num_channels=in_ch,
  257. num_filters=out_channels[i],
  258. has_se=has_se,
  259. name=name + '_branch_layer_' + str(i + 1) + '_' +
  260. str(j + 1),
  261. padding_same=padding_same))
  262. self.basic_block_list[i].append(basic_block_func)
  263. def forward(self, x):
  264. outs = []
  265. for idx, input in enumerate(x):
  266. conv = input
  267. for basic_block_func in self.basic_block_list[idx]:
  268. conv = basic_block_func(conv)
  269. outs.append(conv)
  270. return outs
  271. class BottleneckBlock(nn.Layer):
  272. def __init__(self,
  273. num_channels,
  274. num_filters,
  275. has_se,
  276. stride=1,
  277. downsample=False,
  278. name=None,
  279. padding_same=True):
  280. super(BottleneckBlock, self).__init__()
  281. self.has_se = has_se
  282. self.downsample = downsample
  283. self.conv1 = layers.ConvBNReLU(
  284. in_channels=num_channels,
  285. out_channels=num_filters,
  286. kernel_size=1,
  287. bias_attr=False)
  288. self.conv2 = layers.ConvBNReLU(
  289. in_channels=num_filters,
  290. out_channels=num_filters,
  291. kernel_size=3,
  292. stride=stride,
  293. padding=1 if not padding_same else 'same',
  294. bias_attr=False)
  295. self.conv3 = layers.ConvBN(
  296. in_channels=num_filters,
  297. out_channels=num_filters * 4,
  298. kernel_size=1,
  299. bias_attr=False)
  300. if self.downsample:
  301. self.conv_down = layers.ConvBN(
  302. in_channels=num_channels,
  303. out_channels=num_filters * 4,
  304. kernel_size=1,
  305. bias_attr=False)
  306. if self.has_se:
  307. self.se = SELayer(
  308. num_channels=num_filters * 4,
  309. num_filters=num_filters * 4,
  310. reduction_ratio=16,
  311. name=name + '_fc')
  312. self.add = layers.Add()
  313. self.relu = layers.Activation("relu")
  314. def forward(self, x):
  315. residual = x
  316. conv1 = self.conv1(x)
  317. conv2 = self.conv2(conv1)
  318. conv3 = self.conv3(conv2)
  319. if self.downsample:
  320. residual = self.conv_down(x)
  321. if self.has_se:
  322. conv3 = self.se(conv3)
  323. y = self.add(conv3, residual)
  324. y = self.relu(y)
  325. return y
  326. class BasicBlock(nn.Layer):
  327. def __init__(self,
  328. num_channels,
  329. num_filters,
  330. stride=1,
  331. has_se=False,
  332. downsample=False,
  333. name=None,
  334. padding_same=True):
  335. super(BasicBlock, self).__init__()
  336. self.has_se = has_se
  337. self.downsample = downsample
  338. self.conv1 = layers.ConvBNReLU(
  339. in_channels=num_channels,
  340. out_channels=num_filters,
  341. kernel_size=3,
  342. stride=stride,
  343. padding=1 if not padding_same else 'same',
  344. bias_attr=False)
  345. self.conv2 = layers.ConvBN(
  346. in_channels=num_filters,
  347. out_channels=num_filters,
  348. kernel_size=3,
  349. padding=1 if not padding_same else 'same',
  350. bias_attr=False)
  351. if self.downsample:
  352. self.conv_down = layers.ConvBNReLU(
  353. in_channels=num_channels,
  354. out_channels=num_filters,
  355. kernel_size=1,
  356. bias_attr=False)
  357. if self.has_se:
  358. self.se = SELayer(
  359. num_channels=num_filters,
  360. num_filters=num_filters,
  361. reduction_ratio=16,
  362. name=name + '_fc')
  363. self.add = layers.Add()
  364. self.relu = layers.Activation("relu")
  365. def forward(self, x):
  366. residual = x
  367. conv1 = self.conv1(x)
  368. conv2 = self.conv2(conv1)
  369. if self.downsample:
  370. residual = self.conv_down(x)
  371. if self.has_se:
  372. conv2 = self.se(conv2)
  373. y = self.add(conv2, residual)
  374. y = self.relu(y)
  375. return y
  376. class SELayer(nn.Layer):
  377. def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
  378. super(SELayer, self).__init__()
  379. self.pool2d_gap = nn.AdaptiveAvgPool2D(1)
  380. self._num_channels = num_channels
  381. med_ch = int(num_channels / reduction_ratio)
  382. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  383. self.squeeze = nn.Linear(
  384. num_channels,
  385. med_ch,
  386. weight_attr=paddle.ParamAttr(
  387. initializer=nn.initializer.Uniform(-stdv, stdv)))
  388. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  389. self.excitation = nn.Linear(
  390. med_ch,
  391. num_filters,
  392. weight_attr=paddle.ParamAttr(
  393. initializer=nn.initializer.Uniform(-stdv, stdv)))
  394. def forward(self, x):
  395. pool = self.pool2d_gap(x)
  396. pool = paddle.reshape(pool, shape=[-1, self._num_channels])
  397. squeeze = self.squeeze(pool)
  398. squeeze = F.relu(squeeze)
  399. excitation = self.excitation(squeeze)
  400. excitation = F.sigmoid(excitation)
  401. excitation = paddle.reshape(
  402. excitation, shape=[-1, self._num_channels, 1, 1])
  403. out = x * excitation
  404. return out
  405. class Stage(nn.Layer):
  406. def __init__(self,
  407. num_channels,
  408. num_modules,
  409. num_blocks,
  410. num_filters,
  411. has_se=False,
  412. multi_scale_output=True,
  413. name=None,
  414. align_corners=False,
  415. padding_same=True):
  416. super(Stage, self).__init__()
  417. self._num_modules = num_modules
  418. self.stage_func_list = []
  419. for i in range(num_modules):
  420. if i == num_modules - 1 and not multi_scale_output:
  421. stage_func = self.add_sublayer(
  422. "stage_{}_{}".format(name, i + 1),
  423. HighResolutionModule(
  424. num_channels=num_channels,
  425. num_blocks=num_blocks,
  426. num_filters=num_filters,
  427. has_se=has_se,
  428. multi_scale_output=False,
  429. name=name + '_' + str(i + 1),
  430. align_corners=align_corners,
  431. padding_same=padding_same))
  432. else:
  433. stage_func = self.add_sublayer(
  434. "stage_{}_{}".format(name, i + 1),
  435. HighResolutionModule(
  436. num_channels=num_channels,
  437. num_blocks=num_blocks,
  438. num_filters=num_filters,
  439. has_se=has_se,
  440. name=name + '_' + str(i + 1),
  441. align_corners=align_corners,
  442. padding_same=padding_same))
  443. self.stage_func_list.append(stage_func)
  444. def forward(self, x):
  445. out = x
  446. for idx in range(self._num_modules):
  447. out = self.stage_func_list[idx](out)
  448. return out
  449. class HighResolutionModule(nn.Layer):
  450. def __init__(self,
  451. num_channels,
  452. num_blocks,
  453. num_filters,
  454. has_se=False,
  455. multi_scale_output=True,
  456. name=None,
  457. align_corners=False,
  458. padding_same=True):
  459. super(HighResolutionModule, self).__init__()
  460. self.branches_func = Branches(
  461. num_blocks=num_blocks,
  462. in_channels=num_channels,
  463. out_channels=num_filters,
  464. has_se=has_se,
  465. name=name,
  466. padding_same=padding_same)
  467. self.fuse_func = FuseLayers(
  468. in_channels=num_filters,
  469. out_channels=num_filters,
  470. multi_scale_output=multi_scale_output,
  471. name=name,
  472. align_corners=align_corners,
  473. padding_same=padding_same)
  474. def forward(self, x):
  475. out = self.branches_func(x)
  476. out = self.fuse_func(out)
  477. return out
  478. class FuseLayers(nn.Layer):
  479. def __init__(self,
  480. in_channels,
  481. out_channels,
  482. multi_scale_output=True,
  483. name=None,
  484. align_corners=False,
  485. padding_same=True):
  486. super(FuseLayers, self).__init__()
  487. self._actual_ch = len(in_channels) if multi_scale_output else 1
  488. self._in_channels = in_channels
  489. self.align_corners = align_corners
  490. self.residual_func_list = []
  491. for i in range(self._actual_ch):
  492. for j in range(len(in_channels)):
  493. if j > i:
  494. residual_func = self.add_sublayer(
  495. "residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
  496. layers.ConvBN(
  497. in_channels=in_channels[j],
  498. out_channels=out_channels[i],
  499. kernel_size=1,
  500. bias_attr=False))
  501. self.residual_func_list.append(residual_func)
  502. elif j < i:
  503. pre_num_filters = in_channels[j]
  504. for k in range(i - j):
  505. if k == i - j - 1:
  506. residual_func = self.add_sublayer(
  507. "residual_{}_layer_{}_{}_{}".format(
  508. name, i + 1, j + 1, k + 1),
  509. layers.ConvBN(
  510. in_channels=pre_num_filters,
  511. out_channels=out_channels[i],
  512. kernel_size=3,
  513. stride=2,
  514. padding=1 if not padding_same else 'same',
  515. bias_attr=False))
  516. pre_num_filters = out_channels[i]
  517. else:
  518. residual_func = self.add_sublayer(
  519. "residual_{}_layer_{}_{}_{}".format(
  520. name, i + 1, j + 1, k + 1),
  521. layers.ConvBNReLU(
  522. in_channels=pre_num_filters,
  523. out_channels=out_channels[j],
  524. kernel_size=3,
  525. stride=2,
  526. padding=1 if not padding_same else 'same',
  527. bias_attr=False))
  528. pre_num_filters = out_channels[j]
  529. self.residual_func_list.append(residual_func)
  530. def forward(self, x):
  531. outs = []
  532. residual_func_idx = 0
  533. for i in range(self._actual_ch):
  534. residual = x[i]
  535. residual_shape = paddle.shape(residual)[-2:]
  536. for j in range(len(self._in_channels)):
  537. if j > i:
  538. y = self.residual_func_list[residual_func_idx](x[j])
  539. residual_func_idx += 1
  540. y = F.interpolate(
  541. y,
  542. residual_shape,
  543. mode='bilinear',
  544. align_corners=self.align_corners)
  545. residual = residual + y
  546. elif j < i:
  547. y = x[j]
  548. for k in range(i - j):
  549. y = self.residual_func_list[residual_func_idx](y)
  550. residual_func_idx += 1
  551. residual = residual + y
  552. residual = F.relu(residual)
  553. outs.append(residual)
  554. return outs
  555. @manager.BACKBONES.add_component
  556. def HRNet_W18_Small_V1(**kwargs):
  557. model = HRNet(
  558. stage1_num_modules=1,
  559. stage1_num_blocks=[1],
  560. stage1_num_channels=[32],
  561. stage2_num_modules=1,
  562. stage2_num_blocks=[2, 2],
  563. stage2_num_channels=[16, 32],
  564. stage3_num_modules=1,
  565. stage3_num_blocks=[2, 2, 2],
  566. stage3_num_channels=[16, 32, 64],
  567. stage4_num_modules=1,
  568. stage4_num_blocks=[2, 2, 2, 2],
  569. stage4_num_channels=[16, 32, 64, 128],
  570. **kwargs)
  571. return model
  572. @manager.BACKBONES.add_component
  573. def HRNet_W18_Small_V2(**kwargs):
  574. model = HRNet(
  575. stage1_num_modules=1,
  576. stage1_num_blocks=[2],
  577. stage1_num_channels=[64],
  578. stage2_num_modules=1,
  579. stage2_num_blocks=[2, 2],
  580. stage2_num_channels=[18, 36],
  581. stage3_num_modules=3,
  582. stage3_num_blocks=[2, 2, 2],
  583. stage3_num_channels=[18, 36, 72],
  584. stage4_num_modules=2,
  585. stage4_num_blocks=[2, 2, 2, 2],
  586. stage4_num_channels=[18, 36, 72, 144],
  587. **kwargs)
  588. return model
  589. @manager.BACKBONES.add_component
  590. def HRNet_W18(**kwargs):
  591. model = HRNet(
  592. stage1_num_modules=1,
  593. stage1_num_blocks=[4],
  594. stage1_num_channels=[64],
  595. stage2_num_modules=1,
  596. stage2_num_blocks=[4, 4],
  597. stage2_num_channels=[18, 36],
  598. stage3_num_modules=4,
  599. stage3_num_blocks=[4, 4, 4],
  600. stage3_num_channels=[18, 36, 72],
  601. stage4_num_modules=3,
  602. stage4_num_blocks=[4, 4, 4, 4],
  603. stage4_num_channels=[18, 36, 72, 144],
  604. **kwargs)
  605. return model
  606. @manager.BACKBONES.add_component
  607. def HRNet_W30(**kwargs):
  608. model = HRNet(
  609. stage1_num_modules=1,
  610. stage1_num_blocks=[4],
  611. stage1_num_channels=[64],
  612. stage2_num_modules=1,
  613. stage2_num_blocks=[4, 4],
  614. stage2_num_channels=[30, 60],
  615. stage3_num_modules=4,
  616. stage3_num_blocks=[4, 4, 4],
  617. stage3_num_channels=[30, 60, 120],
  618. stage4_num_modules=3,
  619. stage4_num_blocks=[4, 4, 4, 4],
  620. stage4_num_channels=[30, 60, 120, 240],
  621. **kwargs)
  622. return model
  623. @manager.BACKBONES.add_component
  624. def HRNet_W32(**kwargs):
  625. model = HRNet(
  626. stage1_num_modules=1,
  627. stage1_num_blocks=[4],
  628. stage1_num_channels=[64],
  629. stage2_num_modules=1,
  630. stage2_num_blocks=[4, 4],
  631. stage2_num_channels=[32, 64],
  632. stage3_num_modules=4,
  633. stage3_num_blocks=[4, 4, 4],
  634. stage3_num_channels=[32, 64, 128],
  635. stage4_num_modules=3,
  636. stage4_num_blocks=[4, 4, 4, 4],
  637. stage4_num_channels=[32, 64, 128, 256],
  638. **kwargs)
  639. return model
  640. @manager.BACKBONES.add_component
  641. def HRNet_W40(**kwargs):
  642. model = HRNet(
  643. stage1_num_modules=1,
  644. stage1_num_blocks=[4],
  645. stage1_num_channels=[64],
  646. stage2_num_modules=1,
  647. stage2_num_blocks=[4, 4],
  648. stage2_num_channels=[40, 80],
  649. stage3_num_modules=4,
  650. stage3_num_blocks=[4, 4, 4],
  651. stage3_num_channels=[40, 80, 160],
  652. stage4_num_modules=3,
  653. stage4_num_blocks=[4, 4, 4, 4],
  654. stage4_num_channels=[40, 80, 160, 320],
  655. **kwargs)
  656. return model
  657. @manager.BACKBONES.add_component
  658. def HRNet_W44(**kwargs):
  659. model = HRNet(
  660. stage1_num_modules=1,
  661. stage1_num_blocks=[4],
  662. stage1_num_channels=[64],
  663. stage2_num_modules=1,
  664. stage2_num_blocks=[4, 4],
  665. stage2_num_channels=[44, 88],
  666. stage3_num_modules=4,
  667. stage3_num_blocks=[4, 4, 4],
  668. stage3_num_channels=[44, 88, 176],
  669. stage4_num_modules=3,
  670. stage4_num_blocks=[4, 4, 4, 4],
  671. stage4_num_channels=[44, 88, 176, 352],
  672. **kwargs)
  673. return model
  674. @manager.BACKBONES.add_component
  675. def HRNet_W48(**kwargs):
  676. model = HRNet(
  677. stage1_num_modules=1,
  678. stage1_num_blocks=[4],
  679. stage1_num_channels=[64],
  680. stage2_num_modules=1,
  681. stage2_num_blocks=[4, 4],
  682. stage2_num_channels=[48, 96],
  683. stage3_num_modules=4,
  684. stage3_num_blocks=[4, 4, 4],
  685. stage3_num_channels=[48, 96, 192],
  686. stage4_num_modules=3,
  687. stage4_num_blocks=[4, 4, 4, 4],
  688. stage4_num_channels=[48, 96, 192, 384],
  689. **kwargs)
  690. return model
  691. @manager.BACKBONES.add_component
  692. def HRNet_W60(**kwargs):
  693. model = HRNet(
  694. stage1_num_modules=1,
  695. stage1_num_blocks=[4],
  696. stage1_num_channels=[64],
  697. stage2_num_modules=1,
  698. stage2_num_blocks=[4, 4],
  699. stage2_num_channels=[60, 120],
  700. stage3_num_modules=4,
  701. stage3_num_blocks=[4, 4, 4],
  702. stage3_num_channels=[60, 120, 240],
  703. stage4_num_modules=3,
  704. stage4_num_blocks=[4, 4, 4, 4],
  705. stage4_num_channels=[60, 120, 240, 480],
  706. **kwargs)
  707. return model
  708. @manager.BACKBONES.add_component
  709. def HRNet_W64(**kwargs):
  710. model = HRNet(
  711. stage1_num_modules=1,
  712. stage1_num_blocks=[4],
  713. stage1_num_channels=[64],
  714. stage2_num_modules=1,
  715. stage2_num_blocks=[4, 4],
  716. stage2_num_channels=[64, 128],
  717. stage3_num_modules=4,
  718. stage3_num_blocks=[4, 4, 4],
  719. stage3_num_channels=[64, 128, 256],
  720. stage4_num_modules=3,
  721. stage4_num_blocks=[4, 4, 4, 4],
  722. stage4_num_channels=[64, 128, 256, 512],
  723. **kwargs)
  724. return model