model.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import os
  2. import sys
  3. import paddle
  4. import paddleseg
  5. from paddleseg.utils import get_sys_env, logger
  6. from ppmatting.transforms import Compose
  7. import ppmatting
  8. from ppmatting.core import load
  9. from ppmatting.utils import Config, MatBuilder
  10. current_path = os.path.abspath(os.path.dirname(__file__))
  11. def get_rel_path(path: str):
  12. return os.path.join(current_path, '..', path)
  13. class MattingModel:
  14. path = ""
  15. config = ""
  16. model: paddle.nn.layer.Layer
  17. transforms: Compose
  18. init = False
  19. def __init__(self, p, c):
  20. self.path = get_rel_path(p)
  21. self.config = get_rel_path(c)
  22. _model: MattingModel
  23. modelDict = {
  24. "ppmattingv2": MattingModel("models/ppmattingv2-stdc1-human_512.pdparams",
  25. "configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml"),
  26. "ppmatting": MattingModel("models/ppmatting-hrnet_w18-human_512.pdparams",
  27. "configs/quick_start/ppmattingv2-stdc1-human_512.yml")
  28. }
  29. def load_model():
  30. m = modelDict.get("ppmattingv2")
  31. global _model
  32. _model = m
  33. model_path=get_rel_path("configs/quick_start/ppmattingv2-stdc1-human_512.yml")
  34. cfg = Config(model_path)
  35. builder = MatBuilder(cfg)
  36. paddleseg.utils.show_env_info()
  37. paddleseg.utils.show_cfg_info(cfg)
  38. paddleseg.utils.set_device("cpu")
  39. m.model = builder.model
  40. m.transforms = ppmatting.transforms.Compose(builder.val_transforms)
  41. load(m.model, m.path)
  42. m.init = True
  43. def get_model():
  44. return _model