export.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. def get_input_spec(model_name, shape, trimap):
  16. """
  17. Get the input spec accoring the model_name.
  18. Args:
  19. model_name (str): The model name
  20. shape (str): The shape of input image
  21. trimap (str): Whether a trimap is required
  22. """
  23. input_spec = [{"img": paddle.static.InputSpec(shape=shape, name='img')}]
  24. if trimap:
  25. shape[1] = 1
  26. input_spec[0]['trimap'] = paddle.static.InputSpec(
  27. shape=shape, name='trimap')
  28. if model_name == 'RVM':
  29. input_spec.append(
  30. paddle.static.InputSpec(
  31. shape=[None, 16, None, None], name='r1'))
  32. input_spec.append(
  33. paddle.static.InputSpec(
  34. shape=[None, 20, None, None], name='r2'))
  35. input_spec.append(
  36. paddle.static.InputSpec(
  37. shape=[None, 40, None, None], name='r3'))
  38. input_spec.append(
  39. paddle.static.InputSpec(
  40. shape=[None, 64, None, None], name='r4'))
  41. input_spec.append(
  42. paddle.static.InputSpec(
  43. shape=[1], name='downsample_ratio'))
  44. return input_spec