Просмотр исходного кода

更换模型,这个效果不错

tuonian 1 год назад
Родитель
Сommit
590a61ecbb

+ 3 - 1
.gitignore

@@ -119,4 +119,6 @@ test_tipc/web/models/
 # EISeg
 EISeg/eiseg/config/setting.txt
 
-/outputs
+/outputs
+/uploads
+!/uploads/1.txt

+ 2 - 0
install.sh

@@ -0,0 +1,2 @@
+#!/bin/sh
+python -m pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

+ 3 - 3
main.py

@@ -1,6 +1,6 @@
 from flask import Flask, request
 import os
-from tools import predict
+import tools
 from werkzeug.utils import secure_filename
 
 app = Flask(__name__)
@@ -29,10 +29,10 @@ def seg():
     file_path = get_upload_file_path(filename)
     file.save(file_path)
 
-    predict.seg(file_path, save_dir)
+    tools.seg(file_path, save_dir)
     return '{"code": 1}'
 
 
 if __name__ == '__main__':
-    predict.load_model()
+    tools.load_model()
     app.run(port=20201, host="0.0.0.0", debug=True)

BIN
models/ppmattingv2-stdc1-human_512.pdparams


BIN
outputs/20211009105824_alpha.png


BIN
outputs/20211009105824_rgba.png


BIN
outputs/cat_20_alpha.png


BIN
outputs/cat_20_rgba.png


BIN
outputs/demo111_alpha.png


BIN
outputs/demo111_rgba.png


BIN
outputs/head_120_alpha.png


BIN
outputs/head_120_rgba.png


BIN
outputs/human_alpha.png


BIN
outputs/human_rgba.png


+ 31 - 0
tools/__init__.py

@@ -0,0 +1,31 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import sys
+
+import paddle
+import paddleseg
+from paddleseg.cvlibs import manager
+from paddleseg.utils import get_sys_env, logger
+
+LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(os.path.join(LOCAL_PATH, '..'))
+
+manager.BACKBONES._components_dict.clear()
+manager.TRANSFORMS._components_dict.clear()
+
+from .model import load_model
+from .img import seg,replace

+ 87 - 0
tools/img.py

@@ -0,0 +1,87 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from paddleseg.utils import logger
+
+from tools.model import get_model
+import cv2
+import numpy as np
+from ppmatting.core import predict
+from ppmatting.utils import get_image_list
+
+current_path = os.path.abspath(os.path.dirname(__file__))
+
+
+def get_rel_path(path: str):
+    return "{}/../{}".format(current_path, path)
+
+
+def seg(img_path: str, save_dir: str):
+    image_list, image_dir = get_image_list(img_path)
+    logger.info('Number of predict images = {}'.format(len(image_list)))
+
+    model = get_model()
+
+    return predict(
+        model=model.model,
+        model_path=model.path,
+        transforms=model.transforms,
+        image_list=image_list,
+        image_dir=image_dir,
+        trimap_list=None,
+        save_dir=save_dir,
+        fg_estimate=True)
+
+
+def replace(img_path: str, save_dir: str, bg_color = "r"):
+    image_list, image_dir = get_image_list(img_path)
+    model = get_model()
+    alpha, fg = predict(
+        model=model.model,
+        model_path=model.path,
+        transforms=model.transforms,
+        image_list=image_list,
+        trimap_list=None,
+        save_dir=save_dir,
+        fg_estimate=False)
+
+    img_ori = cv2.imread(img_path)
+    bg = get_bg(bg_color, img_ori.shape)
+    alpha = alpha / 255.0
+    alpha = alpha[:, :, np.newaxis]
+    com = alpha * fg + (1 - alpha) * bg
+    com = com.astype('uint8')
+    com_save_path = os.path.join(save_dir, os.path.basename(img_path))
+    cv2.imwrite(com_save_path, com)
+
+
+def get_bg(background, img_shape):
+    bg = np.zeros(img_shape)
+    if background == 'r':
+        bg[:, :, 2] = 255
+    elif background is None or background == 'g':
+        bg[:, :, 1] = 255
+    elif background == 'b':
+        bg[:, :, 0] = 255
+    elif background == 'w':
+        bg[:, :, :] = 255
+
+    elif not os.path.exists(background):
+        raise Exception('The --background is not existed: {}'.format(
+            background))
+    else:
+        bg = cv2.imread(background)
+        bg = cv2.resize(bg, (img_shape[1], img_shape[0]))
+    return bg

+ 65 - 0
tools/model.py

@@ -0,0 +1,65 @@
+import os
+import sys
+
+import paddle
+import paddleseg
+from paddleseg.utils import get_sys_env, logger
+
+from ppmatting.transforms import Compose
+
+import ppmatting
+from ppmatting.core import load
+from ppmatting.utils import Config, MatBuilder
+
+current_path = os.path.abspath(os.path.dirname(__file__))
+
+
+def get_rel_path(path: str):
+    return os.path.join(current_path, '..', path)
+
+
+class MattingModel:
+    path = ""
+    config = ""
+    model: paddle.nn.layer.Layer
+    transforms: Compose
+    init = False
+
+    def __init__(self, p, c):
+        self.path = get_rel_path(p)
+        self.config = get_rel_path(c)
+
+
+_model: MattingModel
+
+modelDict = {
+    "ppmattingv2": MattingModel("models/ppmattingv2-stdc1-human_512.pdparams",
+                                "configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml"),
+
+    "ppmatting": MattingModel("models/ppmatting-hrnet_w18-human_512.pdparams",
+                              "configs/quick_start/ppmattingv2-stdc1-human_512.yml")
+}
+
+
+def load_model():
+    m = modelDict.get("ppmattingv2")
+    global _model
+    _model = m
+
+    model_path=get_rel_path("configs/quick_start/ppmattingv2-stdc1-human_512.yml")
+    cfg = Config(model_path)
+    builder = MatBuilder(cfg)
+
+    paddleseg.utils.show_env_info()
+    paddleseg.utils.show_cfg_info(cfg)
+    paddleseg.utils.set_device("cpu")
+
+    m.model = builder.model
+    m.transforms = ppmatting.transforms.Compose(builder.val_transforms)
+
+    load(m.model, m.path)
+    m.init = True
+
+
+def get_model():
+    return _model

+ 2 - 8
tools/predict.py

@@ -21,12 +21,6 @@ import paddleseg
 from paddleseg.cvlibs import manager
 from paddleseg.utils import get_sys_env, logger
 
-LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(os.path.join(LOCAL_PATH, '..'))
-
-manager.BACKBONES._components_dict.clear()
-manager.TRANSFORMS._components_dict.clear()
-
 import ppmatting
 from ppmatting.core import predict, load
 from ppmatting.utils import get_image_list, Config, MatBuilder
@@ -119,7 +113,7 @@ global model, transforms
 
 
 def load_model():
-    cfg = Config(get_rel_path("configs/quick_start/ppmattingv2-stdc1-human_512.yml"))
+    cfg = Config(get_rel_path("configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml"))
     builder = MatBuilder(cfg)
 
     paddleseg.utils.show_env_info()
@@ -128,7 +122,7 @@ def load_model():
 
     global model, transforms
     model = builder.model
-    model_path = get_rel_path("models/ppmatting-hrnet_w18-human_512.pdparams")
+    model_path = get_rel_path("models/ppmattingv2-stdc1-human_512.pdparams")
 
     paddleseg.utils.show_env_info()
     paddleseg.utils.show_cfg_info(cfg)


BIN
uploads/20211009105824.jpg


BIN
uploads/cat_20.png


BIN
uploads/demo111.jpeg


BIN
uploads/head_120.png