mobilenetv3.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. from paddle import ParamAttr
  17. from paddle.regularizer import L2Decay
  18. from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
  19. from paddleseg.cvlibs import manager
  20. from paddleseg.utils import utils, logger
  21. from paddleseg.models import layers
  22. __all__ = [
  23. "MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
  24. "MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0",
  25. "MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35",
  26. "MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75",
  27. "MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25",
  28. "MobileNetV3_large_x1_0_os16"
  29. ]
  30. MODEL_STAGES_PATTERN = {
  31. "MobileNetV3_small": ["blocks[0]", "blocks[2]", "blocks[7]", "blocks[10]"],
  32. "MobileNetV3_large":
  33. ["blocks[0]", "blocks[2]", "blocks[5]", "blocks[11]", "blocks[14]"]
  34. }
  35. # "large", "small" is just for MobinetV3_large, MobileNetV3_small respectively.
  36. # The type of "large" or "small" config is a list. Each element(list) represents a depthwise block, which is composed of k, exp, se, act, s.
  37. # k: kernel_size
  38. # exp: middle channel number in depthwise block
  39. # c: output channel number in depthwise block
  40. # se: whether to use SE block
  41. # act: which activation to use
  42. # s: stride in depthwise block
  43. # d: dilation rate in depthwise block
  44. NET_CONFIG = {
  45. "large": [
  46. # k, exp, c, se, act, s
  47. [3, 16, 16, False, "relu", 1],
  48. [3, 64, 24, False, "relu", 2],
  49. [3, 72, 24, False, "relu", 1], # x4
  50. [5, 72, 40, True, "relu", 2],
  51. [5, 120, 40, True, "relu", 1],
  52. [5, 120, 40, True, "relu", 1], # x8
  53. [3, 240, 80, False, "hardswish", 2],
  54. [3, 200, 80, False, "hardswish", 1],
  55. [3, 184, 80, False, "hardswish", 1],
  56. [3, 184, 80, False, "hardswish", 1],
  57. [3, 480, 112, True, "hardswish", 1],
  58. [3, 672, 112, True, "hardswish", 1], # x16
  59. [5, 672, 160, True, "hardswish", 2],
  60. [5, 960, 160, True, "hardswish", 1],
  61. [5, 960, 160, True, "hardswish", 1], # x32
  62. ],
  63. "small": [
  64. # k, exp, c, se, act, s
  65. [3, 16, 16, True, "relu", 2],
  66. [3, 72, 24, False, "relu", 2],
  67. [3, 88, 24, False, "relu", 1],
  68. [5, 96, 40, True, "hardswish", 2],
  69. [5, 240, 40, True, "hardswish", 1],
  70. [5, 240, 40, True, "hardswish", 1],
  71. [5, 120, 48, True, "hardswish", 1],
  72. [5, 144, 48, True, "hardswish", 1],
  73. [5, 288, 96, True, "hardswish", 2],
  74. [5, 576, 96, True, "hardswish", 1],
  75. [5, 576, 96, True, "hardswish", 1],
  76. ],
  77. "large_os8": [
  78. # k, exp, c, se, act, s, {d}
  79. [3, 16, 16, False, "relu", 1],
  80. [3, 64, 24, False, "relu", 2],
  81. [3, 72, 24, False, "relu", 1], # x4
  82. [5, 72, 40, True, "relu", 2],
  83. [5, 120, 40, True, "relu", 1],
  84. [5, 120, 40, True, "relu", 1], # x8
  85. [3, 240, 80, False, "hardswish", 1],
  86. [3, 200, 80, False, "hardswish", 1, 2],
  87. [3, 184, 80, False, "hardswish", 1, 2],
  88. [3, 184, 80, False, "hardswish", 1, 2],
  89. [3, 480, 112, True, "hardswish", 1, 2],
  90. [3, 672, 112, True, "hardswish", 1, 2],
  91. [5, 672, 160, True, "hardswish", 1, 2],
  92. [5, 960, 160, True, "hardswish", 1, 4],
  93. [5, 960, 160, True, "hardswish", 1, 4],
  94. ],
  95. "small_os8": [
  96. # k, exp, c, se, act, s, {d}
  97. [3, 16, 16, True, "relu", 2],
  98. [3, 72, 24, False, "relu", 2],
  99. [3, 88, 24, False, "relu", 1],
  100. [5, 96, 40, True, "hardswish", 1],
  101. [5, 240, 40, True, "hardswish", 1, 2],
  102. [5, 240, 40, True, "hardswish", 1, 2],
  103. [5, 120, 48, True, "hardswish", 1, 2],
  104. [5, 144, 48, True, "hardswish", 1, 2],
  105. [5, 288, 96, True, "hardswish", 1, 2],
  106. [5, 576, 96, True, "hardswish", 1, 4],
  107. [5, 576, 96, True, "hardswish", 1, 4],
  108. ],
  109. "large_os16": [
  110. # k, exp, c, se, act, s, {d}
  111. [3, 16, 16, False, "relu", 1],
  112. [3, 64, 24, False, "relu", 2],
  113. [3, 72, 24, False, "relu", 1], # x4
  114. [5, 72, 40, True, "relu", 2],
  115. [5, 120, 40, True, "relu", 1],
  116. [5, 120, 40, True, "relu", 1], # x8
  117. [3, 240, 80, False, "hardswish", 2],
  118. [3, 200, 80, False, "hardswish", 1, 1],
  119. [3, 184, 80, False, "hardswish", 1, 1],
  120. [3, 184, 80, False, "hardswish", 1, 1],
  121. [3, 480, 112, True, "hardswish", 1, 1],
  122. [3, 672, 112, True, "hardswish", 1, 1],
  123. [5, 672, 160, True, "hardswish", 1, 2],
  124. [5, 960, 160, True, "hardswish", 1, 2],
  125. [5, 960, 160, True, "hardswish", 1, 2],
  126. ],
  127. }
  128. OUT_INDEX = {"large": [2, 5, 11, 14], "small": [0, 2, 7, 10]}
  129. def _make_divisible(v, divisor=8, min_value=None):
  130. if min_value is None:
  131. min_value = divisor
  132. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  133. if new_v < 0.9 * v:
  134. new_v += divisor
  135. return new_v
  136. def _create_act(act):
  137. if act == "hardswish":
  138. return nn.Hardswish()
  139. elif act == "relu":
  140. return nn.ReLU()
  141. elif act is None:
  142. return None
  143. else:
  144. raise RuntimeError(
  145. "The activation function is not supported: {}".format(act))
  146. class MobileNetV3(nn.Layer):
  147. """
  148. MobileNetV3
  149. Args:
  150. config: list. MobileNetV3 depthwise blocks config.
  151. in_channels (int, optional): The channels of input image. Default: 3.
  152. scale: float=1.0. The coefficient that controls the size of network parameters.
  153. Returns:
  154. model: nn.Layer. Specific MobileNetV3 model depends on args.
  155. """
  156. def __init__(self,
  157. config,
  158. stages_pattern,
  159. out_index,
  160. in_channels=3,
  161. scale=1.0,
  162. class_squeeze=960,
  163. return_last_conv=False,
  164. pretrained=None):
  165. super().__init__()
  166. self.cfg = config
  167. self.out_index = out_index
  168. self.scale = scale
  169. self.pretrained = pretrained
  170. self.class_squeeze = class_squeeze
  171. self.return_last_conv = return_last_conv
  172. inplanes = 16
  173. self.conv = ConvBNLayer(
  174. in_c=in_channels,
  175. out_c=_make_divisible(inplanes * self.scale),
  176. filter_size=3,
  177. stride=2,
  178. padding=1,
  179. num_groups=1,
  180. if_act=True,
  181. act="hardswish")
  182. self.blocks = nn.Sequential(*[
  183. ResidualUnit(
  184. in_c=_make_divisible(inplanes * self.scale if i == 0 else
  185. self.cfg[i - 1][2] * self.scale),
  186. mid_c=_make_divisible(self.scale * exp),
  187. out_c=_make_divisible(self.scale * c),
  188. filter_size=k,
  189. stride=s,
  190. use_se=se,
  191. act=act,
  192. dilation=td[0] if td else 1)
  193. for i, (k, exp, c, se, act, s, *td) in enumerate(self.cfg)
  194. ])
  195. self.last_second_conv = ConvBNLayer(
  196. in_c=_make_divisible(self.cfg[-1][2] * self.scale),
  197. out_c=_make_divisible(self.scale * self.class_squeeze),
  198. filter_size=1,
  199. stride=1,
  200. padding=0,
  201. num_groups=1,
  202. if_act=True,
  203. act="hardswish")
  204. # return feat_channels information
  205. out_channels = [config[idx][2] for idx in out_index]
  206. if return_last_conv:
  207. out_channels.append(class_squeeze)
  208. self.feat_channels = [
  209. _make_divisible(self.scale * c) for c in out_channels
  210. ]
  211. self.mean = paddle.to_tensor([0.485, 0.456, 0.406]).unsqueeze((0, 2, 3))
  212. self.std = paddle.to_tensor([0.229, 0.224, 0.225]).unsqueeze((0, 2, 3))
  213. self.init_res(stages_pattern)
  214. self.init_weight()
  215. def init_weight(self):
  216. if self.pretrained is not None:
  217. utils.load_entire_model(self, self.pretrained)
  218. def init_res(self, stages_pattern, return_patterns=None,
  219. return_stages=None):
  220. if return_patterns and return_stages:
  221. msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
  222. logger.warning(msg)
  223. return_stages = None
  224. if return_stages is True:
  225. return_patterns = stages_pattern
  226. # return_stages is int or bool
  227. if type(return_stages) is int:
  228. return_stages = [return_stages]
  229. if isinstance(return_stages, list):
  230. if max(return_stages) > len(stages_pattern) or min(
  231. return_stages) < 0:
  232. msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
  233. logger.warning(msg)
  234. return_stages = [
  235. val for val in return_stages
  236. if val >= 0 and val < len(stages_pattern)
  237. ]
  238. return_patterns = [stages_pattern[i] for i in return_stages]
  239. def forward(self, x):
  240. x = (x - self.mean) / self.std
  241. x = self.conv(x)
  242. feat_list = []
  243. for idx, block in enumerate(self.blocks):
  244. x = block(x)
  245. if idx in self.out_index:
  246. feat_list.append(x)
  247. x = self.last_second_conv(x)
  248. if self.return_last_conv:
  249. feat_list.append(x)
  250. return feat_list
  251. class ConvBNLayer(nn.Layer):
  252. def __init__(self,
  253. in_c,
  254. out_c,
  255. filter_size,
  256. stride,
  257. padding,
  258. num_groups=1,
  259. if_act=True,
  260. act=None,
  261. dilation=1):
  262. super().__init__()
  263. self.conv = Conv2D(
  264. in_channels=in_c,
  265. out_channels=out_c,
  266. kernel_size=filter_size,
  267. stride=stride,
  268. padding=padding,
  269. groups=num_groups,
  270. bias_attr=False,
  271. dilation=dilation)
  272. self.bn = BatchNorm(
  273. num_channels=out_c,
  274. act=None,
  275. epsilon=0.001,
  276. momentum=0.99,
  277. param_attr=ParamAttr(regularizer=L2Decay(0.0)),
  278. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  279. self.if_act = if_act
  280. self.act = _create_act(act)
  281. def forward(self, x):
  282. x = self.conv(x)
  283. x = self.bn(x)
  284. if self.if_act:
  285. x = self.act(x)
  286. return x
  287. class ResidualUnit(nn.Layer):
  288. def __init__(self,
  289. in_c,
  290. mid_c,
  291. out_c,
  292. filter_size,
  293. stride,
  294. use_se,
  295. act=None,
  296. dilation=1):
  297. super().__init__()
  298. self.if_shortcut = stride == 1 and in_c == out_c
  299. self.if_se = use_se
  300. self.in_c = in_c
  301. self.mid_c = mid_c
  302. # There is not expand conv in pytorch version when in_c equaled to mid_c.
  303. if in_c != mid_c:
  304. self.expand_conv = ConvBNLayer(
  305. in_c=in_c,
  306. out_c=mid_c,
  307. filter_size=1,
  308. stride=1,
  309. padding=0,
  310. if_act=True,
  311. act=act)
  312. self.bottleneck_conv = ConvBNLayer(
  313. in_c=mid_c,
  314. out_c=mid_c,
  315. filter_size=filter_size,
  316. stride=stride,
  317. padding=int((filter_size - 1) // 2) * dilation,
  318. num_groups=mid_c,
  319. if_act=True,
  320. act=act,
  321. dilation=dilation)
  322. if self.if_se:
  323. self.mid_se = SEModule(mid_c)
  324. self.linear_conv = ConvBNLayer(
  325. in_c=mid_c,
  326. out_c=out_c,
  327. filter_size=1,
  328. stride=1,
  329. padding=0,
  330. if_act=False,
  331. act=None)
  332. def forward(self, x):
  333. identity = x
  334. if self.in_c != self.mid_c:
  335. x = self.expand_conv(x)
  336. x = self.bottleneck_conv(x)
  337. if self.if_se:
  338. x = self.mid_se(x)
  339. x = self.linear_conv(x)
  340. if self.if_shortcut:
  341. x = paddle.add(identity, x)
  342. return x
  343. # nn.Hardsigmoid can't transfer "slope" and "offset" in nn.functional.hardsigmoid
  344. class Hardsigmoid(nn.Layer):
  345. def __init__(self, slope=0.2, offset=0.5):
  346. super().__init__()
  347. self.slope = slope
  348. self.offset = offset
  349. def forward(self, x):
  350. return nn.functional.hardsigmoid(
  351. x, slope=self.slope, offset=self.offset)
  352. class SEModule(nn.Layer):
  353. def __init__(self, channel, reduction=4):
  354. super().__init__()
  355. self.avg_pool = AdaptiveAvgPool2D(1)
  356. self.conv1 = Conv2D(
  357. in_channels=channel,
  358. out_channels=_make_divisible(channel // reduction, 8),
  359. kernel_size=1,
  360. stride=1,
  361. padding=0)
  362. self.relu = nn.ReLU()
  363. self.conv2 = Conv2D(
  364. in_channels=_make_divisible(channel // reduction, 8),
  365. out_channels=channel,
  366. kernel_size=1,
  367. stride=1,
  368. padding=0)
  369. self.hardsigmoid = Hardsigmoid(slope=0.1666667, offset=0.5)
  370. def forward(self, x):
  371. identity = x
  372. x = self.avg_pool(x)
  373. x = self.conv1(x)
  374. x = self.relu(x)
  375. x = self.conv2(x)
  376. x = self.hardsigmoid(x)
  377. return paddle.multiply(x=identity, y=x)
  378. @manager.BACKBONES.add_component
  379. def MobileNetV3_small_x0_35(**kwargs):
  380. model = MobileNetV3(
  381. config=NET_CONFIG["small"],
  382. scale=0.35,
  383. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
  384. out_index=OUT_INDEX["small"],
  385. **kwargs)
  386. return model
  387. @manager.BACKBONES.add_component
  388. def MobileNetV3_small_x0_5(**kwargs):
  389. model = MobileNetV3(
  390. config=NET_CONFIG["small"],
  391. scale=0.5,
  392. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
  393. out_index=OUT_INDEX["small"],
  394. **kwargs)
  395. return model
  396. @manager.BACKBONES.add_component
  397. def MobileNetV3_small_x0_75(**kwargs):
  398. model = MobileNetV3(
  399. config=NET_CONFIG["small"],
  400. scale=0.75,
  401. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
  402. out_index=OUT_INDEX["small"],
  403. **kwargs)
  404. return model
  405. @manager.BACKBONES.add_component
  406. def MobileNetV3_small_x1_0(**kwargs):
  407. model = MobileNetV3(
  408. config=NET_CONFIG["small"],
  409. scale=1.0,
  410. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
  411. out_index=OUT_INDEX["small"],
  412. **kwargs)
  413. return model
  414. @manager.BACKBONES.add_component
  415. def MobileNetV3_small_x1_25(**kwargs):
  416. model = MobileNetV3(
  417. config=NET_CONFIG["small"],
  418. scale=1.25,
  419. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
  420. out_index=OUT_INDEX["small"],
  421. **kwargs)
  422. return model
  423. @manager.BACKBONES.add_component
  424. def MobileNetV3_large_x0_35(**kwargs):
  425. model = MobileNetV3(
  426. config=NET_CONFIG["large"],
  427. scale=0.35,
  428. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
  429. out_index=OUT_INDEX["large"],
  430. **kwargs)
  431. return model
  432. @manager.BACKBONES.add_component
  433. def MobileNetV3_large_x0_5(**kwargs):
  434. model = MobileNetV3(
  435. config=NET_CONFIG["large"],
  436. scale=0.5,
  437. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
  438. out_index=OUT_INDEX["large"],
  439. **kwargs)
  440. return model
  441. @manager.BACKBONES.add_component
  442. def MobileNetV3_large_x0_75(**kwargs):
  443. model = MobileNetV3(
  444. config=NET_CONFIG["large"],
  445. scale=0.75,
  446. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
  447. out_index=OUT_INDEX["large"],
  448. **kwargs)
  449. return model
  450. @manager.BACKBONES.add_component
  451. def MobileNetV3_large_x1_0(**kwargs):
  452. model = MobileNetV3(
  453. config=NET_CONFIG["large"],
  454. scale=1.0,
  455. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
  456. out_index=OUT_INDEX["large"],
  457. **kwargs)
  458. return model
  459. @manager.BACKBONES.add_component
  460. def MobileNetV3_large_x1_25(**kwargs):
  461. model = MobileNetV3(
  462. config=NET_CONFIG["large"],
  463. scale=1.25,
  464. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
  465. out_index=OUT_INDEX["large"],
  466. **kwargs)
  467. return model
  468. @manager.BACKBONES.add_component
  469. def MobileNetV3_large_x1_0_os8(**kwargs):
  470. model = MobileNetV3(
  471. config=NET_CONFIG["large_os8"],
  472. scale=1.0,
  473. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
  474. out_index=OUT_INDEX["large"],
  475. **kwargs)
  476. return model
  477. @manager.BACKBONES.add_component
  478. def MobileNetV3_small_x1_0_os8(**kwargs):
  479. model = MobileNetV3(
  480. config=NET_CONFIG["small_os8"],
  481. scale=1.0,
  482. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
  483. out_index=OUT_INDEX["small"],
  484. **kwargs)
  485. return model
  486. @manager.BACKBONES.add_component
  487. def MobileNetV3_large_x1_0_os16(**kwargs):
  488. if 'out_index' in kwargs:
  489. model = MobileNetV3(
  490. config=NET_CONFIG["large_os16"],
  491. scale=1.0,
  492. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
  493. **kwargs)
  494. else:
  495. model = MobileNetV3(
  496. config=NET_CONFIG["large_os16"],
  497. scale=1.0,
  498. stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
  499. out_index=OUT_INDEX["large"],
  500. **kwargs)
  501. return model