Переглянути джерело

feat: 接口优化,感觉有点乱,不统一

tuonian 1 рік тому
батько
коміт
d811edb830
5 змінених файлів з 192 додано та 137 видалено
  1. 37 8
      api/file.py
  2. 93 44
      api/image.py
  3. 1 1
      tools/__init__.py
  4. 10 84
      tools/img.py
  5. 51 0
      utils/color.py

+ 37 - 8
api/file.py

@@ -1,4 +1,6 @@
 import os
+
+import cv2
 from flask import request, jsonify, send_file
 from werkzeug.utils import secure_filename
 from werkzeug.datastructures import FileStorage
@@ -6,11 +8,17 @@ import uuid
 from app import app
 from .resp import success_resp, error_resp
 import time
+from .req import  is_empty
 
 UPLOAD_PATH = ""
 OUTPUT_DIR = ""
 
 
+class MimeTypes:
+    png = "image/png"
+    stream = "application/octet-stream"
+
+
 def get_upload_dir():
     return UPLOAD_PATH
 
@@ -29,13 +37,17 @@ def set_options(upload_path, save_dir):
 @app.route("/outputs/<path:filename>")
 def get_outputs_files(filename):
     # 处理后的文件
-    return send_file(os.path.join(OUTPUT_DIR, filename), mimetype="image/png")
+    download = is_empty(request.values.get('download')) is False
+    mimetype = MimeTypes.stream if download else MimeTypes.png
+    return send_file(os.path.join(OUTPUT_DIR, filename), mimetype=mimetype)
 
 
 @app.route("/uploads/<path:filename>")
 def get_upload_files(filename):
     # 用户上传的原始文件
-    return send_file(os.path.join(UPLOAD_PATH, filename), mimetype="image/png")
+    download = is_empty(request.values.get('download')) is False
+    mimetype = MimeTypes.stream if download else MimeTypes.png
+    return send_file(os.path.join(UPLOAD_PATH, filename), mimetype=mimetype)
 
 
 @app.route("/upload", methods=['POST'])
@@ -44,13 +56,26 @@ def upload_file():
     if 'file' not in request.files:
         return jsonify(error_resp("请选择图片!"))
     file = request.files['file']
-    file_path, file_id = save_upload_file(file)
+    file_md5 = request.form.get('fileMd5')
+    file_path, file_id = save_upload_file(file, file_md5)
+    img = cv2.imread(file_path)
     return jsonify(success_resp({
         "fileId": file_id,
-        "url": file_url(file_path)
+        "url": file_url(file_path),
+        "height": img.shape[0],
+        "width": img.shape[1]
     }))
 
 
+@app.route("/download", methods = ['GET'])
+def download_file():
+    file_id = request.values.get("fileId")
+    file_path = get_file_id_path(file_id)
+    if file_path is None:
+        return "文件不存在", 404
+    return send_file(file_path, mimetype="application/octet-stream", attachment_filename=file_id)
+
+
 def file_url(path, output=False):
     if output is True:
         file_id = os.path.relpath(path, OUTPUT_DIR)
@@ -59,18 +84,22 @@ def file_url(path, output=False):
     return "/uploads/{}".format(file_id)
 
 
-def save_upload_file(file: FileStorage):
+def save_upload_file(file: FileStorage, md5 = None):
     # 如果区分用户,加文件前缀或者文件夹名称
     if file.filename == '':
         filename = "unknown.png"
     else:
         filename = secure_filename(file.filename)
     paths = os.path.splitext(filename)
-    t = time.strftime("%Y%m%d%H%M%S", time.localtime())
-    filename = "{}_{}{}_{}".format(paths[0], t, uuid.uuid4().hex, paths[1])
 
+    if md5 is not None and len(md5) > 0:
+        filename = "{}_{}{}".format(paths[0], md5, paths[1])
+    else:
+        t = time.strftime("%Y%m%d%H%M%S", time.localtime())
+        filename = "{}_{}_{}{}".format(paths[0], t, uuid.uuid4().hex, paths[1])
     file_path = get_upload_file_path(filename)
-    file.save(file_path)
+    if os.path.exists(file_path) is False:
+        file.save(file_path)
     return file_path, filename
 
 

+ 93 - 44
api/image.py

@@ -6,7 +6,9 @@ import numpy as np
 from .req import ReplaceForm, is_empty
 import tools
 import cv2
-from .file import get_upload_file_path, get_output_dir, file_url, get_output_file_path, get_file_id
+from .file import get_upload_file_path, get_output_dir, file_url, get_output_file_path, get_file_id, get_file_id_path
+from utils import color
+import time
 
 log = logging.create_logger(app)
 
@@ -33,23 +35,50 @@ def replace():
     if form.validate() is False:
         return jsonify(error_resp("参数错误!"))
 
+    img_path = get_file_id_path(form.file_id.data)
+
+    exist, alpha_path, path = tools.has_seg(img_path, get_output_dir())
+
+    if exist:
+        alpha = cv2.imread(alpha_path)
+        fg = cv2.imread(path)
+    else:
+        alpha, fg, _, path = tools.seg(img_path, get_output_dir())
+
     if is_empty(form.bg_file_id.data) and is_empty(form.background.data):
-        return jsonify(error_resp("请选择需要替换的背景!"))
+        return jsonify(success_resp({
+            "fileId": form.file_id.data,
+            "url": file_url(path, True)
+        }))
 
-    img_path = get_upload_file_path(form.file_id.data)
     bg_path = form.background.data
 
     if is_empty(form.bg_file_id.data) is False:
         bg_path = get_upload_file_path(form.bg_file_id.data)
-        if os.path.exists(bg_path):
+        if os.path.exists(bg_path) is False:
             bg_path = form.background.data
 
-    result = tools.replace(img_path=img_path, background=bg_path, save_dir=get_output_dir())
+    bg = color.get_bg(bg_path, fg.shape)
+    if bg is None:
+        return jsonify(success_resp({
+            "fileId": form.file_id.data,
+            "url": file_url(path, True)
+        }))
+
+    alpha = alpha / 255.0
+    alpha = alpha[:, :, np.newaxis]
+    com = alpha * fg + (1 - alpha) * bg
+    com = com.astype('uint8')
+    filename = os.path.basename(img_path)
+    names = os.path.splitext(filename)
+    save_name = "{}_{}{}".format(names[0], time.time(), names[1])
+    com_save_path = os.path.join(get_output_dir(), save_name)
+    cv2.imwrite(com_save_path, com)
 
-    return jsonify({
-        "fileId": get_file_id(result, True),
-        "url": file_url(result, True)
-    })
+    return jsonify(success_resp({
+        "fileId": get_file_id(com_save_path, True),
+        "url": file_url(com_save_path, True)
+    }))
 
 
 @app.route("/image/resize", methods=["POST", 'GET'])
@@ -65,13 +94,17 @@ def resize():
     rotate = request.values.get("rotate")
     flip = request.values.get("flip")
     rect = request.values.get('rect')
+    save = request.values.get('onSave')
+    download = request.values.get('download')
+    save = is_empty(save) is False
+    download = is_empty(download) is False
 
     if is_empty(reset):
         reset = True
     else:
         reset = False
 
-    file_path = get_upload_file_path(file_id)
+    file_path = get_file_id_path(file_id)
 
     is_get = request.method.upper() == "GET"
 
@@ -86,60 +119,76 @@ def resize():
     origin_h = img.shape[0]
     origin_w = img.shape[1]
 
-    img = crop(img, rect, origin_w, origin_h)
+    dst, r_w, r_h = crop(img, rect, origin_w, origin_h)
 
     change_size = True
     if is_empty(width) and is_empty(height):
         change_size = False
-        w = origin_w
-        h = origin_h
-    elif is_empty(width):
-        h = int(height)
-        w = round(h * origin_w / origin_h)
-    elif is_empty(height):
-        w = int(width)
-        h = round(w * origin_h / origin_w)
-    else:
-        w = int(width)
-        h = int(height)
-
-    dst = img
 
     if change_size:
-        dst = cv2.resize(img, (w, h))
+        if is_empty(width):
+            h = int(height)
+            w = round(h * origin_w / origin_h)
+        elif is_empty(height):
+            w = int(width)
+            h = round(w * origin_h / origin_w)
+        else:
+            w = int(width)
+            h = int(height)
+        origin = img if dst is None else dst
+        dst = cv2.resize(origin, (w, h))
+        r_w = w
+        r_h = h
+
+    if r_w is None:
+        r_w = origin_w
+    if r_h is None:
+        r_h = origin_h
 
     if is_empty(flip) is not True:
         flip = int(flip)
-        dst = cv2.flip(dst, flip)
+        origin = img if dst is None else dst
+        dst = cv2.flip(origin, flip)
 
-    if is_empty(rotate) is False:
-        dst = rot_degree(dst, float(rotate), w=w, h=h)
+    if is_empty(rotate) is False and int(rotate) != 0:
+        origin = img if dst is None else dst
+        dst, new_w, new_h = rot_degree(origin, float(rotate), w=r_w, h=r_h)
+        if new_w is not None:
+            r_w = new_w
+            r_h = new_h
+
+    result_path = file_path
 
     if dst is not None:
         cv2.imwrite(out_file, dst)
-        if is_get:
-            return send_file(out_file, mimetype="image/png")
-        else:
-            return jsonify(success_resp({
-                "fileId": get_file_id(out_file),
-                "url": file_url(out_file)
-            }))
+        result_path = out_file
+    else:
+        dst = img
+
+    if save:
+        result_path = get_output_file_path(file_id)
+        cv2.imwrite(result_path, dst)
 
     if is_get:
-        return send_file(file_path, mimetype="image/png")
+        mimetype = 'application/octet-stream' if download else 'image/png'
+        return send_file(result_path, mimetype=mimetype)
+
     return jsonify(success_resp({
-        "fileId": get_file_id(file_path),
-        "url": file_url(file_path)
+        "fileId": get_file_id(result_path, True),
+        "url": file_url(result_path, True),
+        "width": r_w,
+        "height": r_h,
     }))
 
 
 def crop(img, rect: str, w, h):
-    # 裁剪
+    # 裁剪,
+    # 如果没有改变,返回None
     if is_empty(rect):
-        return img
+        return None, None, None
     r = rect.split(',')
     if len(r) != 4:
-        return img
+        return None, None, None
 
     left = int(r[0])
     top = int(r[1])
@@ -156,9 +205,9 @@ def crop(img, rect: str, w, h):
         bottom = h
 
     if left == 0 and top == 0 and right == w and bottom == h:
-        return img
+        return None, None, None
 
-    return img[top:bottom, left:right]
+    return img[top:bottom, left:right], right-left, bottom-top
 
 
 def rot_degree(img, degree, w, h):
@@ -176,4 +225,4 @@ def rot_degree(img, degree, w, h):
     M[0, 2] += offset_x
     M[1, 2] += offset_y
     dst = cv2.warpAffine(img, M, (new_width, new_height))
-    return dst
+    return dst, new_width, new_height

+ 1 - 1
tools/__init__.py

@@ -28,4 +28,4 @@ manager.BACKBONES._components_dict.clear()
 manager.TRANSFORMS._components_dict.clear()
 
 from .model import load_model
-from .img import seg, replace
+from .img import seg, has_seg

+ 10 - 84
tools/img.py

@@ -11,21 +11,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import os
+
 from paddleseg.utils import logger
 
-from tools.model import get_model
-import cv2
-import numpy as np
 from ppmatting.core import predict
 from ppmatting.utils import get_image_list
-
-current_path = os.path.abspath(os.path.dirname(__file__))
-
-
-def get_rel_path(path: str):
-    return "{}/../{}".format(current_path, path)
+from tools.model import get_model
 
 
 def seg(img_path: str, save_dir: str):
@@ -45,79 +37,13 @@ def seg(img_path: str, save_dir: str):
         fg_estimate=True)
 
 
-def replace(img_path: str, save_dir: str, background: str = None, width: int = None, height: int = None):
-    logger.info("replace: {}, {},{},{},{}".format(img_path, save_dir, background, width, height))
-    image_list, image_dir = get_image_list(img_path)
-    model = get_model()
-    alpha, fg, _, p = predict(
-        model=model.model,
-        model_path=model.path,
-        transforms=model.transforms,
-        image_list=image_list,
-        trimap_list=None,
-        save_dir=save_dir,
-        fg_estimate=False)
-
-    if background is None:
-        return p
-
-    img_ori = cv2.imread(img_path)
-    bg = get_bg(background, img_ori.shape)
-    if bg is None:
-        return p
-
-    alpha = alpha / 255.0
-    alpha = alpha[:, :, np.newaxis]
-    com = alpha * fg + (1 - alpha) * bg
-    com = com.astype('uint8')
-    com_save_path = os.path.join(save_dir, os.path.basename(img_path))
-    cv2.imwrite(com_save_path, com)
-    return com_save_path
-
-
-def get_bg(background, img_shape):
-    # 1、纯色
-    # 2、通道颜色
-    # 3、图片
-    bg = np.zeros(img_shape)
-    if background is None:
-        return None
-    if os.path.exists(background):
-        bg = cv2.imread(background)
-        bg = cv2.resize(bg, (img_shape[1], img_shape[0]))
-    elif background == 'r':
-        bg[:, :, 2] = 255
-    elif background == 'g':
-        bg[:, :, 1] = 255
-    elif background == 'b':
-        bg[:, :, 0] = 255
-    elif background == 'w':
-        bg[:, :, :] = 255
-    elif is_color_hex(background):
-        r, g, b, _ = hex_to_rgb(background)
-        bg[:, :, 2] = r
-        bg[:, :, 1] = g
-        bg[:, :, 0] = b
-    else:
-        return None
-
-    return bg
-
-
-def is_color_hex(color: str):
-    size = len(color)
-    if color.startswith("#"):
-        return size == 7 or size == 9
-    return False
+def has_seg(img_path: str, save_dir: str):
+    # 是否已经抠过图了
+    paths = os.path.splitext(img_path)
+    rgba = "{}_rgba{}".format(paths[0], paths[1])
+    alpha = "{}_alpha{}".format(paths[0], paths[1])
 
+    rgba_path = os.path.join(save_dir, rgba)
+    alpha_path = os.path.join(save_dir, alpha)
 
-def hex_to_rgb(color: str):
-    if color.startswith("#"):
-        color = color[1:len(color)]
-    r = int(color[0:2], 16)
-    g = int(color[2:4], 16)
-    b = int(color[4:6], 16)
-    a = 100
-    if len(color) == 8:
-        a = int(color[6:8], 16)
-    return r, g, b, a
+    return os.path.exists(rgba_path) and os.path.exists(alpha_path), alpha_path, rgba_path

+ 51 - 0
utils/color.py

@@ -0,0 +1,51 @@
+import os
+import cv2
+import numpy as np
+
+
+def get_bg(background, img_shape):
+    # 1、纯色
+    # 2、通道颜色
+    # 3、图片
+    bg = np.zeros(img_shape)
+    if background is None:
+        return None
+    if os.path.exists(background):
+        bg = cv2.imread(background)
+        bg = cv2.resize(bg, (img_shape[1], img_shape[0]))
+    elif background == 'r':
+        bg[:, :, 2] = 255
+    elif background == 'g':
+        bg[:, :, 1] = 255
+    elif background == 'b':
+        bg[:, :, 0] = 255
+    elif background == 'w':
+        bg[:, :, :] = 255
+    elif is_color_hex(background):
+        r, g, b, _ = hex_to_rgb(background)
+        bg[:, :, 2] = r
+        bg[:, :, 1] = g
+        bg[:, :, 0] = b
+    else:
+        return None
+
+    return bg
+
+
+def is_color_hex(color: str):
+    size = len(color)
+    if color.startswith("#"):
+        return size == 7 or size == 9
+    return False
+
+
+def hex_to_rgb(color: str):
+    if color.startswith("#"):
+        color = color[1:len(color)]
+    r = int(color[0:2], 16)
+    g = int(color[2:4], 16)
+    b = int(color[4:6], 16)
+    a = 100
+    if len(color) == 8:
+        a = int(color[6:8], 16)
+    return r, g, b, a