|
@@ -28,7 +28,7 @@ manager.BACKBONES._components_dict.clear()
|
|
|
manager.TRANSFORMS._components_dict.clear()
|
|
|
|
|
|
import ppmatting
|
|
|
-from ppmatting.core import predict
|
|
|
+from ppmatting.core import predict, load
|
|
|
from ppmatting.utils import get_image_list, Config, MatBuilder
|
|
|
|
|
|
current_path = os.path.abspath(os.path.dirname(__file__))
|
|
@@ -115,7 +115,10 @@ def get_rel_path(path: str):
|
|
|
return "{}/../{}".format(current_path, path)
|
|
|
|
|
|
|
|
|
-def seg(img_path: str, save_dir: str):
|
|
|
+global model, transforms
|
|
|
+
|
|
|
+
|
|
|
+def load_model():
|
|
|
cfg = Config(get_rel_path("configs/quick_start/ppmattingv2-stdc1-human_512.yml"))
|
|
|
builder = MatBuilder(cfg)
|
|
|
|
|
@@ -123,9 +126,30 @@ def seg(img_path: str, save_dir: str):
|
|
|
paddleseg.utils.show_cfg_info(cfg)
|
|
|
paddleseg.utils.set_device("cpu")
|
|
|
|
|
|
+ global model, transforms
|
|
|
model = builder.model
|
|
|
+ model_path = get_rel_path("models/ppmatting-hrnet_w18-human_512.pdparams")
|
|
|
+
|
|
|
+ paddleseg.utils.show_env_info()
|
|
|
+ paddleseg.utils.show_cfg_info(cfg)
|
|
|
+ paddleseg.utils.set_device("cpu")
|
|
|
+
|
|
|
transforms = ppmatting.transforms.Compose(builder.val_transforms)
|
|
|
|
|
|
+ load(model, model_path)
|
|
|
+
|
|
|
+
|
|
|
+def seg(img_path: str, save_dir: str):
|
|
|
+ # cfg = Config(get_rel_path("configs/quick_start/ppmattingv2-stdc1-human_512.yml"))
|
|
|
+ # builder = MatBuilder(cfg)
|
|
|
+
|
|
|
+ # paddleseg.utils.show_env_info()
|
|
|
+ # paddleseg.utils.show_cfg_info(cfg)
|
|
|
+ # paddleseg.utils.set_device("cpu")
|
|
|
+
|
|
|
+ # model = builder.model
|
|
|
+ # transforms = ppmatting.transforms.Compose(builder.val_transforms)
|
|
|
+
|
|
|
image_list, image_dir = get_image_list(img_path)
|
|
|
logger.info('Number of predict images = {}'.format(len(image_list)))
|
|
|
|
|
@@ -139,7 +163,7 @@ def seg(img_path: str, save_dir: str):
|
|
|
image_dir=image_dir,
|
|
|
trimap_list=None,
|
|
|
save_dir=save_dir,
|
|
|
- fg_estimate=args.fg_estimate)
|
|
|
+ fg_estimate=True)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|