Browse Source

feat: 模型可配置

tuon 1 year ago
parent
commit
8eca0cc935
3 changed files with 41 additions and 27 deletions
  1. 1 0
      configs/__init__.py
  2. 35 0
      configs/config.py
  3. 5 27
      tools/model.py

+ 1 - 0
configs/__init__.py

@@ -0,0 +1 @@
+from .config import MattingModel, get_model

+ 35 - 0
configs/config.py

@@ -0,0 +1,35 @@
+import os
+import ppmatting
+
+current_path = os.path.abspath(os.path.dirname(__file__))
+
+
+def get_rel_path(path: str):
+    return os.path.join(current_path, '..', path)
+
+
+class MattingModel:
+    path = ""
+    config = ""
+    model = None
+    transforms: ppmatting.transforms.Compose
+    init = False
+
+    def __init__(self, p, c):
+        self.path = get_rel_path(p)
+        self.config = get_rel_path(c)
+
+
+_model: MattingModel
+
+modelDict = {
+    "ppmattingv2": MattingModel("models/ppmattingv2-stdc1-human_512.pdparams",
+                                "configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml"),
+
+    "ppmatting": MattingModel("models/ppmatting-hrnet_w18-human_512.pdparams",
+                              "configs/quick_start/ppmattingv2-stdc1-human_512.yml")
+}
+
+
+def get_model():
+    return modelDict.get("ppmattingv2")

+ 5 - 27
tools/model.py

@@ -3,14 +3,12 @@ import sys
 
 import paddle
 import paddleseg
-from paddleseg.utils import get_sys_env, logger
-
-from ppmatting.transforms import Compose
+from paddleseg.utils import get_sys_env
 
 import ppmatting
 from ppmatting.core import load
 from ppmatting.utils import Config, MatBuilder
-
+import configs
 current_path = os.path.abspath(os.path.dirname(__file__))
 
 
@@ -18,35 +16,15 @@ def get_rel_path(path: str):
     return os.path.join(current_path, '..', path)
 
 
-class MattingModel:
-    path = ""
-    config = ""
-    model = None
-    transforms: ppmatting.transforms.Compose
-    init = False
-
-    def __init__(self, p, c):
-        self.path = get_rel_path(p)
-        self.config = get_rel_path(c)
-
-
-_model: MattingModel
-
-modelDict = {
-    "ppmattingv2": MattingModel("models/ppmattingv2-stdc1-human_512.pdparams",
-                                "configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml"),
-
-    "ppmatting": MattingModel("models/ppmatting-hrnet_w18-human_512.pdparams",
-                              "configs/quick_start/ppmattingv2-stdc1-human_512.yml")
-}
+_model: configs.MattingModel
 
 
 def load_model():
-    m = modelDict.get("ppmattingv2")
+    m = configs.get_model()
     global _model
     _model = m
 
-    model_path=get_rel_path("configs/quick_start/ppmattingv2-stdc1-human_512.yml")
+    model_path = m.config
     cfg = Config(model_path)
     builder = MatBuilder(cfg)