# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from paddleseg.utils import logger from ppmatting.core import predict from ppmatting.utils import get_image_list from tools.model import get_model def seg(img_path: str, save_dir: str): image_list, image_dir = get_image_list(img_path) logger.info('Number of predict images = {}'.format(len(image_list))) model = get_model() return predict( model=model.model, model_path=model.path, transforms=model.transforms, image_list=image_list, image_dir=image_dir, trimap_list=None, save_dir=save_dir, fg_estimate=True) def has_seg(img_path: str, save_dir: str): # 是否已经抠过图了 paths = os.path.splitext(img_path) rgba = "{}_rgba{}".format(paths[0], paths[1]) alpha = "{}_alpha{}".format(paths[0], paths[1]) rgba_path = os.path.join(save_dir, rgba) alpha_path = os.path.join(save_dir, alpha) return os.path.exists(rgba_path) and os.path.exists(alpha_path), alpha_path, rgba_path