12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- from paddleseg.utils import logger
- from ppmatting.core import predict
- from ppmatting.utils import get_image_list
- from tools.model import get_model
- def seg(img_path: str, save_dir: str):
- image_list, image_dir = get_image_list(img_path)
- logger.info('Number of predict images = {}'.format(len(image_list)))
- model = get_model()
- return predict(
- model=model.model,
- model_path=model.path,
- transforms=model.transforms,
- image_list=image_list,
- image_dir=image_dir,
- trimap_list=None,
- save_dir=save_dir,
- fg_estimate=True)
- def has_seg(img_path: str, save_dir: str):
- # 是否已经抠过图了
- paths = os.path.splitext(img_path)
- rgba = "{}_rgba{}".format(paths[0], paths[1])
- alpha = "{}_alpha{}".format(paths[0], paths[1])
- rgba_path = os.path.join(save_dir, rgba)
- alpha_path = os.path.join(save_dir, alpha)
- return os.path.exists(rgba_path) and os.path.exists(alpha_path), alpha_path, rgba_path
|