1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- def avg_max_reduce_channel_helper(x, use_concat=True):
-
- assert not isinstance(x, (list, tuple))
- mean_value = paddle.mean(x, axis=1, keepdim=True)
- max_value = paddle.max(x, axis=1, keepdim=True)
- if use_concat:
- res = paddle.concat([mean_value, max_value], axis=1)
- else:
- res = [mean_value, max_value]
- return res
- def avg_max_reduce_channel(x):
-
-
- if not isinstance(x, (list, tuple)):
- return avg_max_reduce_channel_helper(x)
- elif len(x) == 1:
- return avg_max_reduce_channel_helper(x[0])
- else:
- res = []
- for xi in x:
- res.extend(avg_max_reduce_channel_helper(xi, False))
- return paddle.concat(res, axis=1)
|