model.py 839 B

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import os
  2. import sys
  3. import paddle
  4. import paddleseg
  5. from paddleseg.utils import get_sys_env
  6. import ppmatting
  7. from ppmatting.core import load
  8. from ppmatting.utils import Config, MatBuilder
  9. import configs
  10. current_path = os.path.abspath(os.path.dirname(__file__))
  11. def get_rel_path(path: str):
  12. return os.path.join(current_path, '..', path)
  13. _model: configs.MattingModel
  14. def load_model():
  15. m = configs.get_model()
  16. global _model
  17. _model = m
  18. model_path = m.config
  19. cfg = Config(model_path)
  20. builder = MatBuilder(cfg)
  21. paddleseg.utils.show_env_info()
  22. paddleseg.utils.show_cfg_info(cfg)
  23. paddleseg.utils.set_device("cpu")
  24. m.model = builder.model
  25. m.transforms = ppmatting.transforms.Compose(builder.val_transforms)
  26. load(m.model, m.path)
  27. m.init = True
  28. def get_model():
  29. return _model