Browse Source

搞定抠图接口,更换背景

tuon 1 year ago
parent
commit
3962ceb4a2
10 changed files with 343 additions and 53 deletions
  1. 2 0
      api/__init__.py
  2. 92 0
      api/file.py
  3. 165 0
      api/image.py
  4. 12 0
      api/req.py
  5. 10 16
      api/resp.py
  6. 4 0
      app.py
  7. 8 24
      main.py
  8. 2 1
      requirements.txt
  9. 1 1
      tools/__init__.py
  10. 47 11
      tools/img.py

+ 2 - 0
api/__init__.py

@@ -0,0 +1,2 @@
+from .file import *
+from .image import *

+ 92 - 0
api/file.py

@@ -0,0 +1,92 @@
+import os
+from flask import request, jsonify, send_file
+from werkzeug.utils import secure_filename
+from werkzeug.datastructures import FileStorage
+import uuid
+from app import app
+from .resp import success_resp, error_resp
+import time
+
+UPLOAD_PATH = ""
+OUTPUT_DIR = ""
+
+
+def get_upload_dir():
+    return UPLOAD_PATH
+
+
+def get_output_dir():
+    return OUTPUT_DIR
+
+
+def set_options(upload_path, save_dir):
+    global UPLOAD_PATH, OUTPUT_DIR
+    UPLOAD_PATH = upload_path
+    OUTPUT_DIR = save_dir
+    print("dirs", UPLOAD_PATH, OUTPUT_DIR)
+
+
+@app.route("/outputs/<path:filename>")
+def get_outputs_files(filename):
+    # 处理后的文件
+    return send_file(os.path.join(OUTPUT_DIR, filename), mimetype="image/png")
+
+
+@app.route("/uploads/<path:filename>")
+def get_upload_files(filename):
+    # 用户上传的原始文件
+    return send_file(os.path.join(UPLOAD_PATH, filename), mimetype="image/png")
+
+
+@app.route("/upload", methods=['POST'])
+def upload_file():
+    # 上传图片
+    if 'file' not in request.files:
+        return jsonify(error_resp("请选择图片!"))
+    file = request.files['file']
+    file_path, file_id = save_upload_file(file)
+    return jsonify(success_resp({
+        "fileId": file_id,
+        "url": file_url(file_path)
+    }))
+
+
+def file_url(path, output=False):
+    if output is True:
+        file_id = os.path.relpath(path, OUTPUT_DIR)
+        return "/outputs/{}".format(file_id)
+    file_id = os.path.relpath(path, UPLOAD_PATH)
+    return "/uploads/{}".format(file_id)
+
+
+def save_upload_file(file: FileStorage):
+    # 如果区分用户,加文件前缀或者文件夹名称
+    if file.filename == '':
+        filename = "unknown.png"
+    else:
+        filename = secure_filename(file.filename)
+    paths = os.path.splitext(filename)
+    t = time.strftime("%Y%m%d%H%M%S", time.localtime())
+    filename = "{}_{}{}_{}".format(paths[0], t, uuid.uuid4().hex, paths[1])
+
+    file_path = get_upload_file_path(filename)
+    file.save(file_path)
+    return file_path, filename
+
+
+def get_output_file_path(filename: str, suffix: str = None, fun: str = None):
+    if filename == '':
+        filename = "unknown.png"
+    else:
+        filename = secure_filename(filename)
+    paths = os.path.splitext(filename)
+    if suffix is None or len(suffix) == 0:
+        suffix = time.strftime("%Y%m%d%H%M%S", time.localtime())
+    if fun is not None:
+        suffix = "{}_{}".format(fun, suffix)
+    filename = "{}_{}{}".format(paths[0], suffix, paths[1])
+    return os.path.join(OUTPUT_DIR, filename)
+
+
+def get_upload_file_path(name):
+    return os.path.join(UPLOAD_PATH, name)

+ 165 - 0
api/image.py

@@ -0,0 +1,165 @@
+import os
+from flask import request, logging, send_file, jsonify
+from werkzeug.datastructures import CombinedMultiDict
+from app import app
+from .resp import success_resp, error_resp
+import numpy as np
+from .req import ReplaceForm, is_empty
+import tools
+import cv2
+from .file import get_upload_file_path, get_output_dir, file_url, get_output_file_path
+
+log = logging.create_logger(app)
+
+
+@app.route("/image/seg", methods=['POST'])
+def seg():
+    if 'fileId' not in request.values:
+        return "请先上传图片!", 500
+    file_id = request.values.get('fileId')
+
+    file_path = get_upload_file_path(file_id)
+    _, fg, _, path = tools.seg(file_path, get_output_dir())
+    # filename = os.path.relpath(path, OUTPUT_DIR)
+    return jsonify(success_resp({
+        "fileId": file_id,
+        "url": file_url(path, True)
+    }))
+
+
+@app.route("/image/replace", methods=["POST"])
+def replace():
+    form_data = request.form
+    form = ReplaceForm(form_data)
+    if form.validate() is False:
+        return jsonify(error_resp("参数错误!"))
+
+    if is_empty(form.bg_file_id.data) and is_empty(form.background.data):
+        return jsonify(error_resp("请选择需要替换的背景!"))
+
+    img_path = get_upload_file_path(form.file_id.data)
+
+    bg_path = get_upload_file_path(form.bg_file_id.data)
+    if os.path.exists(bg_path) is False:
+        bg_path = form.background.data
+
+    result = tools.replace(img_path=img_path, background=bg_path, save_dir=get_output_dir())
+
+    return jsonify({
+        "fileId": form.file_id.data,
+        "url": file_url(result, True)
+    })
+
+
+@app.route("/image/resize", methods=["POST", 'GET'])
+def resize():
+    # flip: 翻转, 0 为沿X轴翻转,正数为沿Y轴翻转,负数为同时沿X轴和Y轴翻转
+    # reset: 是否重头开始,否则从上一次的处理开始,默认为重头开始
+    # resize: 重新设定尺寸
+    # rect: 裁剪,left,top,right,bottom 四个参数
+    reset = request.values.get("reset")
+    file_id = request.values.get("fileId")
+    width = request.values.get("width")
+    height = request.values.get("height")
+    rotate = request.values.get("rotate")
+    flip = request.values.get("flip")
+    rect = request.values.get('rect')
+
+    if is_empty(reset):
+        reset = True
+    else:
+        reset = False
+
+    file_path = get_upload_file_path(file_id)
+
+    if os.path.exists(file_path) is not True:
+        return "文件上传失败,请重新上传!", 400
+
+    out_file = get_output_file_path(file_id, "resize")
+    if reset is False and os.path.exists(out_file):
+        img = cv2.imread(out_file)
+    else:
+        img = cv2.imread(file_path)
+    origin_h = img.shape[0]
+    origin_w = img.shape[1]
+
+    img = crop(img, rect, origin_w, origin_h)
+
+    change_size = True
+    if is_empty(width) and is_empty(height):
+        change_size = False
+        w = origin_w
+        h = origin_h
+    elif is_empty(width):
+        h = int(height)
+        w = round(h * origin_w / origin_h)
+    elif is_empty(height):
+        w = int(width)
+        h = round(w * origin_h / origin_w)
+    else:
+        w = int(width)
+        h = int(height)
+
+    dst = img
+
+    if change_size:
+        dst = cv2.resize(img, (w, h))
+
+    if is_empty(flip) is not True:
+        flip = int(flip)
+        dst = cv2.flip(dst, flip)
+
+    if is_empty(rotate) is False:
+        dst = rot_degree(dst, float(rotate), w=w, h=h)
+
+    if dst is not None:
+        cv2.imwrite(out_file, dst)
+        return send_file(out_file, mimetype="image/png")
+
+    return send_file(file_path, mimetype="image/png")
+
+
+def crop(img, rect: str, w, h):
+    # 裁剪
+    if is_empty(rect):
+        return img
+    r = rect.split(',')
+    if len(r) != 4:
+        return img
+
+    left = int(r[0])
+    top = int(r[1])
+    right = int(r[2])
+    bottom = int(r[3])
+
+    if left < 0:
+        left = 0
+    if right > w:
+        right = w
+    if top < 0:
+        top = 0
+    if bottom > h:
+        bottom = h
+
+    if left == 0 and top == 0 and right == w and bottom == h:
+        return img
+
+    return img[top:bottom, left:right]
+
+
+def rot_degree(img, degree, w, h):
+    center = (w / 2, h / 2)
+
+    M = cv2.getRotationMatrix2D(center, degree, 1)
+    top_right = np.array((w - 1, 0)) - np.array(center)
+    bottom_right = np.array((w - 1, h - 1)) - np.array(center)
+    top_right_after_rot = M[0:2, 0:2].dot(top_right)
+    bottom_right_after_rot = M[0:2, 0:2].dot(bottom_right)
+    new_width = max(int(abs(bottom_right_after_rot[0] * 2) + 0.5), int(abs(top_right_after_rot[0] * 2) + 0.5))
+    new_height = max(int(abs(top_right_after_rot[1] * 2) + 0.5), int(abs(bottom_right_after_rot[1] * 2) + 0.5))
+    offset_x = (new_width - w) / 2
+    offset_y = (new_height - h) / 2
+    M[0, 2] += offset_x
+    M[1, 2] += offset_y
+    dst = cv2.warpAffine(img, M, (new_width, new_height))
+    return dst

+ 12 - 0
api/req.py

@@ -0,0 +1,12 @@
+from wtforms import Form, SubmitField, StringField, PasswordField, IntegerField, FileField, validators, FloatField
+
+
+def is_empty(s: str):
+    return s is None or len(s) == 0
+
+
+class ReplaceForm(Form):
+    # background: 可以是图片文件名称,也可以是r,g,b,w或者是颜色的hex
+    background = StringField(label="background", validators=[])
+    bg_file_id = StringField("bg_file_id", validators=[])
+    file_id = StringField(label="file_id", validators=[validators.DataRequired(message="请先上传图片")])

+ 10 - 16
api/resp.py

@@ -1,20 +1,14 @@
-class ApiResp:
-    # 基本响应
-    code = 0
-    msg = "Successful"
-    data = None
-
-    def __init__(self, code, msg=None, data=None):
-        self.code = code
-        if msg is not None and len(msg) > 0:
-            self.msg = msg
-        if data is not None:
-            self.data = data
-
-
+# 响应基类
 def error_resp(msg, code=-1):
-    return ApiResp(code=code, msg=msg)
+    return {
+        "code": code,
+        "msg": msg
+    }
 
 
 def success_resp(data):
-    return ApiResp(code=0, data=data)
+    return {
+        "code": 0,
+        "msg": "Successful",
+        "data": data
+    }

+ 4 - 0
app.py

@@ -0,0 +1,4 @@
+from flask import Flask
+
+
+app = Flask(__name__)

+ 8 - 24
main.py

@@ -1,10 +1,8 @@
-from flask import Flask, request, jsonify
 import os
-import tools
-from werkzeug.utils import secure_filename
-from api import resp
 
-app = Flask(__name__)
+import tools
+import api
+from app import app
 
 cur_dirs = os.path.abspath(os.path.dirname(__file__))
 save_dir = os.path.join(cur_dirs, 'outputs')
@@ -13,27 +11,13 @@ upload_dir = os.path.join(cur_dirs, "uploads")
 app.config['UPLOAD_FOLDER'] = upload_dir
 
 
-def get_upload_file_path(name):
-    return os.path.join(upload_dir, name)
-
-
-@app.route("/image/seg", methods=['POST'])
-def seg():
-    if 'file' not in request.files:
-        return '{}'
-    file = request.files['file']
-
-    if file.filename == '':
-        return '{}'
-
-    filename = secure_filename(file.filename)
-    file_path = get_upload_file_path(filename)
-    file.save(file_path)
-
-    _, _, _, path = tools.seg(file_path, save_dir)
-    return jsonify(resp.success_resp(path))
+def init():
+    if os.path.exists(upload_dir) is False:
+        os.makedirs(upload_dir)
+    api.set_options(upload_dir, save_dir)
 
 
 if __name__ == '__main__':
+    init()
     tools.load_model()
     app.run(port=20201, host="0.0.0.0", debug=True)

+ 2 - 1
requirements.txt

@@ -11,4 +11,5 @@ flask~=2.0.3
 six
 scipy
 pillow
-werkzeug
+werkzeug
+wtforms

+ 1 - 1
tools/__init__.py

@@ -28,4 +28,4 @@ manager.BACKBONES._components_dict.clear()
 manager.TRANSFORMS._components_dict.clear()
 
 from .model import load_model
-from .img import seg,replace
+from .img import seg, replace

+ 47 - 11
tools/img.py

@@ -45,10 +45,11 @@ def seg(img_path: str, save_dir: str):
         fg_estimate=True)
 
 
-def replace(img_path: str, save_dir: str, bg_color = "r"):
+def replace(img_path: str, save_dir: str, background: str = None, width: int = None, height: int = None):
+    logger.info("replace: {}, {},{},{},{}".format(img_path, save_dir, background, width, height))
     image_list, image_dir = get_image_list(img_path)
     model = get_model()
-    alpha, fg = predict(
+    alpha, fg, _, p = predict(
         model=model.model,
         model_path=model.path,
         transforms=model.transforms,
@@ -57,31 +58,66 @@ def replace(img_path: str, save_dir: str, bg_color = "r"):
         save_dir=save_dir,
         fg_estimate=False)
 
+    if background is None:
+        return p
+
     img_ori = cv2.imread(img_path)
-    bg = get_bg(bg_color, img_ori.shape)
+    bg = get_bg(background, img_ori.shape)
+    if bg is None:
+        return p
+
     alpha = alpha / 255.0
     alpha = alpha[:, :, np.newaxis]
     com = alpha * fg + (1 - alpha) * bg
     com = com.astype('uint8')
     com_save_path = os.path.join(save_dir, os.path.basename(img_path))
     cv2.imwrite(com_save_path, com)
+    return com_save_path
 
 
 def get_bg(background, img_shape):
+    # 1、纯色
+    # 2、通道颜色
+    # 3、图片
     bg = np.zeros(img_shape)
-    if background == 'r':
+    if background is None:
+        return None
+    if os.path.exists(background):
+        bg = cv2.imread(background)
+        bg = cv2.resize(bg, (img_shape[1], img_shape[0]))
+    elif background == 'r':
         bg[:, :, 2] = 255
-    elif background is None or background == 'g':
+    elif background == 'g':
         bg[:, :, 1] = 255
     elif background == 'b':
         bg[:, :, 0] = 255
     elif background == 'w':
         bg[:, :, :] = 255
-
-    elif not os.path.exists(background):
-        raise Exception('The --background is not existed: {}'.format(
-            background))
+    elif is_color_hex(background):
+        r, g, b, _ = hex_to_rgb(background)
+        bg[:, :, 2] = r
+        bg[:, :, 1] = g
+        bg[:, :, 0] = b
     else:
-        bg = cv2.imread(background)
-        bg = cv2.resize(bg, (img_shape[1], img_shape[0]))
+        return None
+
     return bg
+
+
+def is_color_hex(color: str):
+    size = len(color)
+    if color.startswith("#"):
+        return size == 7 or size == 9
+    return False
+
+
+def hex_to_rgb(color: str):
+    if color.startswith("#"):
+        color = color[1:len(color)]
+    r = int(color[0:2], 16)
+    g = int(color[2:4], 16)
+    b = int(color[4:6], 16)
+    a = 100
+    if len(color) == 8:
+        a = int(color[6:8], 16)
+    return r, g, b, a