tuon 1 jaar geleden
commit
d284f3ea1f
100 gewijzigde bestanden met toevoegingen van 7126 en 0 verwijderingen
  1. 120 0
      .gitignore
  2. 106 0
      README.md
  3. 106 0
      README_CN.md
  4. 20 0
      configs/benchmarks/Composition-1k/closeform_composition1k.yml
  5. 20 0
      configs/benchmarks/Distinctions-646/closeform_distinctions646.yml
  6. 9 0
      configs/benchmarks/PPM/README.md
  7. 19 0
      configs/benchmarks/PPM/closeform.yml
  8. 6 0
      configs/benchmarks/PPM/fast.yml
  9. 6 0
      configs/benchmarks/PPM/knn.yml
  10. 6 0
      configs/benchmarks/PPM/learningbased.yml
  11. 6 0
      configs/benchmarks/PPM/randomwalks.yml
  12. 45 0
      configs/dim/dim-vgg16.yml
  13. 54 0
      configs/human_matting/human_matting-resnet34_vd.yml
  14. 5 0
      configs/modnet/modnet-hrnet_w18.yml
  15. 47 0
      configs/modnet/modnet-mobilenetv2.yml
  16. 5 0
      configs/modnet/modnet-resnet50_vd.yml
  17. 20 0
      configs/ppmatting/README.md
  18. 29 0
      configs/ppmatting/ppmatting-hrnet_w18-human_1024.yml
  19. 44 0
      configs/ppmatting/ppmatting-hrnet_w18-human_512.yml
  20. 7 0
      configs/ppmatting/ppmatting-hrnet_w48-composition.yml
  21. 55 0
      configs/ppmatting/ppmatting-hrnet_w48-distinctions.yml
  22. 66 0
      configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml
  23. 47 0
      configs/quick_start/modnet-mobilenetv2.yml
  24. 66 0
      configs/quick_start/ppmattingv2-stdc1-human_512.yml
  25. 9 0
      configs/rvm/README.md
  26. 18 0
      configs/rvm/rvm-mobilenetv3.yaml
  27. BIN
      demo/human.jpg
  28. 13 0
      deploy/human_matting_android_demo/.gitignore
  29. 201 0
      deploy/human_matting_android_demo/LICENSE
  30. 140 0
      deploy/human_matting_android_demo/README.md
  31. 141 0
      deploy/human_matting_android_demo/README_CN.md
  32. 1 0
      deploy/human_matting_android_demo/app/.gitignore
  33. 118 0
      deploy/human_matting_android_demo/app/build.gradle
  34. 172 0
      deploy/human_matting_android_demo/app/gradlew
  35. 84 0
      deploy/human_matting_android_demo/app/gradlew.bat
  36. 8 0
      deploy/human_matting_android_demo/app/local.properties
  37. 21 0
      deploy/human_matting_android_demo/app/proguard-rules.pro
  38. 26 0
      deploy/human_matting_android_demo/app/src/androidTest/java/com/baidu/paddle/lite/demo/ExampleInstrumentedTest.java
  39. 29 0
      deploy/human_matting_android_demo/app/src/main/AndroidManifest.xml
  40. BIN
      deploy/human_matting_android_demo/app/src/main/assets/image_matting/images/bg.jpg
  41. BIN
      deploy/human_matting_android_demo/app/src/main/assets/image_matting/images/human.jpg
  42. 2 0
      deploy/human_matting_android_demo/app/src/main/assets/image_matting/labels/label_list
  43. 127 0
      deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/AppCompatPreferenceActivity.java
  44. 562 0
      deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/MainActivity.java
  45. 302 0
      deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/Predictor.java
  46. 158 0
      deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/SettingsActivity.java
  47. 87 0
      deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/Utils.java
  48. 38 0
      deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/config/Config.java
  49. 79 0
      deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/preprocess/Preprocess.java
  50. 151 0
      deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/visual/Visualize.java
  51. 34 0
      deploy/human_matting_android_demo/app/src/main/res/drawable-v24/ic_launcher_foreground.xml
  52. 170 0
      deploy/human_matting_android_demo/app/src/main/res/drawable/ic_launcher_background.xml
  53. BIN
      deploy/human_matting_android_demo/app/src/main/res/drawable/paddle_logo.png
  54. 132 0
      deploy/human_matting_android_demo/app/src/main/res/layout/activity_main.xml
  55. 21 0
      deploy/human_matting_android_demo/app/src/main/res/menu/menu_action_options.xml
  56. 5 0
      deploy/human_matting_android_demo/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml
  57. 5 0
      deploy/human_matting_android_demo/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml
  58. BIN
      deploy/human_matting_android_demo/app/src/main/res/mipmap-hdpi/ic_launcher.png
  59. BIN
      deploy/human_matting_android_demo/app/src/main/res/mipmap-hdpi/ic_launcher_round.png
  60. BIN
      deploy/human_matting_android_demo/app/src/main/res/mipmap-mdpi/ic_launcher.png
  61. BIN
      deploy/human_matting_android_demo/app/src/main/res/mipmap-mdpi/ic_launcher_round.png
  62. BIN
      deploy/human_matting_android_demo/app/src/main/res/mipmap-xhdpi/ic_launcher.png
  63. BIN
      deploy/human_matting_android_demo/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png
  64. BIN
      deploy/human_matting_android_demo/app/src/main/res/mipmap-xxhdpi/ic_launcher.png
  65. BIN
      deploy/human_matting_android_demo/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png
  66. BIN
      deploy/human_matting_android_demo/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png
  67. BIN
      deploy/human_matting_android_demo/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png
  68. 39 0
      deploy/human_matting_android_demo/app/src/main/res/values/arrays.xml
  69. 6 0
      deploy/human_matting_android_demo/app/src/main/res/values/colors.xml
  70. 22 0
      deploy/human_matting_android_demo/app/src/main/res/values/strings.xml
  71. 25 0
      deploy/human_matting_android_demo/app/src/main/res/values/styles.xml
  72. 59 0
      deploy/human_matting_android_demo/app/src/main/res/xml/settings.xml
  73. 17 0
      deploy/human_matting_android_demo/app/src/test/java/com/baidu/paddle/lite/demo/ExampleUnitTest.java
  74. 27 0
      deploy/human_matting_android_demo/build.gradle
  75. 17 0
      deploy/human_matting_android_demo/gradle.properties
  76. BIN
      deploy/human_matting_android_demo/gradle/wrapper/gradle-wrapper.jar
  77. 6 0
      deploy/human_matting_android_demo/gradle/wrapper/gradle-wrapper.properties
  78. 172 0
      deploy/human_matting_android_demo/gradlew
  79. 84 0
      deploy/human_matting_android_demo/gradlew.bat
  80. 1 0
      deploy/human_matting_android_demo/settings.gradle
  81. 782 0
      deploy/python/infer.py
  82. 68 0
      docs/data_prepare_cn.md
  83. 72 0
      docs/data_prepare_en.md
  84. 270 0
      docs/full_develop_cn.md
  85. 269 0
      docs/full_develop_en.md
  86. 8 0
      docs/online_demo_cn.md
  87. 10 0
      docs/online_demo_en.md
  88. 116 0
      docs/quick_start_cn.md
  89. 119 0
      docs/quick_start_en.md
  90. 36 0
      main.py
  91. BIN
      models/ppmatting-hrnet_w18-human_512.pdparams
  92. BIN
      outputs/human_alpha.png
  93. BIN
      outputs/human_rgba.png
  94. 1 0
      ppmatting/__init__.py
  95. 6 0
      ppmatting/core/__init__.py
  96. 217 0
      ppmatting/core/bg_replace_video.py
  97. 212 0
      ppmatting/core/predict.py
  98. 168 0
      ppmatting/core/predict_video.py
  99. 353 0
      ppmatting/core/train.py
  100. 176 0
      ppmatting/core/val.py

+ 120 - 0
.gitignore

@@ -0,0 +1,120 @@
+# Mac system
+.DS_Store
+
+# Pycharm
+.idea/
+
+# VSCode
+.vscode/
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# js
+node_modules/
+package-lock.json
+test_tipc/web/models/
+
+# EISeg
+EISeg/eiseg/config/setting.txt

+ 106 - 0
README.md

@@ -0,0 +1,106 @@
+English | [简体中文](README_CN.md)
+
+# Image Matting
+
+## Contents
+* [Introduction](#Introduction)
+* [Update Notes](#Update-Notes)
+* [Community](#Community)
+* [Models](#Models)
+* [Tutorials](#Tutorials)
+* [Acknowledgement](#Acknowledgement)
+* [Citation](#Citation)
+
+
+## Introduction
+
+Image Matting is the technique of extracting foreground from an image by calculating its color and transparency.
+It is widely used in the film industry to replace background, image composition, and visual effects.
+Each pixel in the image will have a value that represents its foreground transparency, called Alpha.
+The set of all Alpha values in an image is called Alpha Matte.
+The part of the image covered by the mask can be extracted to complete foreground separation.
+
+
+<p align="center">
+<img src="https://user-images.githubusercontent.com/30919197/179751613-d26f2261-7bcf-4066-a0a4-4c818e7065f0.gif" width="100%" height="100%">
+</p>
+
+## Update Notes
+* 2022.11
+  * **Release self developed lite matting SOTA model PP-MattingV2**. Compared with MODNet, the inference speed of PP-MattingV2 is increased by 44.6%, and the average error is decreased by 17.91%.
+  * Adjust the document structure and improve the model zoo information.
+  * [FastDeploy](https://github.com/PaddlePaddle/FastDeploy) support PP-MattingV2, PP-Matting, PP-HumanMatting and MODNet models.
+* 2022.07
+  * Release PP-Matting code. Add ClosedFormMatting, KNNMatting, FastMatting, LearningBaseMatting and RandomWalksMatting traditional machine learning algorithms.
+  Add GCA model.
+  * upport to specify metrics for evaluation. Support to specify metrics for evaluation.
+* 2022.04
+  * **Release self developed high accuracy matting SOTA model PP-Matting**. Add PP-HumanMatting high-resolution human matting model.
+  * Add Grad, Conn evaluation metrics. Add foreground evaluation funciton, which use [ML](https://arxiv.org/pdf/2006.14970.pdf) algorithm to evaluate foreground when prediction or background replacement.
+  * Add GradientLoss and LaplacianLoss. Add RandomSharpen, RandomSharpen, RandomReJpeg, RSSN data augmentation strategies.
+
+* 2021.11
+  * **Matting Project is released**, which Realizes image matting function.
+  * Support Matting models: DIM, MODNet. Support model export and python deployment. Support background replacement function. Support human matting deployment in Android.
+
+## Community
+
+* If you have any questions, suggestions and feature requests, please create an issues in [GitHub Issues](https://github.com/PaddlePaddle/PaddleSeg/issues).
+* Welcome to scan the following QR code and join paddleseg wechat group to communicate with us.
+<div align="center">
+<img src="https://user-images.githubusercontent.com/30883834/213601179-0813a896-11e1-4514-b612-d145e068ba86.jpeg"  width = "200" />  
+</div>
+
+## Models
+
+For the widely application scenario -- human matting, we have trained and open source the ** high-quality human matting models**.
+According the actual application scenario, you can directly deploy or finetune.
+
+The model zoo includes our self developded high accuracy model PP-Matting and lite model PP-MattingV2.
+- PP-Matting is a high accuracy matting model developded by PaddleSeg, which realizes high-resolution image matting under semantic guidance by the design of Guidance Flow.
+    For high accuracy, this model is recommended. Two pre-trained models are opened source with 512 and 1024 resolution level.
+
+- PP-MattingV2 is a lite matting SOTA model developed by PaddleSeg. It extracts high-level semantc informating by double-pyramid pool and spatial attention,
+    and uses multi-level feature fusion mechanism for both semantic and detail prediciton.
+
+| Model | SAD | MSE | Grad | Conn |Params(M) | FLOPs(G) | FPS | Config File | Checkpoint | Inference Model |
+| - | - | -| - | - | - | - | -| - | - | - |
+| PP-MattingV2-512   |40.59|0.0038|33.86|38.90| 8.95 | 7.51 | 98.89 |[cfg](../configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/ppmattingv2-stdc1-human_512.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/ppmattingv2-stdc1-human_512.zip) |
+| PP-Matting-512     |31.56|0.0022|31.80|30.13| 24.5 | 91.28 | 28.9 |[cfg](../configs/ppmatting/ppmatting-hrnet_w18-human_512.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/ppmatting-hrnet_w18-human_512.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/ppmatting-hrnet_w18-human_512.zip) |
+| PP-Matting-1024    |66.22|0.0088|32.90|64.80| 24.5 | 91.28 | 13.4(1024X1024) |[cfg](../configs/ppmatting/ppmatting-hrnet_w18-human_1024.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/ppmatting-hrnet_w18-human_1024.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/ppmatting-hrnet_w18-human_1024.zip) |
+| PP-HumanMatting    |53.15|0.0054|43.75|52.03| 63.9 | 135.8 (2048X2048)| 32.8(2048X2048)|[cfg](../configs/human_matting/human_matting-resnet34_vd.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/human_matting-resnet34_vd.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/pp-humanmatting-resnet34_vd.zip) |
+| MODNet-MobileNetV2 |50.07|0.0053|35.55|48.37| 6.5 | 15.7 | 68.4 |[cfg](../configs/modnet/modnet-mobilenetv2.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/modnet-mobilenetv2.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/modnet-mobilenetv2.zip) |
+| MODNet-ResNet50_vd |39.01|0.0038|32.29|37.38| 92.2 | 151.6 | 29.0 |[cfg](../configs/modnet/modnet-resnet50_vd.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/modnet-resnet50_vd.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/modnet-resnet50_vd.zip) |
+| MODNet-HRNet_W18   |35.55|0.0035|31.73|34.07| 10.2 | 28.5 | 62.6 |[cfg](../configs/modnet/modnet-hrnet_w18.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/modnet-hrnet_w18.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/modnet-hrnet_w18.zip) |
+| DIM-VGG16          |32.31|0.0233|28.89|31.45| 28.4 | 175.5| 30.4 |[cfg](../configs/dim/dim-vgg16.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/dim-vgg16.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/dim-vgg16.zip) |
+
+**Note**:
+* The dataset for metrics is composed of PPM-100 and human part of AIM-500, with a total of 195 images, which named [PPM-AIM-195](https://paddleseg.bj.bcebos.com/matting/datasets/PPM-AIM-195.zip).
+* The model default input size is (512, 512) while calculating FLOPs and FPS and the GPU is Tesla V100 32G. FPS is calculated base on Paddle Inference.
+* DIM is a trimap-base matting method, which metrics are calculated in transition area.
+    If no trimap image is provided, the area  0<alpha<255 is used as the transition area after dilation erosion with a radius of 25 pixels.
+
+## Tutorials
+* [Online experience](docs/online_demo_en.md)
+* [Quick start](docs/quick_start_en.md)
+* [Full development](docs/full_develop_en.md)
+* [Human matting android deployment](deploy/human_matting_android_demo/README.md)
+* [Human matting .NET deployment](https://gitee.com/raoyutian/PaddleSegSharp)
+* [Dataset preparation](docs/data_prepare_en.md)
+* AI Studio tutorials
+  * [The Matting tutorial of PaddleSeg](https://aistudio.baidu.com/aistudio/projectdetail/3876411?contributionType=1)
+  * [The image matting tutorial of PP-Matting](https://aistudio.baidu.com/aistudio/projectdetail/5002963?contributionType=1)
+
+## Acknowledgement
+* Thanks [Qian bin](https://github.com/qianbin1989228) for their contributons.
+* Thanks for the algorithm support of [GFM](https://arxiv.org/abs/2010.16188).
+
+## Citation
+```
+@article{chen2022pp,
+  title={PP-Matting: High-Accuracy Natural Image Matting},
+  author={Chen, Guowei and Liu, Yi and Wang, Jian and Peng, Juncai and Hao, Yuying and Chu, Lutao and Tang, Shiyu and Wu, Zewu and Chen, Zeyu and Yu, Zhiliang and others},
+  journal={arXiv preprint arXiv:2204.09433},
+  year={2022}
+}
+```

+ 106 - 0
README_CN.md

@@ -0,0 +1,106 @@
+简体中文 | [English](README.md)
+
+# Image Matting
+
+## 目录
+* [简介](#简介)
+* [更新动态](#更新动态)
+* [技术交流](#技术交流)
+* [模型库](#模型库)
+* [使用教程](#使用教程)
+* [社区贡献](#社区贡献)
+* [学术引用](#学术引用)
+
+
+## 简介
+
+Image Matting(精细化分割/影像去背/抠图)是指借由计算前景的颜色和透明度,将前景从影像中撷取出来的技术,可用于替换背景、影像合成、视觉特效,在电影工业中被广泛地使用。
+影像中的每个像素会有代表其前景透明度的值,称作阿法值(Alpha),一张影像中所有阿法值的集合称作阿法遮罩(Alpha Matte),将影像被遮罩所涵盖的部分取出即可完成前景的分离。
+
+
+<p align="center">
+<img src="https://user-images.githubusercontent.com/30919197/179751613-d26f2261-7bcf-4066-a0a4-4c818e7065f0.gif" width="100%" height="100%">
+</p>
+
+## 更新动态
+* 2022.11
+  * **开源自研轻量级抠图SOTA模型PP-MattingV2**。对比MODNet, PP-MattingV2推理速度提升44.6%, 误差平均相对减小17.91%。
+  * 调整文档结构,完善模型库信息。
+  * [FastDeploy](https://github.com/PaddlePaddle/FastDeploy)部署支持PP-MattingV2, PP-Matting, PP-HumanMatting和MODNet模型。
+* 2022.07
+  * 开源PP-Matting代码;新增ClosedFormMatting、KNNMatting、FastMatting、LearningBaseMatting和RandomWalksMatting传统机器学习算法;新增GCA模型。
+  * 完善目录结构;支持指定指标进行评估。
+* 2022.04
+  * **开源自研高精度抠图SOTA模型PP-Matting**;新增PP-HumanMatting高分辨人像抠图模型。
+  * 新增Grad、Conn评估指标;新增前景评估功能,利用[ML](https://arxiv.org/pdf/2006.14970.pdf)算法在预测和背景替换时进行前景评估。
+  * 新增GradientLoss和LaplacianLoss;新增RandomSharpen、RandomSharpen、RandomReJpeg、RSSN数据增强策略。
+* 2021.11
+  * **Matting项目开源**, 实现图像抠图功能。
+  * 支持Matting模型:DIM, MODNet;支持模型导出及Python部署;支持背景替换功能;支持人像抠图Android部署。
+
+## 技术交流
+
+* 如果大家有使用问题和功能建议, 可以通过[GitHub Issues](https://github.com/PaddlePaddle/PaddleSeg/issues)提issue。
+* **欢迎加入PaddleSeg的微信用户群👫**(扫码填写简单问卷即可入群),大家可以和值班同学、各界大佬直接进行交流,还可以**领取30G重磅学习大礼包🎁**
+  * 🔥 获取深度学习视频教程、图像分割论文合集
+  * 🔥 获取PaddleSeg的历次直播视频,最新发版信息和直播动态
+  * 🔥 获取PaddleSeg自建的人像分割数据集,整理的开源数据集
+  * 🔥 获取PaddleSeg在垂类场景的预训练模型和应用合集,涵盖人像分割、交互式分割等等
+  * 🔥 获取PaddleSeg的全流程产业实操范例,包括质检缺陷分割、抠图Matting、道路分割等等
+<div align="center">
+<img src="https://user-images.githubusercontent.com/30883834/213601179-0813a896-11e1-4514-b612-d145e068ba86.jpeg"  width = "200" />  
+</div>
+
+## 模型库
+
+针对高频应用场景 —— 人像抠图,我们训练并开源了**高质量人像抠图模型库**。根据实际应用场景,大家可以直接部署应用,也支持进行微调训练。
+
+模型库中包括我们自研的高精度PP-Matting模型和轻量级PP-MattingV2模型。
+- PP-Matting是PaddleSeg自研的高精度抠图模型,通过引导流设计实现语义引导下高分辨率图像抠图。追求更高精度,推荐使用该模型。
+    且该模型提供了512和1024两个分辨率级别的预训练模型。
+- PP-MattingV2是PaddleSeg自研的轻量级抠图SOTA模型,通过双层金字塔池化及空间注意力提取高级语义信息,并利用多级特征融合机制兼顾语义和细节的预测。
+    对比MODNet模型推理速度提升44.6%, 误差平均相对减小17.91%。追求更高速度,推荐使用该模型。
+
+| 模型 | SAD | MSE | Grad | Conn |Params(M) | FLOPs(G) | FPS | Config File | Checkpoint | Inference Model |
+| - | - | -| - | - | - | - | -| - | - | - |
+| PP-MattingV2-512   |40.59|0.0038|33.86|38.90| 8.95 | 7.51 | 98.89 |[cfg](../configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/ppmattingv2-stdc1-human_512.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/ppmattingv2-stdc1-human_512.zip) |
+| PP-Matting-512     |31.56|0.0022|31.80|30.13| 24.5 | 91.28 | 28.9 |[cfg](../configs/ppmatting/ppmatting-hrnet_w18-human_512.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/ppmatting-hrnet_w18-human_512.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/ppmatting-hrnet_w18-human_512.zip) |
+| PP-Matting-1024    |66.22|0.0088|32.90|64.80| 24.5 | 91.28 | 13.4(1024X1024) |[cfg](../configs/ppmatting/ppmatting-hrnet_w18-human_1024.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/ppmatting-hrnet_w18-human_1024.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/ppmatting-hrnet_w18-human_1024.zip) |
+| PP-HumanMatting    |53.15|0.0054|43.75|52.03| 63.9 | 135.8 (2048X2048)| 32.8(2048X2048)|[cfg](../configs/human_matting/human_matting-resnet34_vd.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/human_matting-resnet34_vd.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/pp-humanmatting-resnet34_vd.zip) |
+| MODNet-MobileNetV2 |50.07|0.0053|35.55|48.37| 6.5 | 15.7 | 68.4 |[cfg](../configs/modnet/modnet-mobilenetv2.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/modnet-mobilenetv2.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/modnet-mobilenetv2.zip) |
+| MODNet-ResNet50_vd |39.01|0.0038|32.29|37.38| 92.2 | 151.6 | 29.0 |[cfg](../configs/modnet/modnet-resnet50_vd.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/modnet-resnet50_vd.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/modnet-resnet50_vd.zip) |
+| MODNet-HRNet_W18   |35.55|0.0035|31.73|34.07| 10.2 | 28.5 | 62.6 |[cfg](../configs/modnet/modnet-hrnet_w18.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/modnet-hrnet_w18.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/modnet-hrnet_w18.zip) |
+| DIM-VGG16          |32.31|0.0233|28.89|31.45| 28.4 | 175.5| 30.4 |[cfg](../configs/dim/dim-vgg16.yml)| [model](https://paddleseg.bj.bcebos.com/matting/models/dim-vgg16.pdparams) | [model inference](https://paddleseg.bj.bcebos.com/matting/models/deploy/dim-vgg16.zip) |
+
+**注意**:
+* 指标计算数据集为PPM-100和AIM-500中的人像部分共同组成,共195张,[PPM-AIM-195](https://paddleseg.bj.bcebos.com/matting/datasets/PPM-AIM-195.zip)。
+* FLOPs和FPS计算默认模型输入大小为(512, 512), GPU为Tesla V100 32G。FPS基于Paddle Inference预测库进行计算。
+* DIM为trimap-based的抠图方法,指标只计算过度区域部分,对于没有提供trimap的情况下,默认将0<alpha<255的区域以25像素为半径进行膨胀腐蚀后作为过度区域。
+
+## 使用教程
+* [在线体验](docs/online_demo_cn.md)
+* [快速体验](docs/quick_start_cn.md)
+* [全流程开发](docs/full_develop_cn.md)
+* [人像抠图Android部署](deploy/human_matting_android_demo/README_CN.md)
+* [人像抠图.NET部署](https://gitee.com/raoyutian/PaddleSegSharp)
+* [数据集准备](docs/data_prepare_cn.md)
+* AI Studio第三方教程
+  * [PaddleSeg的Matting教程](https://aistudio.baidu.com/aistudio/projectdetail/3876411?contributionType=1)
+  * [PP-Matting图像抠图教程](https://aistudio.baidu.com/aistudio/projectdetail/5002963?contributionType=1)
+
+## 社区贡献
+* 感谢[钱彬(Qianbin)](https://github.com/qianbin1989228)等开发者的贡献。
+* 感谢Jizhizi Li等提出的[GFM](https://arxiv.org/abs/2010.16188) Matting框架助力PP-Matting的算法研发。
+
+## 学术引用
+```
+@article{chen2022pp,
+  title={PP-Matting: High-Accuracy Natural Image Matting},
+  author={Chen, Guowei and Liu, Yi and Wang, Jian and Peng, Juncai and Hao, Yuying and Chu, Lutao and Tang, Shiyu and Wu, Zewu and Chen, Zeyu and Yu, Zhiliang and others},
+  journal={arXiv preprint arXiv:2204.09433},
+  year={2022}
+}
+```
+
+## 参考文档
+https://gitee.com/paddlepaddle/PaddleSeg/blob/release/2.8/Matting/docs/quick_start_cn.md

+ 20 - 0
configs/benchmarks/Composition-1k/closeform_composition1k.yml

@@ -0,0 +1,20 @@
+
+
+val_dataset:
+  type: Composition1K
+  dataset_root: data/Composition-1k
+  val_file: val.txt
+  separator: '|'
+  transforms:
+    - type: LoadImages
+    - type: ResizeByShort
+      short_size: 512
+    - type: ResizeToIntMult
+      mult_int: 32
+    - type: Normalize
+  mode: val
+  get_trimap: True
+
+model:
+  type: CloseFormMatting
+

+ 20 - 0
configs/benchmarks/Distinctions-646/closeform_distinctions646.yml

@@ -0,0 +1,20 @@
+
+
+val_dataset:
+  type: Distinctions646
+  dataset_root: data/Distinctions-646
+  val_file: val.txt
+  separator: '|'
+  transforms:
+    - type: LoadImages
+    - type: ResizeByShort
+      short_size: 512
+    - type: ResizeToIntMult
+      mult_int: 32
+    - type: Normalize
+  mode: val
+  get_trimap: True
+
+model:
+  type: CloseFormMatting
+

+ 9 - 0
configs/benchmarks/PPM/README.md

@@ -0,0 +1,9 @@
+### PPM
+
+| Method | SAD | MSE | Grad | Conn |
+|-|-|-|-|-|
+|ClosedFormMatting|40.6251|0.0782|55.5716|40.6646|
+|KNNMatting|41.5604|0.0681|52.5200|42.1784|
+|FastMatting|35.8735|0.0492|48.9267|35.6183|
+|LearningBasedMatting|40.5506|0.0776|55.3923|40.5690|
+|RandomWalksMatting|54.6315|0.0962|69.8779|54.0870|

+ 19 - 0
configs/benchmarks/PPM/closeform.yml

@@ -0,0 +1,19 @@
+
+
+val_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  val_file: val.txt
+  transforms:
+    - type: LoadImages
+    - type: ResizeByShort
+      short_size: 512
+    - type: ResizeToIntMult
+      mult_int: 32
+    - type: Normalize
+  mode: val
+  get_trimap: True
+
+model:
+  type: CloseFormMatting
+

+ 6 - 0
configs/benchmarks/PPM/fast.yml

@@ -0,0 +1,6 @@
+
+_base_: closeform.yml
+
+model:
+  type: FastMatting
+

+ 6 - 0
configs/benchmarks/PPM/knn.yml

@@ -0,0 +1,6 @@
+
+_base_: closeform.yml
+
+model:
+  type: KNNMatting
+

+ 6 - 0
configs/benchmarks/PPM/learningbased.yml

@@ -0,0 +1,6 @@
+
+_base_: closeform.yml
+
+model:
+  type: LearningBasedMatting
+

+ 6 - 0
configs/benchmarks/PPM/randomwalks.yml

@@ -0,0 +1,6 @@
+
+_base_: closeform.yml
+
+model:
+  type: RandomWalksMatting
+

+ 45 - 0
configs/dim/dim-vgg16.yml

@@ -0,0 +1,45 @@
+batch_size: 16
+iters: 100000
+
+train_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  train_file: train.txt
+  transforms:
+    - type: LoadImages
+    - type: RandomCropByAlpha
+      crop_size: [[320, 320], [480, 480], [640, 640]]
+    - type: Resize
+      target_size: [320, 320]
+    - type: RandomDistort
+    - type: RandomBlur
+    - type: RandomHorizontalFlip
+    - type: Normalize
+  mode: train
+  get_trimap: True
+
+val_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  val_file: val.txt
+  transforms:
+    - type: LoadImages
+    - type: LimitLong
+      max_long: 3840
+    - type: Normalize
+  mode: val
+  get_trimap: True
+
+model:
+  type: DIM
+  backbone:
+    type: VGG16
+    input_channels: 4
+    pretrained: https://paddleseg.bj.bcebos.com/matting/models/DIM_VGG16_pretrained/model.pdparams
+  pretrained: Null
+
+optimizer:
+  type: adam
+
+learning_rate:
+  value: 0.001

+ 54 - 0
configs/human_matting/human_matting-resnet34_vd.yml

@@ -0,0 +1,54 @@
+batch_size: 4
+iters: 50000
+
+train_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  train_file: train.txt
+  transforms:
+    - type: LoadImages
+    - type: RandomResize
+      size: [2048, 2048]
+      scale: [0.3, 1.5]
+    - type: RandomCrop
+      crop_size: [2048, 2048]
+    - type: RandomDistort
+    - type: RandomBlur
+      prob: 0.1
+    - type: RandomHorizontalFlip
+    - type: Padding
+      target_size: [2048, 2048]
+    - type: Normalize
+  mode: train
+
+val_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  val_file: val.txt
+  transforms:
+    - type: LoadImages
+    - type: ResizeByShort
+      short_size: 2048
+    - type: ResizeToIntMult
+      mult_int: 128
+    - type: Normalize
+  mode: val
+  get_trimap: False
+
+model:
+  type: HumanMatting
+  backbone:
+    type: ResNet34_vd
+    pretrained: https://paddleseg.bj.bcebos.com/matting/models/ResNet34_vd_pretrained/model.pdparams
+  pretrained: Null
+  if_refine: True
+
+optimizer:
+  type: sgd
+  momentum: 0.9
+  weight_decay: 4.0e-5
+
+lr_scheduler:
+  type: PiecewiseDecay
+  boundaries: [30000, 40000]
+  values: [0.001, 0.0001, 0.00001]

+ 5 - 0
configs/modnet/modnet-hrnet_w18.yml

@@ -0,0 +1,5 @@
+_base_: modnet-mobilenetv2.yml
+model:
+  backbone:
+    type: HRNet_W18
+    pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz

+ 47 - 0
configs/modnet/modnet-mobilenetv2.yml

@@ -0,0 +1,47 @@
+batch_size: 16
+iters: 100000
+
+train_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  train_file: train.txt
+  transforms:
+    - type: LoadImages
+    - type: RandomCrop
+      crop_size: [512, 512]
+    - type: RandomDistort
+    - type: RandomBlur
+    - type: RandomHorizontalFlip
+    - type: Normalize
+  mode: train
+
+val_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  val_file: val.txt
+  transforms:
+    - type: LoadImages
+    - type: ResizeByShort
+      short_size: 512
+    - type: ResizeToIntMult
+      mult_int: 32
+    - type: Normalize
+  mode: val
+  get_trimap: False
+
+model:
+  type: MODNet
+  backbone:
+    type: MobileNetV2
+    pretrained: https://paddleseg.bj.bcebos.com/matting/models/MobileNetV2_pretrained/model.pdparams
+  pretrained: Null
+
+optimizer:
+  type: sgd
+  momentum: 0.9
+  weight_decay: 4.0e-5
+
+lr_scheduler:
+  type: PiecewiseDecay
+  boundaries: [40000, 80000]
+  values: [0.02, 0.002, 0.0002]

+ 5 - 0
configs/modnet/modnet-resnet50_vd.yml

@@ -0,0 +1,5 @@
+_base_: modnet-mobilenetv2.yml
+model:
+  backbone:
+    type: ResNet50_vd
+    pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz

+ 20 - 0
configs/ppmatting/README.md

@@ -0,0 +1,20 @@
+# PP-Matting: High-Accuracy Natural Image Matting
+
+## Reference
+
+> Chen G, Liu Y, Wang J, et al. PP-Matting: High-Accuracy Natural Image Matting[J]. arXiv preprint arXiv:2204.09433, 2022.
+
+## Performance
+
+### Composition-1k
+
+| Model | Backbone | Resolution | Training Iters | SAD $\downarrow$ | MSE $\downarrow$ | Grad $\downarrow$ | Conn $\downarrow$ | Links |
+|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
+|PP-Matting|HRNet_W48|512x512|300000|46.22|0.005|22.69|45.40|[model](https://paddleseg.bj.bcebos.com/matting/models/ppmatting-hrnet_w48-composition.pdparams)|
+
+
+### Distinctions-646
+
+| Model | Backbone | Resolution | Training Iters | SAD $\downarrow$ | MSE $\downarrow$ | Grad $\downarrow$ | Conn $\downarrow$ | Links |
+|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
+|PP-Matting|HRNet_W48|512x512|300000|40.69|0.009|43.91|40.56|[model](https://paddleseg.bj.bcebos.com/matting/models/ppmatting-hrnet_w48-distinctions.pdparams)|

+ 29 - 0
configs/ppmatting/ppmatting-hrnet_w18-human_1024.yml

@@ -0,0 +1,29 @@
+_base_: 'ppmatting-hrnet_w18-human_512.yml'
+
+
+train_dataset:
+  transforms:
+    - type: LoadImages
+    - type: LimitShort
+      max_short: 1024
+    - type: RandomCrop
+      crop_size: [1024, 1024]
+    - type: RandomDistort
+    - type: RandomBlur
+      prob: 0.1
+    - type: RandomNoise
+      prob: 0.5
+    - type: RandomReJpeg
+      prob: 0.2
+    - type: RandomHorizontalFlip
+    - type: Normalize
+
+val_dataset:
+  transforms:
+    - type: LoadImages
+    - type: LimitShort
+      max_short: 1024
+    - type: ResizeToIntMult
+      mult_int: 32
+    - type: Normalize
+

+ 44 - 0
configs/ppmatting/ppmatting-hrnet_w18-human_512.yml

@@ -0,0 +1,44 @@
+_base_: 'ppmatting-hrnet_w48-distinctions.yml'
+
+batch_size: 4
+iters: 200000
+
+train_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  train_file: train.txt
+  transforms:
+    - type: LoadImages
+    - type: LimitShort
+      max_short: 512
+    - type: RandomCrop
+      crop_size: [512, 512]
+    - type: RandomDistort
+    - type: RandomBlur
+      prob: 0.1
+    - type: RandomNoise
+      prob: 0.5
+    - type: RandomReJpeg
+      prob: 0.2
+    - type: RandomHorizontalFlip
+    - type: Normalize
+  mode: train
+
+val_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  val_file: val.txt
+  transforms:
+    - type: LoadImages
+    - type: LimitShort
+      max_short: 512
+    - type: ResizeToIntMult
+      mult_int: 32
+    - type: Normalize
+  mode: val
+  get_trimap: False
+
+model:
+  backbone:
+    type: HRNet_W18
+    pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz

+ 7 - 0
configs/ppmatting/ppmatting-hrnet_w48-composition.yml

@@ -0,0 +1,7 @@
+_base_: 'ppmatting-hrnet_w48-distinctions.yml'
+
+train_dataset:
+  dataset_root: data/matting/Composition-1k
+
+val_dataset:
+  dataset_root: data/matting/Composition-1k

+ 55 - 0
configs/ppmatting/ppmatting-hrnet_w48-distinctions.yml

@@ -0,0 +1,55 @@
+batch_size: 4
+iters: 300000
+
+train_dataset:
+  type: MattingDataset
+  dataset_root: data/matting/Distinctions-646
+  train_file: train.txt
+  transforms:
+    - type: LoadImages
+    - type: Padding
+      target_size: [512, 512]
+    - type: RandomCrop
+      crop_size: [[512, 512],[640, 640], [800, 800]]
+    - type: Resize
+      target_size: [512, 512]
+    - type: RandomDistort
+    - type: RandomBlur
+      prob: 0.1
+    - type: RandomHorizontalFlip
+    - type: Normalize
+  mode: train
+  separator: '|'
+
+val_dataset:
+  type: MattingDataset
+  dataset_root: data/matting/Distinctions-646
+  val_file: val.txt
+  transforms:
+    - type: LoadImages
+    - type: LimitShort
+      max_short: 1536
+    - type: ResizeToIntMult
+      mult_int: 32
+    - type: Normalize
+  mode: val
+  get_trimap: False
+  separator: '|'
+
+model:
+  type: PPMatting
+  backbone:
+    type: HRNet_W48
+    pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w48_ssld.tar.gz
+  pretrained: Null
+
+optimizer:
+  type: sgd
+  momentum: 0.9
+  weight_decay: 4.0e-5
+
+lr_scheduler:
+  type: PolynomialDecay
+  learning_rate: 0.01
+  end_lr: 0
+  power: 0.9

+ 66 - 0
configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml

@@ -0,0 +1,66 @@
+batch_size: 16  # total batch size: 16
+iters: 100000
+
+train_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  train_file: train.txt
+  transforms:
+    - type: LoadImages
+    - type: LimitShort
+      max_short: 512
+    - type: RandomCrop
+      crop_size: [512, 512]
+    - type: Padding
+      target_size: [512, 512]
+    - type: RandomDistort
+    - type: RandomBlur
+      prob: 0.1
+    - type: RandomSharpen
+      prob: 0.2
+    - type: RandomNoise
+      prob: 0.5
+    - type: RandomReJpeg
+      prob: 0.2
+    - type: RandomHorizontalFlip
+    - type: Normalize
+  mode: train
+
+val_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  val_file: val.txt
+  transforms:
+    - type: LoadImages
+    - type: LimitShort
+      max_short: 512
+    - type: ResizeToIntMult
+      mult_int: 32
+    - type: Normalize
+  mode: val
+  get_trimap: False
+
+model:
+  type: PPMattingV2
+  backbone:
+    type: STDC1
+    pretrained: https://bj.bcebos.com/paddleseg/dygraph/PP_STDCNet1.tar.gz
+  decoder_channels: [128, 96, 64, 32, 16]
+  head_channel: 8
+  dpp_output_channel: 256
+  dpp_merge_type: add
+  
+
+optimizer:
+  type: sgd
+  momentum: 0.9
+  weight_decay: 5.0e-4
+
+lr_scheduler:
+  type: PolynomialDecay
+  learning_rate: 0.01
+  end_lr: 0
+  power: 0.9
+  warmup_iters: 1000
+  warmup_start_lr: 1.0e-5
+

+ 47 - 0
configs/quick_start/modnet-mobilenetv2.yml

@@ -0,0 +1,47 @@
+batch_size: 1
+iters: 1000
+
+train_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  train_file: train.txt
+  transforms:
+    - type: LoadImages
+    - type: RandomCrop
+      crop_size: [512, 512]
+    - type: RandomDistort
+    - type: RandomBlur
+    - type: RandomHorizontalFlip
+    - type: Normalize
+  mode: train
+
+val_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  val_file: val.txt
+  transforms:
+    - type: LoadImages
+    - type: ResizeByShort
+      short_size: 512
+    - type: ResizeToIntMult
+      mult_int: 32
+    - type: Normalize
+  mode: val
+  get_trimap: False
+
+model:
+  type: MODNet
+  backbone:
+    type: MobileNetV2
+    pretrained: https://paddleseg.bj.bcebos.com/matting/models/MobileNetV2_pretrained/model.pdparams
+  pretrained: Null
+
+optimizer:
+  type: sgd
+  momentum: 0.9
+  weight_decay: 4.0e-5
+
+lr_scheduler:
+  type: PiecewiseDecay
+  boundaries: [40000, 80000]
+  values: [0.02, 0.002, 0.0002]

+ 66 - 0
configs/quick_start/ppmattingv2-stdc1-human_512.yml

@@ -0,0 +1,66 @@
+batch_size: 1 
+iters: 1000
+
+train_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  train_file: train.txt
+  transforms:
+    - type: LoadImages
+    - type: LimitShort
+      max_short: 512
+    - type: RandomCrop
+      crop_size: [512, 512]
+    - type: Padding
+      target_size: [512, 512]
+    - type: RandomDistort
+    - type: RandomBlur
+      prob: 0.1
+    - type: RandomSharpen
+      prob: 0.2
+    - type: RandomNoise
+      prob: 0.5
+    - type: RandomReJpeg
+      prob: 0.2
+    - type: RandomHorizontalFlip
+    - type: Normalize
+  mode: train
+
+val_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  val_file: val.txt
+  transforms:
+    - type: LoadImages
+    - type: LimitShort
+      max_short: 512
+    - type: ResizeToIntMult
+      mult_int: 32
+    - type: Normalize
+  mode: val
+  get_trimap: False
+
+model:
+  type: PPMattingV2
+  backbone:
+    type: STDC1
+    pretrained: https://bj.bcebos.com/paddleseg/dygraph/PP_STDCNet1.tar.gz
+  decoder_channels: [128, 96, 64, 32, 16]
+  head_channel: 8
+  dpp_output_channel: 256
+  dpp_merge_type: add
+  
+
+optimizer:
+  type: sgd
+  momentum: 0.9
+  weight_decay: 5.0e-4
+
+lr_scheduler:
+  type: PolynomialDecay
+  learning_rate: 0.01
+  end_lr: 0
+  power: 0.9
+  warmup_iters: 100
+  warmup_start_lr: 1.0e-5
+

+ 9 - 0
configs/rvm/README.md

@@ -0,0 +1,9 @@
+# Robust High-Resolution Video Matting with Temporal Guidance
+
+## Reference
+> Lin S, Yang L, Saleemi I, et al. Robust high-resolution video matting with temporal guidance[C]//Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2022: 238-247.
+
+## Download
+| Model | Backbone | Links |
+|:-:|:-:|:-:|
+| RVM | MobileNetV3-Large|[model](https://paddleseg.bj.bcebos.com/matting/models/rvm-mobilenetv3.pdparams)|

+ 18 - 0
configs/rvm/rvm-mobilenetv3.yaml

@@ -0,0 +1,18 @@
+model:
+  type: RVM
+  backbone:
+    type: MobileNetV3_large_x1_0_os16
+    out_index: [0, 2, 5]
+    return_last_conv: True
+  refiner: 'deep_guided_filter'
+  downsample_ratio: 0.25
+  pretrained: Null
+
+val_dataset:
+  transforms:
+    - type: LoadImages
+      to_rgb: True
+    - type: Normalize
+      mean: [0, 0, 0]
+      std: [1, 1, 1]
+

BIN
demo/human.jpg


+ 13 - 0
deploy/human_matting_android_demo/.gitignore

@@ -0,0 +1,13 @@
+*.iml
+.gradle
+/local.properties
+/.idea/caches
+/.idea/libraries
+/.idea/modules.xml
+/.idea/workspace.xml
+/.idea/navEditor.xml
+/.idea/assetWizardSettings.xml
+.DS_Store
+/build
+/captures
+.externalNativeBuild

+ 201 - 0
deploy/human_matting_android_demo/LICENSE

@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.

+ 140 - 0
deploy/human_matting_android_demo/README.md

@@ -0,0 +1,140 @@
+English | [简体中文](README_CN.md)
+
+# Human Matting Android Demo
+Based on [PaddleSeg](https://github.com/paddlepaddle/paddleseg/tree/develop) [MODNet](https://github.com/PaddlePaddle/PaddleSeg/tree/develop/contrib/Matting) algorithm to realise human matting(Android demo).
+
+You can directly download and install the example project [apk](https://paddleseg.bj.bcebos.com/matting/models/deploy/app-debug.apk) to experience。
+
+## 1. Results Exhibition
+<div align="center">
+<img src=https://user-images.githubusercontent.com/14087480/141890516-6aad4691-9ab3-4baf-99e5-f1afa1b21281.png  width="50%">
+
+</div>
+
+
+## 2. Android Demo Instructions
+
+### 2.1 Reruirements
+* Android Studio 3.4;
+* Android mobile phone;
+
+### 2.2 Installation
+* open Android Studio and on "Welcome to Android Studio" window, click "Open an existing Android Studio project". In the path selection window that is displayed, select the folder corresponding to the Android Demo. Then click the "Open" button in the lower right corner to import the project. The Lite prediction library required by demo will be automatically downloaded during the construction process.
+* Connect Android phone via USB;
+* After loading the project, click the Run->Run 'App' button on the menu bar, Select the connected Android device in the "Select Deployment Target" window, and then click the "OK" button;
+
+*Note:this Android demo is based on [Paddle-Lite](https://paddlelite.paddlepaddle.org.cn/),PaddleLite version is 2.8.0*
+
+### 2.3 Prediction
+* In human matting demo, a human image will be loaded by default, and the CPU prediction result and prediction delay will be given below the image;
+* In the human matting demo, you can also load test images from the album or camera by clicking the "Open local album" and "Open camera to take photos" buttons in the upper right corner, and then perform prediction.
+
+*Note:When taking a photo in demo, the photo will be compressed automatically. If you want to test the effect of the original photo, you can use the mobile phone camera to take a photo and open it from the album for prediction.*
+
+## 3. Secondary Development
+The inference library or model can be updated according to the need for secondary development. The updated model can be divided into two steps: model export and model transformation.
+
+### 3.1 Update Inference Library
+[Paddle-Lite website](https://paddlelite.paddlepaddle.org.cn/) probides a pre-compiled version of Android inference library, which can also be compiled by referring to the official website.
+
+The Paddle-Lite inference library on Android mainly contains three files:
+
+* PaddlePredictor.jar;
+* arm64-v8a/libpaddle_lite_jni.so;
+* armeabi-v7a/libpaddle_lite_jni.so;
+
+Two methods will be introduced in the following:
+
+* Use a precompiled version of the inference library. Latest precompiled file reference:[release](https://github.com/PaddlePaddle/Paddle-Lite/releases/). This demo uses the [version](https://paddlelite-demo.bj.bcebos.com/libs/android/paddle_lite_libs_v2_8_0.tar.gz)
+
+    Uncompress the above files and the PaddlePredictor.jar is in java/PaddlePredictor.jar;
+
+    The so file about arm64-v8a is in java/libs/arm64-v8a;
+
+    The so file about armeabi-v7a is in java/libs/armeabi-v7a;
+
+* Manually compile the paddle-Lite inference library
+Development environment preparation and compilation methods refer to [paddle-Lite source code compilation](https://paddle-lite.readthedocs.io/zh/release-v2.8/source_compile/compile_env.html).
+
+Prepare the above documents, then refer [java_api](https://paddle-lite.readthedocs.io/zh/release-v2.8/api_reference/java_api_doc.html)to have a inference on Android. Refer to the documentation in the Update Inference Library section of [Paddle-Lite-Demo](https://github.com/PaddlePaddle/Paddle-Lite-Demo) for details on how to use the inference library.
+
+### 3.2 Model Export
+This demo uses the MODNet with HRNet_W18 backbone to perform human matting. Please refer to official websit to get model [training tutorial](https://github.com/PaddlePaddle/PaddleSeg/tree/develop/contrib/Matting). There are 3 models provided with different backbones: MobileNetV2、ResNet50_vd and HRNet_W18. This Android demo considers the accuracy and speed, using HRNet_W18 as the Backone. The trained dynamic graph model can be downloaded directly from the official website for algorithm verification.
+
+In order to be able to infer on Android phones, the dynamic graph model needs to be exported as a static graph model, and the input size of the image should be fixed when exporting.
+
+First, update the [PaddleSeg](https://github.com/paddlepaddle/paddleseg/tree/develop) repository. Then `cd` to the `PaddleSeg/contrib/Matting` directory. Then put the downloaded modnet-hrnet_w18.pdparams (traing by youself is ok) on current directory(`PaddleSeg/contrib/Matting`). After that, fix the config file `configs/modnet_mobilenetv2.yml`(note: hrnet18 is used, but the config file `modnet_hrnet_w18.yml` is based on `modnet_mobilenetv2.yml`), where,modify the val_dataset field as follows:
+
+``` yml
+val_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  val_file: val.txt
+  transforms:
+    - type: LoadImages
+    - type: ResizeByShort
+      short_size: 256
+    - type: ResizeToIntMult
+      mult_int: 32
+    - type: Normalize
+  mode: val
+  get_trimap: False
+```
+In the above modification, pay special attention to the short_size: 256 field which directly determines the size of our final inferential image. If the value of this field is set too small, the prediction accuracy will be affected; if the value is set too high, the inference speed of the phone will be affected (or even the phone will crash due to performance problems). In practical testing, this field is set to 256 for HRnet18.
+
+After modifying the configuration file, run the following command to export the static graph:
+``` shell
+python export.py \
+    --config configs/modnet/modnet_hrnet_w18.yml \
+    --model_path modnet-hrnet_w18.pdparams \
+    --save_dir output
+```
+
+After the conversion, the `output` folder will be generated in the current directory, and the files in the folder are the static graph files.
+
+### 3.3 Model Conversion
+
+#### 3.3.1 Model Conversion Tool
+Once you have the static diagram model and parameter files exported from PaddleSeg ready, you need to optimize the model using the opt provided with Paddle-Lite and convert to the file format supported by Paddle-Lite.
+
+Firstly, install PaddleLite:
+
+``` shell
+pip install paddlelite==2.8.0
+```
+
+Then use the following Python script to convert:
+
+``` python
+# Reference the Paddlelite inference library
+from paddlelite.lite import *
+
+# 1. Create opt instance
+opt=Opt()
+
+# 2. Specify the static model path
+opt.set_model_file('./output/model.pdmodel')
+opt.set_param_file('./output/model.pdiparams')
+
+# 3. Specify conversion type: arm, x86, opencl or npu
+opt.set_valid_places("arm")
+# 4. Specifies the model transformation type: naive_buffer or protobuf
+opt.set_model_type("naive_buffer")
+# 5. Address of output model
+opt.set_optimize_out("./output/hrnet_w18")
+# 6. Perform model optimization
+opt.run()
+```
+
+After conversion, the `hrnet_w18.nb` file will be generated in the `output` directory.
+
+#### 3.3.2 Update Model
+Using the optimized `. Nb ` file to replace the file in `app/SRC/main/assets/image_matting/models/modnet` in android applications.
+
+Then change the image input size in the project: open the string.xml file and follow the under example:
+``` xml
+<string name="INPUT_SHAPE_DEFAULT">1,3,256,256</string>
+```
+1,3,256,256 represent the corresponding batchsize, channel, height and width respectively. Generally, height and width are modified to the size which set during model export.
+
+The entire android demo is implemented in Java, without embedded C++ code, which is relatively easy to build and execute. In the future, you can also move this demo to Java Web projects for human matting in the Web.

+ 141 - 0
deploy/human_matting_android_demo/README_CN.md

@@ -0,0 +1,141 @@
+简体中文 | [English](README.md)
+
+# Human Matting Android Demo
+基于[PaddleSeg](https://github.com/paddlepaddle/paddleseg/tree/develop)的[MODNet](https://github.com/PaddlePaddle/PaddleSeg/tree/develop/contrib/Matting)算法实现人像抠图(安卓版demo)。
+
+可以直接下载安装本示例工程中的[apk](https://paddleseg.bj.bcebos.com/matting/models/deploy/app-debug.apk)进行体验。
+
+## 1. 效果展示
+<div align="center">
+<img src=https://user-images.githubusercontent.com/14087480/141890516-6aad4691-9ab3-4baf-99e5-f1afa1b21281.png  width="50%">
+
+</div>
+
+
+## 2. 安卓Demo使用说明
+
+### 2.1 要求
+* Android Studio 3.4;
+* Android手机;
+
+### 2.2 一键安装
+* 打开Android Studio,在"Welcome to Android Studio"窗口点击"Open an existing Android Studio project",在弹出的路径选择窗口中选择本安卓demo对应的文件夹,然后点击右下角的"Open"按钮即可导入工程,构建工程的过程中会自动下载demo需要的Lite预测库;
+* 通过USB连接Android手机;
+* 载入工程后,点击菜单栏的Run->Run 'App'按钮,在弹出的"Select Deployment Target"窗口选择已经连接的Android设备,然后点击"OK"按钮;
+
+*注:此安卓demo基于[Paddle-Lite](https://paddlelite.paddlepaddle.org.cn/)开发,PaddleLite版本为2.8.0。*
+
+### 2.3 预测
+* 在人像抠图Demo中,默认会载入一张人像图像,并会在图像下方给出CPU的预测结果和预测延时;
+* 在人像抠图Demo中,你还可以通过右上角的"打开本地相册"和"打开摄像头拍照"按钮分别从相册或相机中加载测试图像然后进行预测推理;
+
+*注意:demo中拍照时照片会自动压缩,想测试拍照原图效果,可使用手机相机拍照后从相册中打开进行预测。*
+
+## 3. 二次开发
+可按需要更新预测库或模型进行二次开发,其中更新模型分为模型导出和模型转换两个步骤。
+
+### 3.1 更新预测库
+[Paddle-Lite官网](https://paddlelite.paddlepaddle.org.cn/)提供了预编译版本的安卓预测库,也可以参考官网自行编译。
+
+Paddle-Lite在安卓端的预测库主要包括三个文件:
+
+* PaddlePredictor.jar;
+* arm64-v8a/libpaddle_lite_jni.so;
+* armeabi-v7a/libpaddle_lite_jni.so;
+
+下面分别介绍两种方法:
+
+* 使用预编译版本的预测库,最新的预编译文件参考:[release](https://github.com/PaddlePaddle/Paddle-Lite/releases/),此demo使用的[版本](https://paddlelite-demo.bj.bcebos.com/libs/android/paddle_lite_libs_v2_8_0.tar.gz)
+
+    解压上面文件,PaddlePredictor.jar位于:java/PaddlePredictor.jar;
+
+    arm64-v8a相关so位于:java/libs/arm64-v8a;
+
+    armeabi-v7a相关so位于:java/libs/armeabi-v7a;
+
+* 手动编译Paddle-Lite预测库
+开发环境的准备和编译方法参考:[Paddle-Lite源码编译](https://paddle-lite.readthedocs.io/zh/release-v2.8/source_compile/compile_env.html)。
+
+准备好上述文件,即可参考[java_api](https://paddle-lite.readthedocs.io/zh/release-v2.8/api_reference/java_api_doc.html)在安卓端进行推理。具体使用预测库的方法可参考[Paddle-Lite-Demo](https://github.com/PaddlePaddle/Paddle-Lite-Demo)中更新预测库部分的文档。
+
+### 3.2 模型导出
+此demo的人像抠图采用Backbone为HRNet_W18的MODNet模型,模型[训练教程](https://github.com/PaddlePaddle/PaddleSeg/tree/develop/contrib/Matting)请参考官网,官网提供了3种不同性能的Backone:MobileNetV2、ResNet50_vd和HRNet_W18。本安卓demo综合考虑精度和速度要求,采用了HRNet_W18作为Backone。可以直接从官网下载训练好的动态图模型进行算法验证。
+
+为了能够在安卓手机上进行推理,需要将动态图模型导出为静态图模型,导出时固定图像输入尺寸即可。
+
+首先git最新的[PaddleSeg](https://github.com/paddlepaddle/paddleseg/tree/develop)项目,然后cd进入到PaddleSeg/contrib/Matting目录。将下载下来的modnet-hrnet_w18.pdparams动态图模型文件(也可以自行训练得到)放置在当前文件夹(PaddleSeg/contrib/Matting)下面。然后修改配置文件 configs/modnet_mobilenetv2.yml(注意:虽然采用hrnet18模型,但是该模型依赖的配置文件modnet_hrnet_w18.yml本身依赖modnet_mobilenetv2.yml),修改其中的val_dataset字段如下:
+
+``` yml
+val_dataset:
+  type: MattingDataset
+  dataset_root: data/PPM-100
+  val_file: val.txt
+  transforms:
+    - type: LoadImages
+    - type: ResizeByShort
+      short_size: 256
+    - type: ResizeToIntMult
+      mult_int: 32
+    - type: Normalize
+  mode: val
+  get_trimap: False
+```
+上述修改中尤其注意short_size: 256这个字段,这个值直接决定我们最终的推理图像采用的尺寸大小。这个字段值设置太小会影响预测精度,设置太大会影响手机推理速度(甚至造成手机因性能问题无法完成推理而崩溃)。经过实际测试,对于hrnet18,该字段设置为256较好。
+
+完成配置文件修改后,采用下面的命令进行静态图导出:
+``` shell
+python export.py \
+    --config configs/modnet/modnet_hrnet_w18.yml \
+    --model_path modnet-hrnet_w18.pdparams \
+    --save_dir output
+```
+
+转换完成后在当前目录下会生成output文件夹,该文件夹中的文件即为转出来的静态图文件。
+
+### 3.3 模型转换
+
+#### 3.3.1 模型转换工具
+准备好PaddleSeg导出来的静态图模型和参数文件后,需要使用Paddle-Lite提供的opt对模型进行优化,并转换成Paddle-Lite支持的文件格式。
+
+首先安装PaddleLite:
+
+``` shell
+pip install paddlelite==2.8.0
+```
+
+然后使用下面的python脚本进行转换:
+
+``` python
+# 引用Paddlelite预测库
+from paddlelite.lite import *
+
+# 1. 创建opt实例
+opt=Opt()
+
+# 2. 指定静态模型路径
+opt.set_model_file('./output/model.pdmodel')
+opt.set_param_file('./output/model.pdiparams')
+
+# 3. 指定转化类型: arm、x86、opencl、npu
+opt.set_valid_places("arm")
+# 4. 指定模型转化类型: naive_buffer、protobuf
+opt.set_model_type("naive_buffer")
+# 5. 输出模型地址
+opt.set_optimize_out("./output/hrnet_w18")
+# 6. 执行模型优化
+opt.run()
+```
+
+转换完成后在output目录下会生成对应的hrnet_w18.nb文件。
+
+#### 3.3.2 更新模型
+将优化好的`.nb`文件,替换安卓程序中的 app/src/main/assets/image_matting/
+models/modnet下面的文件即可。
+
+然后在工程中修改图像输入尺寸:打开string.xml文件,修改示例如下:
+``` xml
+<string name="INPUT_SHAPE_DEFAULT">1,3,256,256</string>
+```
+1,3,256,256分别表示图像对应的batchsize、channel、height、width,我们一般修改height和width即可,这里的height和width需要和静态图导出时设置的尺寸一致。
+
+整个安卓demo采用java实现,没有内嵌C++代码,构建和执行比较简单。未来也可以将本demo移植到java web项目中实现web版人像抠图。

+ 1 - 0
deploy/human_matting_android_demo/app/.gitignore

@@ -0,0 +1 @@
+/build

+ 118 - 0
deploy/human_matting_android_demo/app/build.gradle

@@ -0,0 +1,118 @@
+import java.security.MessageDigest
+
+apply plugin: 'com.android.application'
+
+android {
+    compileSdkVersion 28
+    defaultConfig {
+        applicationId "com.baidu.paddleseg.lite.demo.human_matting"
+        minSdkVersion 15
+        versionCode 1
+        versionName "1.0"
+        testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
+    }
+    buildTypes {
+        release {
+            minifyEnabled false
+            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
+        }
+    }
+}
+
+dependencies {
+    implementation fileTree(include: ['*.jar'], dir: 'libs')
+    implementation 'com.android.support:appcompat-v7:28.0.0'
+    implementation 'com.android.support.constraint:constraint-layout:1.1.3'
+    implementation 'com.android.support:design:28.0.0'
+    testImplementation 'junit:junit:4.12'
+    androidTestImplementation 'com.android.support.test:runner:1.0.2'
+    androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
+    implementation files('libs/PaddlePredictor.jar')  //添加PaddleLite包
+}
+
+def paddleLiteLibs = 'https://paddlelite-demo.bj.bcebos.com/libs/android/paddle_lite_libs_v2_8_0.tar.gz'
+task downloadAndExtractPaddleLiteLibs(type: DefaultTask) {
+    doFirst {
+        println "Downloading and extracting Paddle Lite libs"
+    }
+    doLast {
+        // Prepare cache folder for libs
+        if (!file("cache").exists()) {
+            mkdir "cache"
+        }
+        // Generate cache name for libs
+        MessageDigest messageDigest = MessageDigest.getInstance('MD5')
+        messageDigest.update(paddleLiteLibs.bytes)
+        String cacheName = new BigInteger(1, messageDigest.digest()).toString(32)
+        // Download libs
+        if (!file("cache/${cacheName}.tar.gz").exists()) {
+            ant.get(src: paddleLiteLibs, dest: file("cache/${cacheName}.tar.gz"))
+        }
+        // Unpack libs
+        copy {
+            from tarTree("cache/${cacheName}.tar.gz")
+            into "cache/${cacheName}"
+        }
+        // Copy PaddlePredictor.jar
+        if (!file("libs/PaddlePredictor.jar").exists()) {
+            copy {
+                from "cache/${cacheName}/java/PaddlePredictor.jar"
+                into "libs"
+            }
+        }
+        // Copy libpaddle_lite_jni.so for armeabi-v7a and arm64-v8a
+        if (!file("src/main/jniLibs/armeabi-v7a/libpaddle_lite_jni.so").exists()) {
+            copy {
+                from "cache/${cacheName}/java/libs/armeabi-v7a/"
+                into "src/main/jniLibs/armeabi-v7a"
+            }
+        }
+        if (!file("src/main/jniLibs/arm64-v8a/libpaddle_lite_jni.so").exists()) {
+            copy {
+                from "cache/${cacheName}/java/libs/arm64-v8a/"
+                into "src/main/jniLibs/arm64-v8a"
+            }
+        }
+    }
+}
+preBuild.dependsOn downloadAndExtractPaddleLiteLibs
+
+def paddleLiteModels = [
+        [
+                'src' : 'https://paddleseg.bj.bcebos.com/matting/models/deploy/hrnet_w18.tar.gz',
+                'dest' : 'src/main/assets/image_matting/models/modnet'
+        ],
+]
+task downloadAndExtractPaddleLiteModels(type: DefaultTask) {
+    doFirst {
+        println "Downloading and extracting Paddle Lite models"
+    }
+    doLast {
+        // Prepare cache folder for models
+        if (!file("cache").exists()) {
+            mkdir "cache"
+        }
+        paddleLiteModels.eachWithIndex { model, index ->
+            MessageDigest messageDigest = MessageDigest.getInstance('MD5')
+            messageDigest.update(model.src.bytes)
+            String cacheName = new BigInteger(1, messageDigest.digest()).toString(32)
+            // Download model file
+            if (!file("cache/${cacheName}.tar.gz").exists()) {
+                ant.get(src: model.src, dest: file("cache/${cacheName}.tar.gz"))
+            }
+            // Unpack model file
+            copy {
+                from tarTree("cache/${cacheName}.tar.gz")
+                into "cache/${cacheName}"
+            }
+            // Copy model file
+            if (!file("${model.dest}/hrnet_w18.nb").exists()) {
+                copy {
+                    from "cache/${cacheName}"
+                    into "${model.dest}"
+                }
+            }
+        }
+    }
+}
+preBuild.dependsOn downloadAndExtractPaddleLiteModels

+ 172 - 0
deploy/human_matting_android_demo/app/gradlew

@@ -0,0 +1,172 @@
+#!/usr/bin/env sh
+
+##############################################################################
+##
+##  Gradle start up script for UN*X
+##
+##############################################################################
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+    ls=`ls -ld "$PRG"`
+    link=`expr "$ls" : '.*-> \(.*\)$'`
+    if expr "$link" : '/.*' > /dev/null; then
+        PRG="$link"
+    else
+        PRG=`dirname "$PRG"`"/$link"
+    fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS=""
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn () {
+    echo "$*"
+}
+
+die () {
+    echo
+    echo "$*"
+    echo
+    exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+nonstop=false
+case "`uname`" in
+  CYGWIN* )
+    cygwin=true
+    ;;
+  Darwin* )
+    darwin=true
+    ;;
+  MINGW* )
+    msys=true
+    ;;
+  NONSTOP* )
+    nonstop=true
+    ;;
+esac
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+    if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+        # IBM's JDK on AIX uses strange locations for the executables
+        JAVACMD="$JAVA_HOME/jre/sh/java"
+    else
+        JAVACMD="$JAVA_HOME/bin/java"
+    fi
+    if [ ! -x "$JAVACMD" ] ; then
+        die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+    fi
+else
+    JAVACMD="java"
+    which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
+    MAX_FD_LIMIT=`ulimit -H -n`
+    if [ $? -eq 0 ] ; then
+        if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+            MAX_FD="$MAX_FD_LIMIT"
+        fi
+        ulimit -n $MAX_FD
+        if [ $? -ne 0 ] ; then
+            warn "Could not set maximum file descriptor limit: $MAX_FD"
+        fi
+    else
+        warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+    fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+    GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin, switch paths to Windows format before running java
+if $cygwin ; then
+    APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+    CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+    JAVACMD=`cygpath --unix "$JAVACMD"`
+
+    # We build the pattern for arguments to be converted via cygpath
+    ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+    SEP=""
+    for dir in $ROOTDIRSRAW ; do
+        ROOTDIRS="$ROOTDIRS$SEP$dir"
+        SEP="|"
+    done
+    OURCYGPATTERN="(^($ROOTDIRS))"
+    # Add a user-defined pattern to the cygpath arguments
+    if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+        OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+    fi
+    # Now convert the arguments - kludge to limit ourselves to /bin/sh
+    i=0
+    for arg in "$@" ; do
+        CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+        CHECK2=`echo "$arg"|egrep -c "^-"`                                 ### Determine if an option
+
+        if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then                    ### Added a condition
+            eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+        else
+            eval `echo args$i`="\"$arg\""
+        fi
+        i=$((i+1))
+    done
+    case $i in
+        (0) set -- ;;
+        (1) set -- "$args0" ;;
+        (2) set -- "$args0" "$args1" ;;
+        (3) set -- "$args0" "$args1" "$args2" ;;
+        (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+        (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+        (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+        (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+        (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+        (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+    esac
+fi
+
+# Escape application args
+save () {
+    for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
+    echo " "
+}
+APP_ARGS=$(save "$@")
+
+# Collect all arguments for the java command, following the shell quoting and substitution rules
+eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
+
+# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
+if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
+  cd "$(dirname "$0")"
+fi
+
+exec "$JAVACMD" "$@"

+ 84 - 0
deploy/human_matting_android_demo/app/gradlew.bat

@@ -0,0 +1,84 @@
+@if "%DEBUG%" == "" @echo off
+@rem ##########################################################################
+@rem
+@rem  Gradle startup script for Windows
+@rem
+@rem ##########################################################################
+
+@rem Set local scope for the variables with windows NT shell
+if "%OS%"=="Windows_NT" setlocal
+
+set DIRNAME=%~dp0
+if "%DIRNAME%" == "" set DIRNAME=.
+set APP_BASE_NAME=%~n0
+set APP_HOME=%DIRNAME%
+
+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+set DEFAULT_JVM_OPTS=
+
+@rem Find java.exe
+if defined JAVA_HOME goto findJavaFromJavaHome
+
+set JAVA_EXE=java.exe
+%JAVA_EXE% -version >NUL 2>&1
+if "%ERRORLEVEL%" == "0" goto init
+
+echo.
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:findJavaFromJavaHome
+set JAVA_HOME=%JAVA_HOME:"=%
+set JAVA_EXE=%JAVA_HOME%/bin/java.exe
+
+if exist "%JAVA_EXE%" goto init
+
+echo.
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:init
+@rem Get command-line arguments, handling Windows variants
+
+if not "%OS%" == "Windows_NT" goto win9xME_args
+
+:win9xME_args
+@rem Slurp the command line arguments.
+set CMD_LINE_ARGS=
+set _SKIP=2
+
+:win9xME_args_slurp
+if "x%~1" == "x" goto execute
+
+set CMD_LINE_ARGS=%*
+
+:execute
+@rem Setup the command line
+
+set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
+
+@rem Execute Gradle
+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
+
+:end
+@rem End local scope for the variables with windows NT shell
+if "%ERRORLEVEL%"=="0" goto mainEnd
+
+:fail
+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
+rem the _cmd.exe /c_ return code!
+if  not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
+exit /b 1
+
+:mainEnd
+if "%OS%"=="Windows_NT" endlocal
+
+:omega

+ 8 - 0
deploy/human_matting_android_demo/app/local.properties

@@ -0,0 +1,8 @@
+## This file must *NOT* be checked into Version Control Systems,
+# as it contains information specific to your local configuration.
+#
+# Location of the SDK. This is only used by Gradle.
+# For customization when using a Version Control System, please read the
+# header note.
+#Mon Nov 25 17:01:52 CST 2019
+sdk.dir=/Users/chenlingchi/Library/Android/sdk

+ 21 - 0
deploy/human_matting_android_demo/app/proguard-rules.pro

@@ -0,0 +1,21 @@
+# Add project specific ProGuard rules here.
+# You can control the set of applied configuration files using the
+# proguardFiles setting in build.gradle.
+#
+# For more details, see
+#   http://developer.android.com/guide/developing/tools/proguard.html
+
+# If your project uses WebView with JS, uncomment the following
+# and specify the fully qualified class name to the JavaScript interface
+# class:
+#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
+#   public *;
+#}
+
+# Uncomment this to preserve the line number information for
+# debugging stack traces.
+#-keepattributes SourceFile,LineNumberTable
+
+# If you keep the line number information, uncomment this to
+# hide the original source file name.
+#-renamesourcefileattribute SourceFile

+ 26 - 0
deploy/human_matting_android_demo/app/src/androidTest/java/com/baidu/paddle/lite/demo/ExampleInstrumentedTest.java

@@ -0,0 +1,26 @@
+package com.baidu.paddle.lite.demo;
+
+import android.content.Context;
+import android.support.test.InstrumentationRegistry;
+import android.support.test.runner.AndroidJUnit4;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import static org.junit.Assert.*;
+
+/**
+ * Instrumented test, which will execute on an Android device.
+ *
+ * @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
+ */
+@RunWith(AndroidJUnit4.class)
+public class ExampleInstrumentedTest {
+    @Test
+    public void useAppContext() {
+        // Context of the app under test.
+        Context appContext = InstrumentationRegistry.getTargetContext();
+
+        assertEquals("com.baidu.paddle.lite.demo", appContext.getPackageName());
+    }
+}

+ 29 - 0
deploy/human_matting_android_demo/app/src/main/AndroidManifest.xml

@@ -0,0 +1,29 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+          package="com.paddle.demo.matting">
+
+    <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
+    <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
+    <uses-permission android:name="android.permission.CAMERA"/>
+
+    <application
+            android:allowBackup="true"
+            android:largeHeap="true"
+            android:icon="@mipmap/ic_launcher"
+            android:label="@string/app_name"
+            android:roundIcon="@mipmap/ic_launcher_round"
+            android:supportsRtl="true"
+            android:theme="@style/AppTheme">
+        <activity android:name=".MainActivity">
+            <intent-filter>
+                <action android:name="android.intent.action.MAIN"/>
+                <category android:name="android.intent.category.LAUNCHER"/>
+            </intent-filter>
+        </activity>
+        <activity
+                android:name=".SettingsActivity"
+                android:label="Settings">
+        </activity>
+    </application>
+
+</manifest>

BIN
deploy/human_matting_android_demo/app/src/main/assets/image_matting/images/bg.jpg


BIN
deploy/human_matting_android_demo/app/src/main/assets/image_matting/images/human.jpg


+ 2 - 0
deploy/human_matting_android_demo/app/src/main/assets/image_matting/labels/label_list

@@ -0,0 +1,2 @@
+background
+human

+ 127 - 0
deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/AppCompatPreferenceActivity.java

@@ -0,0 +1,127 @@
+/*
+ * Copyright (C) 2014 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.paddle.demo.matting;
+
+import android.content.res.Configuration;
+import android.os.Bundle;
+import android.preference.PreferenceActivity;
+import android.support.annotation.LayoutRes;
+import android.support.annotation.Nullable;
+import android.support.v7.app.ActionBar;
+import android.support.v7.app.AppCompatDelegate;
+import android.support.v7.widget.Toolbar;
+import android.view.MenuInflater;
+import android.view.View;
+import android.view.ViewGroup;
+
+/**
+ * A {@link android.preference.PreferenceActivity} which implements and proxies the necessary calls
+ * to be used with AppCompat.
+ * <p>
+ * This technique can be used with an {@link android.app.Activity} class, not just
+ * {@link android.preference.PreferenceActivity}.
+ */
+public abstract class AppCompatPreferenceActivity extends PreferenceActivity {
+    private AppCompatDelegate mDelegate;
+
+    @Override
+    protected void onCreate(Bundle savedInstanceState) {
+        getDelegate().installViewFactory();
+        getDelegate().onCreate(savedInstanceState);
+        super.onCreate(savedInstanceState);
+    }
+
+    @Override
+    protected void onPostCreate(Bundle savedInstanceState) {
+        super.onPostCreate(savedInstanceState);
+        getDelegate().onPostCreate(savedInstanceState);
+    }
+
+    public ActionBar getSupportActionBar() {
+        return getDelegate().getSupportActionBar();
+    }
+
+    public void setSupportActionBar(@Nullable Toolbar toolbar) {
+        getDelegate().setSupportActionBar(toolbar);
+    }
+
+    @Override
+    public MenuInflater getMenuInflater() {
+        return getDelegate().getMenuInflater();
+    }
+
+    @Override
+    public void setContentView(@LayoutRes int layoutResID) {
+        getDelegate().setContentView(layoutResID);
+    }
+
+    @Override
+    public void setContentView(View view) {
+        getDelegate().setContentView(view);
+    }
+
+    @Override
+    public void setContentView(View view, ViewGroup.LayoutParams params) {
+        getDelegate().setContentView(view, params);
+    }
+
+    @Override
+    public void addContentView(View view, ViewGroup.LayoutParams params) {
+        getDelegate().addContentView(view, params);
+    }
+
+    @Override
+    protected void onPostResume() {
+        super.onPostResume();
+        getDelegate().onPostResume();
+    }
+
+    @Override
+    protected void onTitleChanged(CharSequence title, int color) {
+        super.onTitleChanged(title, color);
+        getDelegate().setTitle(title);
+    }
+
+    @Override
+    public void onConfigurationChanged(Configuration newConfig) {
+        super.onConfigurationChanged(newConfig);
+        getDelegate().onConfigurationChanged(newConfig);
+    }
+
+    @Override
+    protected void onStop() {
+        super.onStop();
+        getDelegate().onStop();
+    }
+
+    @Override
+    protected void onDestroy() {
+        super.onDestroy();
+        getDelegate().onDestroy();
+    }
+
+    public void invalidateOptionsMenu() {
+        getDelegate().invalidateOptionsMenu();
+    }
+
+    private AppCompatDelegate getDelegate() {
+        if (mDelegate == null) {
+            mDelegate = AppCompatDelegate.create(this, null);
+        }
+        return mDelegate;
+    }
+}

+ 562 - 0
deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/MainActivity.java

@@ -0,0 +1,562 @@
+package com.paddle.demo.matting;
+
+import android.Manifest;
+import android.app.ProgressDialog;
+import android.content.ContentResolver;
+import android.content.ContentValues;
+import android.content.Intent;
+import android.content.SharedPreferences;
+import android.content.pm.PackageManager;
+import android.database.Cursor;
+import android.graphics.Bitmap;
+import android.graphics.BitmapFactory;
+import android.net.Uri;
+import android.os.Bundle;
+import android.os.Environment;
+import android.os.Handler;
+import android.os.HandlerThread;
+import android.os.Message;
+import android.preference.PreferenceManager;
+import android.provider.MediaStore;
+import android.support.annotation.NonNull;
+import android.support.v4.app.ActivityCompat;
+import android.support.v4.content.ContextCompat;
+import android.support.v7.app.AppCompatActivity;
+import android.text.method.ScrollingMovementMethod;
+import android.util.Log;
+import android.view.Display;
+import android.view.Menu;
+import android.view.MenuInflater;
+import android.view.MenuItem;
+import android.view.View;
+import android.widget.Button;
+import android.widget.ImageView;
+import android.widget.TextView;
+import android.widget.Toast;
+
+import com.paddle.demo.matting.config.Config;
+import com.paddle.demo.matting.preprocess.Preprocess;
+import com.paddle.demo.matting.visual.Visualize;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+
+public class MainActivity extends AppCompatActivity {
+
+    private static final String TAG = MainActivity.class.getSimpleName();
+
+    //定义图像来源
+    public static final int OPEN_GALLERY_REQUEST_CODE = 0;//本地相册
+    public static final int TAKE_PHOTO_REQUEST_CODE = 1;//摄像头拍摄
+
+    //定义模型推理相关变量
+    public static final int REQUEST_LOAD_MODEL = 0;
+    public static final int REQUEST_RUN_MODEL = 1;
+    public static final int RESPONSE_LOAD_MODEL_SUCCESSED = 0;
+    public static final int RESPONSE_LOAD_MODEL_FAILED = 1;
+    public static final int RESPONSE_RUN_MODEL_SUCCESSED = 2;
+    public static final int RESPONSE_RUN_MODEL_FAILED = 3;
+
+    protected ProgressDialog pbLoadModel = null;
+    protected ProgressDialog pbRunModel = null;
+
+    //定义操作流程线程句柄
+    protected HandlerThread worker = null; // 工作线程(加载和运行模型)
+    protected Handler receiver = null; // 接收来自工作线程的消息
+    protected Handler sender = null; // 发送消息给工作线程
+
+    protected TextView tvInputSetting;//输入信息面板
+    protected ImageView ivInputImage;//输入图像面板
+    protected TextView tvOutputResult;//输出结果面板
+    protected TextView tvInferenceTime;//推理时间面板
+
+    // 模型配置
+    Config config = new Config();
+
+    protected Predictor predictor = new Predictor();
+
+    Preprocess preprocess = new Preprocess();
+
+    Visualize visualize = new Visualize();
+
+    //定义图像存储路径
+    Uri photoUri;
+
+    private Button btnSaveImg;
+
+    @Override
+    protected void onCreate(Bundle savedInstanceState) {
+        super.onCreate(savedInstanceState);
+        setContentView(R.layout.activity_main);
+
+        //绑定保存图像按钮
+        btnSaveImg = (Button) findViewById(R.id.save_img);
+
+        //定义消息接收线程
+        receiver = new Handler() {
+            @Override
+            public void handleMessage(Message msg) {
+                switch (msg.what) {
+                    case RESPONSE_LOAD_MODEL_SUCCESSED:
+                        pbLoadModel.dismiss();
+                        onLoadModelSuccessed();
+                        break;
+                    case RESPONSE_LOAD_MODEL_FAILED:
+                        pbLoadModel.dismiss();
+                        Toast.makeText(MainActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show();
+                        onLoadModelFailed();
+                        break;
+                    case RESPONSE_RUN_MODEL_SUCCESSED:
+                        pbRunModel.dismiss();
+                        onRunModelSuccessed();
+                        break;
+                    case RESPONSE_RUN_MODEL_FAILED:
+                        pbRunModel.dismiss();
+                        Toast.makeText(MainActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show();
+                        onRunModelFailed();
+                        break;
+                    default:
+                        break;
+                }
+            }
+        };
+
+        //定义工作线程
+        worker = new HandlerThread("Predictor Worker");
+        worker.start();
+
+        //定义发送消息线程
+        sender = new Handler(worker.getLooper()) {
+            public void handleMessage(Message msg) {
+                switch (msg.what) {
+                    case REQUEST_LOAD_MODEL:
+                        // load model and reload test image
+                        if (onLoadModel()) {
+                            receiver.sendEmptyMessage(RESPONSE_LOAD_MODEL_SUCCESSED);
+                        } else {
+                            receiver.sendEmptyMessage(RESPONSE_LOAD_MODEL_FAILED);
+                        }
+                        break;
+                    case REQUEST_RUN_MODEL:
+                        // run model if model is loaded
+                        if (onRunModel()) {
+                            receiver.sendEmptyMessage(RESPONSE_RUN_MODEL_SUCCESSED);
+                        } else {
+                            receiver.sendEmptyMessage(RESPONSE_RUN_MODEL_FAILED);
+                        }
+                        break;
+                    default:
+                        break;
+                }
+            }
+        };
+
+        tvInputSetting = findViewById(R.id.tv_input_setting);
+        ivInputImage = findViewById(R.id.iv_input_image);
+        tvInferenceTime = findViewById(R.id.tv_inference_time);
+        tvOutputResult = findViewById(R.id.tv_output_result);
+        tvInputSetting.setMovementMethod(ScrollingMovementMethod.getInstance());
+        tvOutputResult.setMovementMethod(ScrollingMovementMethod.getInstance());
+    }
+
+    public boolean onLoadModel() {
+        return predictor.init(MainActivity.this, config);
+    }
+
+    public boolean onRunModel() {
+        return predictor.isLoaded() && predictor.runModel(preprocess,visualize);
+    }
+
+    public void onLoadModelFailed() {
+
+    }
+    public void onRunModelFailed() {
+    }
+
+    public void loadModel() {
+        pbLoadModel = ProgressDialog.show(this, "", "加载模型中...", false, false);
+        sender.sendEmptyMessage(REQUEST_LOAD_MODEL);
+    }
+
+    public void runModel() {
+        pbRunModel = ProgressDialog.show(this, "", "推理中...", false, false);
+        sender.sendEmptyMessage(REQUEST_RUN_MODEL);
+    }
+
+    public void onLoadModelSuccessed() {
+        // load test image from file_paths and run model
+        try {
+            if (config.imagePath.isEmpty()||config.bgPath.isEmpty()) {
+                return;
+            }
+            Bitmap image = null;
+            Bitmap bg = null;
+
+            //加载待抠图像(如果是拍照或者本地相册读取,则第一个字符为“/”。否则就是从默认路径下读取图片)
+            if (!config.imagePath.substring(0, 1).equals("/")) {
+                InputStream imageStream = getAssets().open(config.imagePath);
+                image = BitmapFactory.decodeStream(imageStream);
+            } else {
+                if (!new File(config.imagePath).exists()) {
+                    return;
+                }
+                image = BitmapFactory.decodeFile(config.imagePath);
+            }
+
+            //加载背景图像
+            if (!config.bgPath.substring(0, 1).equals("/")) {
+                InputStream imageStream = getAssets().open(config.bgPath);
+                bg = BitmapFactory.decodeStream(imageStream);
+            } else {
+                if (!new File(config.bgPath).exists()) {
+                    return;
+                }
+                bg = BitmapFactory.decodeFile(config.bgPath);
+            }
+
+            if (image != null && bg != null && predictor.isLoaded()) {
+                predictor.setInputImage(image,bg);
+                runModel();
+            }
+        } catch (IOException e) {
+            Toast.makeText(MainActivity.this, "Load image failed!", Toast.LENGTH_SHORT).show();
+            e.printStackTrace();
+        }
+    }
+
+    public void onRunModelSuccessed() {
+        // 获取抠图结果并更新UI
+        tvInferenceTime.setText("推理耗时: " + predictor.inferenceTime() + " ms");
+        Bitmap outputImage = predictor.outputImage();
+        if (outputImage != null) {
+            ivInputImage.setImageBitmap(outputImage);
+        }
+        tvOutputResult.setText(predictor.outputResult());
+        tvOutputResult.scrollTo(0, 0);
+    }
+
+
+    public void onImageChanged(Bitmap image) {
+        Bitmap bg = null;
+        try {
+            //加载背景图像
+            if (!config.bgPath.substring(0, 1).equals("/")) {
+                InputStream imageStream = getAssets().open(config.bgPath);
+                bg = BitmapFactory.decodeStream(imageStream);
+            } else {
+                if (!new File(config.bgPath).exists()) {
+                    return;
+                }
+                bg = BitmapFactory.decodeFile(config.bgPath);
+            }
+        } catch (IOException e) {
+            Toast.makeText(MainActivity.this, "加载背景图失败!", Toast.LENGTH_SHORT).show();
+            e.printStackTrace();
+        }
+
+        // rerun model if users pick test image from gallery or camera
+        //设置预测器图像
+        if (image != null && predictor.isLoaded()) {
+            predictor.setInputImage(image,bg);
+            runModel();
+        }
+    }
+
+    public void onImageChanged(String path) {
+        Bitmap bg = null;
+        try {
+            //加载背景图像
+            if (!config.bgPath.substring(0, 1).equals("/")) {
+                InputStream imageStream = getAssets().open(config.bgPath);
+                bg = BitmapFactory.decodeStream(imageStream);
+            } else {
+                if (!new File(config.bgPath).exists()) {
+                    return;
+                }
+                bg = BitmapFactory.decodeFile(config.bgPath);
+            }
+        } catch (IOException e) {
+            Toast.makeText(MainActivity.this, "加载背景图失败!", Toast.LENGTH_SHORT).show();
+            e.printStackTrace();
+        }
+
+        //设置预测器图像
+        Bitmap image = BitmapFactory.decodeFile(path);
+        predictor.setInputImage(image,bg);
+            runModel();
+    }
+
+    //打开设置页面
+    public void onSettingsClicked() {
+        startActivity(new Intent(MainActivity.this, SettingsActivity.class));
+    }
+
+    @Override
+    public boolean onCreateOptionsMenu(Menu menu) {
+        MenuInflater inflater = getMenuInflater();
+        inflater.inflate(R.menu.menu_action_options, menu);
+        return true;
+    }
+
+    @Override
+    public boolean onOptionsItemSelected(MenuItem item) {
+        switch (item.getItemId()) {
+            case android.R.id.home:
+                finish();
+                break;
+            case R.id.open_gallery:
+                if (requestAllPermissions()) {
+                    openGallery();
+                }
+                break;
+            case R.id.take_photo:
+                if (requestAllPermissions()) {
+                    takePhoto();
+                }
+                break;
+            case R.id.settings:
+                if (requestAllPermissions()) {
+                    // make sure we have SDCard r&w permissions to load model from SDCard
+                    onSettingsClicked();
+                }
+                break;
+        }
+        return super.onOptionsItemSelected(item);
+    }
+    @Override
+    public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions,
+                                           @NonNull int[] grantResults) {
+        super.onRequestPermissionsResult(requestCode, permissions, grantResults);
+        if (grantResults[0] != PackageManager.PERMISSION_GRANTED || grantResults[1] != PackageManager.PERMISSION_GRANTED) {
+            Toast.makeText(this, "Permission Denied", Toast.LENGTH_SHORT).show();
+        }
+    }
+
+    private Bitmap getBitMapFromPath(String imageFilePath) {
+        Display currentDisplay = getWindowManager().getDefaultDisplay();
+        int dw = currentDisplay.getWidth();
+        int dh = currentDisplay.getHeight();
+        // Load up the image's dimensions not the image itself
+        BitmapFactory.Options bmpFactoryOptions = new BitmapFactory.Options();
+        bmpFactoryOptions.inJustDecodeBounds = true;
+        Bitmap bmp = BitmapFactory.decodeFile(imageFilePath, bmpFactoryOptions);
+        int heightRatio = (int) Math.ceil(bmpFactoryOptions.outHeight
+                / (float) dh);
+        int widthRatio = (int) Math.ceil(bmpFactoryOptions.outWidth
+                / (float) dw);
+
+        // If both of the ratios are greater than 1,
+        // one of the sides of the image is greater than the screen
+        if (heightRatio > 1 && widthRatio > 1) {
+            if (heightRatio > widthRatio) {
+                // Height ratio is larger, scale according to it
+                bmpFactoryOptions.inSampleSize = heightRatio;
+            } else {
+                // Width ratio is larger, scale according to it
+                bmpFactoryOptions.inSampleSize = widthRatio;
+            }
+        }
+        // Decode it for real
+        bmpFactoryOptions.inJustDecodeBounds = false;
+        bmp = BitmapFactory.decodeFile(imageFilePath, bmpFactoryOptions);
+        return bmp;
+    }
+
+    @Override
+    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
+        super.onActivityResult(requestCode, resultCode, data);
+        //if (resultCode == RESULT_OK && data != null) {
+        if (resultCode == RESULT_OK) {
+            switch (requestCode) {
+                case OPEN_GALLERY_REQUEST_CODE:
+                    try {
+                        ContentResolver resolver = getContentResolver();
+                        Uri uri = data.getData();
+                        Bitmap image = MediaStore.Images.Media.getBitmap(resolver, uri);
+                        //判断图像尺寸(如果图像过大推理会奔溃)
+                        int width = image.getWidth();
+                        int height = image.getHeight();
+                        Bitmap scaleImage;
+                        if(width > 800) {
+                            int new_width = 800;
+                            int new_height = (int)(height*1.0/width*new_width);
+                            scaleImage = Bitmap.createScaledBitmap(image, new_width, new_height,true);
+                        }
+                        else{
+                            scaleImage = image.copy(Bitmap.Config.ARGB_8888, true);
+                        }
+                        String[] proj = {MediaStore.Images.Media.DATA};
+                        Cursor cursor = managedQuery(uri, proj, null, null, null);
+                        cursor.moveToFirst();
+                        onImageChanged(scaleImage);
+                    } catch (IOException e) {
+                        Log.e(TAG, e.toString());
+                    }
+                    break;
+
+                case TAKE_PHOTO_REQUEST_CODE:
+                    //Bitmap image = (Bitmap) data.getParcelableExtra("data");//获取缩略图
+                    Bitmap image = null;
+                    if (photoUri == null)
+                        return;
+                    //通过Uri和selection来获取真实的图片路径
+                    Cursor cursor = getContentResolver().query(photoUri, null, null, null, null);
+                    if (cursor != null) {
+                        if (cursor.moveToFirst()){
+                            String path = cursor.getString(cursor.getColumnIndex(MediaStore.Images.Media.DATA));
+                            //获得图片
+                            image = getBitMapFromPath(path);
+                        }
+                        cursor.close();
+                    }
+                    photoUri = null;
+                    onImageChanged(image);
+                    break;
+                default:
+                    break;
+            }
+        }
+    }
+    private boolean requestAllPermissions() {
+        if (ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE)
+                != PackageManager.PERMISSION_GRANTED || ContextCompat.checkSelfPermission(this,
+                Manifest.permission.CAMERA)
+                != PackageManager.PERMISSION_GRANTED) {
+            ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.WRITE_EXTERNAL_STORAGE,
+                            Manifest.permission.CAMERA},
+                    0);
+            return false;
+        }
+        return true;
+    }
+
+    private void openGallery() {
+        Intent intent = new Intent(Intent.ACTION_PICK, null);
+        intent.setDataAndType(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, "image/*");
+        startActivityForResult(intent, OPEN_GALLERY_REQUEST_CODE);
+    }
+
+    private void takePhoto() {
+        Intent takePhotoIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
+        ContentValues values = new ContentValues();
+        photoUri = getContentResolver().insert(
+                MediaStore.Images.Media.EXTERNAL_CONTENT_URI, values);
+        takePhotoIntent.putExtra(android.provider.MediaStore.EXTRA_OUTPUT, photoUri);
+        if (takePhotoIntent.resolveActivity(getPackageManager()) != null) {
+            startActivityForResult(takePhotoIntent, TAKE_PHOTO_REQUEST_CODE);
+        }
+    }
+
+    @Override
+    public boolean onPrepareOptionsMenu(Menu menu) {
+        boolean isLoaded = predictor.isLoaded();
+        menu.findItem(R.id.open_gallery).setEnabled(isLoaded);
+        menu.findItem(R.id.take_photo).setEnabled(isLoaded);
+        return super.onPrepareOptionsMenu(menu);
+    }
+
+    @Override
+    protected void onResume() {
+        Log.i(TAG,"begin onResume");
+        super.onResume();
+
+        SharedPreferences sharedPreferences = PreferenceManager.getDefaultSharedPreferences(this);
+        boolean settingsChanged = false;
+        String model_path = sharedPreferences.getString(getString(R.string.MODEL_PATH_KEY),
+                getString(R.string.MODEL_PATH_DEFAULT));
+        String label_path = sharedPreferences.getString(getString(R.string.LABEL_PATH_KEY),
+                getString(R.string.LABEL_PATH_DEFAULT));
+        String image_path = sharedPreferences.getString(getString(R.string.IMAGE_PATH_KEY),
+                getString(R.string.IMAGE_PATH_DEFAULT));
+        String bg_path = sharedPreferences.getString(getString(R.string.BG_PATH_KEY),
+                getString(R.string.BG_PATH_DEFAULT));
+        settingsChanged |= !model_path.equalsIgnoreCase(config.modelPath);
+        settingsChanged |= !label_path.equalsIgnoreCase(config.labelPath);
+        settingsChanged |= !image_path.equalsIgnoreCase(config.imagePath);
+        settingsChanged |= !bg_path.equalsIgnoreCase(config.bgPath);
+        int cpu_thread_num = Integer.parseInt(sharedPreferences.getString(getString(R.string.CPU_THREAD_NUM_KEY),
+                getString(R.string.CPU_THREAD_NUM_DEFAULT)));
+        settingsChanged |= cpu_thread_num != config.cpuThreadNum;
+        String cpu_power_mode =
+                sharedPreferences.getString(getString(R.string.CPU_POWER_MODE_KEY),
+                        getString(R.string.CPU_POWER_MODE_DEFAULT));
+        settingsChanged |= !cpu_power_mode.equalsIgnoreCase(config.cpuPowerMode);
+        String input_color_format =
+                sharedPreferences.getString(getString(R.string.INPUT_COLOR_FORMAT_KEY),
+                        getString(R.string.INPUT_COLOR_FORMAT_DEFAULT));
+        settingsChanged |= !input_color_format.equalsIgnoreCase(config.inputColorFormat);
+        long[] input_shape =
+                Utils.parseLongsFromString(sharedPreferences.getString(getString(R.string.INPUT_SHAPE_KEY),
+                        getString(R.string.INPUT_SHAPE_DEFAULT)), ",");
+
+        settingsChanged |= input_shape.length != config.inputShape.length;
+
+        if (!settingsChanged) {
+            for (int i = 0; i < input_shape.length; i++) {
+                settingsChanged |= input_shape[i] != config.inputShape[i];
+            }
+        }
+
+        if (settingsChanged) {
+            config.init(model_path,label_path,image_path,bg_path,cpu_thread_num,cpu_power_mode,
+                    input_color_format,input_shape);
+            preprocess.init(config);
+            // 更新UI
+            tvInputSetting.setText("算法模型: " + config.modelPath.substring(config.modelPath.lastIndexOf("/") + 1));
+            tvInputSetting.scrollTo(0, 0);
+            // 如果配置发生改变则重新加载模型并预测
+            loadModel();
+        }
+    }
+
+    @Override
+    protected void onDestroy() {
+        if (predictor != null) {
+            predictor.releaseModel();
+        }
+        worker.quit();
+        super.onDestroy();
+    }
+
+    //保存图像
+    public void clickSaveImg(View view){
+        //取出图像
+        Bitmap outputImage = predictor.outputImage();
+        if(outputImage==null) {
+            Toast.makeText(getApplicationContext(),"当前没有图像",Toast.LENGTH_SHORT).show();
+            return;
+        }
+
+        //保存图像
+        File f = new File(Environment.getExternalStorageDirectory() + "/newphoto.jpg");
+        if (f.exists()) {
+            f.delete();
+        }
+        try {
+            FileOutputStream out = new FileOutputStream(f);
+            outputImage.compress(Bitmap.CompressFormat.JPEG, 80, out);
+            out.flush();
+            out.close();
+        } catch (FileNotFoundException e) {
+            e.printStackTrace();
+        } catch (IOException e) {
+            e.printStackTrace();
+        }
+
+        // 把文件插入到系统图库
+        try {
+            MediaStore.Images.Media.insertImage(this.getContentResolver(),
+                    f.getAbsolutePath(), Environment.getExternalStorageDirectory() + "/newphoto.jpg", null);
+        } catch (FileNotFoundException e) {
+            e.printStackTrace();
+        }
+        // 通知图库更新
+        sendBroadcast(new Intent(Intent.ACTION_MEDIA_SCANNER_SCAN_FILE,
+                Uri.parse("file://" + "/sdcard/matting/")));
+
+        Toast.makeText(getApplicationContext(),"保存成功",Toast.LENGTH_SHORT).show();
+    }
+}

+ 302 - 0
deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/Predictor.java

@@ -0,0 +1,302 @@
+package com.paddle.demo.matting;
+
+import android.content.Context;
+import android.graphics.Bitmap;
+import android.util.Log;
+
+import com.baidu.paddle.lite.MobileConfig;
+import com.baidu.paddle.lite.PaddlePredictor;
+import com.baidu.paddle.lite.PowerMode;
+import com.baidu.paddle.lite.Tensor;
+import com.paddle.demo.matting.config.Config;
+
+import com.paddle.demo.matting.preprocess.Preprocess;
+import com.paddle.demo.matting.visual.Visualize;
+
+import java.io.File;
+import java.io.InputStream;
+import java.util.Date;
+import java.util.Vector;
+
+public class Predictor {
+    private static final String TAG = Predictor.class.getSimpleName();
+    protected Vector<String> wordLabels = new Vector<String>();
+
+    Config config = new Config();
+
+    protected Bitmap inputImage = null;//输入图像
+    protected Bitmap scaledImage = null;//尺度归一化图像
+    protected Bitmap bgImage = null;//背景图像
+    protected Bitmap outputImage = null;//输出图像
+    protected String outputResult = "";//输出结果
+    protected float preprocessTime = 0;//预处理时间
+    protected float postprocessTime = 0;//后处理时间
+
+    public boolean isLoaded = false;
+    public int warmupIterNum = 0;
+    public int inferIterNum = 1;
+    protected Context appCtx = null;
+    public int cpuThreadNum = 1;
+    public String cpuPowerMode = "LITE_POWER_HIGH";
+    public String modelPath = "";
+    public String modelName = "";
+    protected PaddlePredictor paddlePredictor = null;
+    protected float inferenceTime = 0;
+
+    public Predictor() {
+        super();
+    }
+
+    public boolean init(Context appCtx, String modelPath, int cpuThreadNum, String cpuPowerMode) {
+        this.appCtx = appCtx;
+        isLoaded = loadModel(modelPath, cpuThreadNum, cpuPowerMode);
+        return isLoaded;
+    }
+
+    public boolean init(Context appCtx, Config config) {
+
+        if (config.inputShape.length != 4) {
+            Log.i(TAG, "size of input shape should be: 4");
+            return false;
+        }
+        if (config.inputShape[0] != 1) {
+            Log.i(TAG, "only one batch is supported in the matting demo, you can use any batch size in " +
+                    "your Apps!");
+            return false;
+        }
+        if (config.inputShape[1] != 1 && config.inputShape[1] != 3) {
+            Log.i(TAG, "only one/three channels are supported in the image matting demo, you can use any " +
+                    "channel size in your Apps!");
+            return false;
+        }
+        if (!config.inputColorFormat.equalsIgnoreCase("RGB") && !config.inputColorFormat.equalsIgnoreCase("BGR")) {
+            Log.i(TAG, "only RGB and BGR color format is supported.");
+            return false;
+        }
+        init(appCtx, config.modelPath, config.cpuThreadNum, config.cpuPowerMode);
+
+        if (!isLoaded()) {
+            return false;
+        }
+        this.config = config;
+
+        return isLoaded;
+    }
+
+
+    public boolean isLoaded() {
+        return paddlePredictor != null && isLoaded;
+    }
+
+    //加载真值标签label,在本matting实例中用不到(在图像分类和图像分割任务中如果是多类目标那么一般需要提供label)
+    protected boolean loadLabel(String labelPath) {
+        wordLabels.clear();
+        // load word labels from file
+        try {
+            InputStream assetsInputStream = appCtx.getAssets().open(labelPath);
+            int available = assetsInputStream.available();
+            byte[] lines = new byte[available];
+            assetsInputStream.read(lines);
+            assetsInputStream.close();
+            String words = new String(lines);
+            String[] contents = words.split("\n");
+            for (String content : contents) {
+                wordLabels.add(content);
+            }
+            Log.i(TAG, "word label size: " + wordLabels.size());
+        } catch (Exception e) {
+            Log.e(TAG, e.getMessage());
+            return false;
+        }
+        return true;
+    }
+
+    public Tensor getInput(int idx) {
+        if (!isLoaded()) {
+            return null;
+        }
+        return paddlePredictor.getInput(idx);
+    }
+
+    public Tensor getOutput(int idx) {
+        if (!isLoaded()) {
+            return null;
+        }
+        return paddlePredictor.getOutput(idx);
+    }
+
+    protected boolean loadModel(String modelPath, int cpuThreadNum, String cpuPowerMode) {
+        // 如果已经加载过模型则先释放掉
+        releaseModel();
+
+        // 加载模型
+        if (modelPath.isEmpty()) {
+            return false;
+        }
+        String realPath = modelPath;
+        if (!modelPath.substring(0, 1).equals("/")) {
+            // read model files from custom file_paths if the first character of mode file_paths is '/'
+            // otherwise copy model to cache from assets
+            realPath = appCtx.getCacheDir() + "/" + modelPath;
+            Utils.copyDirectoryFromAssets(appCtx, modelPath, realPath);
+        }
+        if (realPath.isEmpty()) {
+            return false;
+        }
+        MobileConfig config = new MobileConfig();
+        //修改模型
+        config.setModelFromFile(realPath + File.separator + "hrnet_w18.nb");
+        config.setThreads(cpuThreadNum);
+        if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_HIGH")) {
+            config.setPowerMode(PowerMode.LITE_POWER_HIGH);
+        } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_LOW")) {
+            config.setPowerMode(PowerMode.LITE_POWER_LOW);
+        } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_FULL")) {
+            config.setPowerMode(PowerMode.LITE_POWER_FULL);
+        } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_NO_BIND")) {
+            config.setPowerMode(PowerMode.LITE_POWER_NO_BIND);
+        } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_RAND_HIGH")) {
+            config.setPowerMode(PowerMode.LITE_POWER_RAND_HIGH);
+        } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_RAND_LOW")) {
+            config.setPowerMode(PowerMode.LITE_POWER_RAND_LOW);
+        } else {
+            Log.e(TAG, "unknown cpu power mode!");
+            return false;
+        }
+        paddlePredictor = PaddlePredictor.createPaddlePredictor(config);
+        this.cpuThreadNum = cpuThreadNum;
+        this.cpuPowerMode = cpuPowerMode;
+        this.modelPath = realPath;
+        this.modelName = realPath.substring(realPath.lastIndexOf("/") + 1);
+        return true;
+    }
+
+    public boolean runModel() {
+        if (!isLoaded()) {
+            return false;
+        }
+        // warm up
+        for (int i = 0; i < warmupIterNum; i++){
+            paddlePredictor.run();
+        }
+        // inference
+        Date start = new Date();
+        for (int i = 0; i < inferIterNum; i++) {
+            paddlePredictor.run();
+        }
+        Date end = new Date();
+        inferenceTime = (end.getTime() - start.getTime()) / (float) inferIterNum;
+        return true;
+    }
+
+    public boolean runModel(Bitmap image) {
+        setInputImage(image,bgImage);
+        return runModel();
+    }
+
+    //正式的推理主函数
+    public boolean runModel(Preprocess preprocess, Visualize visualize) {
+        if (inputImage == null || bgImage == null) {
+            return false;
+        }
+
+        // set input shape
+        Tensor inputTensor = getInput(0);
+        inputTensor.resize(config.inputShape);
+
+        // pre-process image
+        Date start = new Date();
+
+        // setInputImage(scaledImage);
+        preprocess.init(config);
+        preprocess.to_array(scaledImage);
+        preprocess.normalize(preprocess.inputData);
+
+        // feed input tensor with pre-processed data
+        inputTensor.setData(preprocess.inputData);
+
+        Date end = new Date();
+        preprocessTime = (float) (end.getTime() - start.getTime());
+
+        // inference
+        runModel();
+
+        start = new Date();
+        Tensor outputTensor = getOutput(0);
+
+        // post-process
+        this.outputImage = visualize.draw(inputImage, outputTensor,bgImage);
+        postprocessTime = (float) (end.getTime() - start.getTime());
+
+        outputResult = new String();
+
+        return true;
+    }
+    public void releaseModel() {
+        paddlePredictor = null;
+        isLoaded = false;
+        cpuThreadNum = 1;
+        cpuPowerMode = "LITE_POWER_HIGH";
+        modelPath = "";
+        modelName = "";
+    }
+
+    public void setConfig(Config config){
+        this.config = config;
+    }
+
+    public Bitmap inputImage() {
+        return inputImage;
+    }
+
+    public Bitmap outputImage() {
+        return outputImage;
+    }
+
+    public String outputResult() {
+        return outputResult;
+    }
+
+    public float preprocessTime() {
+        return preprocessTime;
+    }
+
+    public float postprocessTime() {
+        return postprocessTime;
+    }
+
+    public String modelPath() {
+        return modelPath;
+    }
+
+    public String modelName() {
+        return modelName;
+    }
+
+    public int cpuThreadNum() {
+        return cpuThreadNum;
+    }
+
+    public String cpuPowerMode() {
+        return cpuPowerMode;
+    }
+
+    public float inferenceTime() {
+        return inferenceTime;
+    }
+
+    public void setInputImage(Bitmap image,Bitmap bg) {
+        if (image == null || bg == null) {
+            return;
+        }
+        // 预处理输入图像
+        Bitmap rgbaImage = image.copy(Bitmap.Config.ARGB_8888, true);
+        Bitmap scaleImage = Bitmap.createScaledBitmap(rgbaImage, (int) this.config.inputShape[3], (int) this.config.inputShape[2], true);
+        this.inputImage = rgbaImage;
+        this.scaledImage = scaleImage;
+
+        //预处理背景图像
+        this.bgImage = bg.copy(Bitmap.Config.ARGB_8888, true);
+    }
+
+}

+ 158 - 0
deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/SettingsActivity.java

@@ -0,0 +1,158 @@
+package com.paddle.demo.matting;
+
+import android.content.SharedPreferences;
+import android.os.Bundle;
+import android.preference.CheckBoxPreference;
+import android.preference.EditTextPreference;
+import android.preference.ListPreference;
+import android.support.v7.app.ActionBar;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class SettingsActivity extends AppCompatPreferenceActivity implements SharedPreferences.OnSharedPreferenceChangeListener {
+    ListPreference lpChoosePreInstalledModel = null;
+    CheckBoxPreference cbEnableCustomSettings = null;
+    EditTextPreference etModelPath = null;
+    EditTextPreference etLabelPath = null;
+    EditTextPreference etImagePath = null;
+    ListPreference lpCPUThreadNum = null;
+    ListPreference lpCPUPowerMode = null;
+    ListPreference lpInputColorFormat = null;
+
+
+
+    List<String> preInstalledModelPaths = null;
+    List<String> preInstalledLabelPaths = null;
+    List<String> preInstalledImagePaths = null;
+    List<String> preInstalledCPUThreadNums = null;
+    List<String> preInstalledCPUPowerModes = null;
+    List<String> preInstalledInputColorFormats = null;
+
+
+    @Override
+    public void onCreate(Bundle savedInstanceState) {
+        super.onCreate(savedInstanceState);
+        addPreferencesFromResource(R.xml.settings);
+        ActionBar supportActionBar = getSupportActionBar();
+        if (supportActionBar != null) {
+            supportActionBar.setDisplayHomeAsUpEnabled(true);
+        }
+
+        // initialized pre-installed models
+        preInstalledModelPaths = new ArrayList<String>();
+        preInstalledLabelPaths = new ArrayList<String>();
+        preInstalledImagePaths = new ArrayList<String>();
+
+        preInstalledCPUThreadNums = new ArrayList<String>();
+        preInstalledCPUPowerModes = new ArrayList<String>();
+        preInstalledInputColorFormats = new ArrayList<String>();
+        // add deeplab_mobilenet_for_cpu
+        preInstalledModelPaths.add(getString(R.string.MODEL_PATH_DEFAULT));
+        preInstalledLabelPaths.add(getString(R.string.LABEL_PATH_DEFAULT));
+        preInstalledImagePaths.add(getString(R.string.IMAGE_PATH_DEFAULT));
+        preInstalledCPUThreadNums.add(getString(R.string.CPU_THREAD_NUM_DEFAULT));
+        preInstalledCPUPowerModes.add(getString(R.string.CPU_POWER_MODE_DEFAULT));
+        preInstalledInputColorFormats.add(getString(R.string.INPUT_COLOR_FORMAT_DEFAULT));
+        // initialize UI components
+        lpChoosePreInstalledModel =
+                (ListPreference) findPreference(getString(R.string.CHOOSE_PRE_INSTALLED_MODEL_KEY));
+        String[] preInstalledModelNames = new String[preInstalledModelPaths.size()];
+        for (int i = 0; i < preInstalledModelPaths.size(); i++) {
+            preInstalledModelNames[i] =
+                    preInstalledModelPaths.get(i).substring(preInstalledModelPaths.get(i).lastIndexOf("/") + 1);
+        }
+        lpChoosePreInstalledModel.setEntries(preInstalledModelNames);
+        lpChoosePreInstalledModel.setEntryValues(preInstalledModelPaths.toArray(new String[preInstalledModelPaths.size()]));
+        cbEnableCustomSettings =
+                (CheckBoxPreference) findPreference(getString(R.string.ENABLE_CUSTOM_SETTINGS_KEY));
+        etModelPath = (EditTextPreference) findPreference(getString(R.string.MODEL_PATH_KEY));
+        etModelPath.setTitle("Model Path (SDCard: " + Utils.getSDCardDirectory() + ")");
+        etLabelPath = (EditTextPreference) findPreference(getString(R.string.LABEL_PATH_KEY));
+        etImagePath = (EditTextPreference) findPreference(getString(R.string.IMAGE_PATH_KEY));
+        lpCPUThreadNum =
+                (ListPreference) findPreference(getString(R.string.CPU_THREAD_NUM_KEY));
+        lpCPUPowerMode =
+                (ListPreference) findPreference(getString(R.string.CPU_POWER_MODE_KEY));
+        lpInputColorFormat =
+                (ListPreference) findPreference(getString(R.string.INPUT_COLOR_FORMAT_KEY));
+    }
+
+    private void reloadPreferenceAndUpdateUI() {
+        SharedPreferences sharedPreferences = getPreferenceScreen().getSharedPreferences();
+        boolean enableCustomSettings =
+                sharedPreferences.getBoolean(getString(R.string.ENABLE_CUSTOM_SETTINGS_KEY), false);
+        String modelPath = sharedPreferences.getString(getString(R.string.CHOOSE_PRE_INSTALLED_MODEL_KEY),
+                getString(R.string.MODEL_PATH_DEFAULT));
+        int modelIdx = lpChoosePreInstalledModel.findIndexOfValue(modelPath);
+        if (modelIdx >= 0 && modelIdx < preInstalledModelPaths.size()) {
+            if (!enableCustomSettings) {
+                SharedPreferences.Editor editor = sharedPreferences.edit();
+                editor.putString(getString(R.string.MODEL_PATH_KEY), preInstalledModelPaths.get(modelIdx));
+                editor.putString(getString(R.string.LABEL_PATH_KEY), preInstalledLabelPaths.get(modelIdx));
+                editor.putString(getString(R.string.IMAGE_PATH_KEY), preInstalledImagePaths.get(modelIdx));
+                editor.putString(getString(R.string.CPU_THREAD_NUM_KEY), preInstalledCPUThreadNums.get(modelIdx));
+                editor.putString(getString(R.string.CPU_POWER_MODE_KEY), preInstalledCPUPowerModes.get(modelIdx));
+                editor.putString(getString(R.string.INPUT_COLOR_FORMAT_KEY),
+                        preInstalledInputColorFormats.get(modelIdx));
+                editor.commit();
+            }
+            lpChoosePreInstalledModel.setSummary(modelPath);
+        }
+        cbEnableCustomSettings.setChecked(enableCustomSettings);
+        etModelPath.setEnabled(enableCustomSettings);
+        etLabelPath.setEnabled(enableCustomSettings);
+        etImagePath.setEnabled(enableCustomSettings);
+        lpCPUThreadNum.setEnabled(enableCustomSettings);
+        lpCPUPowerMode.setEnabled(enableCustomSettings);
+        lpInputColorFormat.setEnabled(enableCustomSettings);
+        modelPath = sharedPreferences.getString(getString(R.string.MODEL_PATH_KEY),
+                getString(R.string.MODEL_PATH_DEFAULT));
+        String labelPath = sharedPreferences.getString(getString(R.string.LABEL_PATH_KEY),
+                getString(R.string.LABEL_PATH_DEFAULT));
+        String imagePath = sharedPreferences.getString(getString(R.string.IMAGE_PATH_KEY),
+                getString(R.string.IMAGE_PATH_DEFAULT));
+        String cpuThreadNum = sharedPreferences.getString(getString(R.string.CPU_THREAD_NUM_KEY),
+                getString(R.string.CPU_THREAD_NUM_DEFAULT));
+        String cpuPowerMode = sharedPreferences.getString(getString(R.string.CPU_POWER_MODE_KEY),
+                getString(R.string.CPU_POWER_MODE_DEFAULT));
+        String inputColorFormat = sharedPreferences.getString(getString(R.string.INPUT_COLOR_FORMAT_KEY),
+                getString(R.string.INPUT_COLOR_FORMAT_DEFAULT));
+        etModelPath.setSummary(modelPath);
+        etModelPath.setText(modelPath);
+        etLabelPath.setSummary(labelPath);
+        etLabelPath.setText(labelPath);
+        etImagePath.setSummary(imagePath);
+        etImagePath.setText(imagePath);
+        lpCPUThreadNum.setValue(cpuThreadNum);
+        lpCPUThreadNum.setSummary(cpuThreadNum);
+        lpCPUPowerMode.setValue(cpuPowerMode);
+        lpCPUPowerMode.setSummary(cpuPowerMode);
+        lpInputColorFormat.setValue(inputColorFormat);
+        lpInputColorFormat.setSummary(inputColorFormat);
+
+    }
+
+    @Override
+    protected void onResume() {
+        super.onResume();
+        getPreferenceScreen().getSharedPreferences().registerOnSharedPreferenceChangeListener(this);
+        reloadPreferenceAndUpdateUI();
+    }
+
+    @Override
+    protected void onPause() {
+        super.onPause();
+        getPreferenceScreen().getSharedPreferences().unregisterOnSharedPreferenceChangeListener(this);
+    }
+
+    @Override
+    public void onSharedPreferenceChanged(SharedPreferences sharedPreferences, String key) {
+        if (key.equals(getString(R.string.CHOOSE_PRE_INSTALLED_MODEL_KEY))) {
+            SharedPreferences.Editor editor = sharedPreferences.edit();
+            editor.putBoolean(getString(R.string.ENABLE_CUSTOM_SETTINGS_KEY), false);
+            editor.commit();
+        }
+        reloadPreferenceAndUpdateUI();
+    }
+}

+ 87 - 0
deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/Utils.java

@@ -0,0 +1,87 @@
+package com.paddle.demo.matting;
+
+import android.content.Context;
+import android.os.Environment;
+
+import java.io.*;
+
+public class Utils {
+    private static final String TAG = Utils.class.getSimpleName();
+
+    public static void copyFileFromAssets(Context appCtx, String srcPath, String dstPath) {
+        if (srcPath.isEmpty() || dstPath.isEmpty()) {
+            return;
+        }
+        InputStream is = null;
+        OutputStream os = null;
+        try {
+            is = new BufferedInputStream(appCtx.getAssets().open(srcPath));
+            os = new BufferedOutputStream(new FileOutputStream(new File(dstPath)));
+            byte[] buffer = new byte[1024];
+            int length = 0;
+            while ((length = is.read(buffer)) != -1) {
+                os.write(buffer, 0, length);
+            }
+        } catch (FileNotFoundException e) {
+            e.printStackTrace();
+        } catch (IOException e) {
+            e.printStackTrace();
+        } finally {
+            try {
+                os.close();
+                is.close();
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
+        }
+    }
+
+    public static void copyDirectoryFromAssets(Context appCtx, String srcDir, String dstDir) {
+        if (srcDir.isEmpty() || dstDir.isEmpty()) {
+            return;
+        }
+        try {
+            if (!new File(dstDir).exists()) {
+                new File(dstDir).mkdirs();
+            }
+            for (String fileName : appCtx.getAssets().list(srcDir)) {
+                String srcSubPath = srcDir + File.separator + fileName;
+                String dstSubPath = dstDir + File.separator + fileName;
+                if (new File(srcSubPath).isDirectory()) {
+                    copyDirectoryFromAssets(appCtx, srcSubPath, dstSubPath);
+                } else {
+                    copyFileFromAssets(appCtx, srcSubPath, dstSubPath);
+                }
+            }
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+    }
+
+    public static float[] parseFloatsFromString(String string, String delimiter) {
+        String[] pieces = string.trim().toLowerCase().split(delimiter);
+        float[] floats = new float[pieces.length];
+        for (int i = 0; i < pieces.length; i++) {
+            floats[i] = Float.parseFloat(pieces[i].trim());
+        }
+        return floats;
+    }
+
+    public static long[] parseLongsFromString(String string, String delimiter) {
+        String[] pieces = string.trim().toLowerCase().split(delimiter);
+        long[] longs = new long[pieces.length];
+        for (int i = 0; i < pieces.length; i++) {
+            longs[i] = Long.parseLong(pieces[i].trim());
+        }
+        return longs;
+    }
+
+    public static String getSDCardDirectory() {
+        return Environment.getExternalStorageDirectory().getAbsolutePath();
+    }
+
+    public static boolean isSupportedNPU() {
+        String hardware = android.os.Build.HARDWARE;
+        return hardware.equalsIgnoreCase("kirin810") || hardware.equalsIgnoreCase("kirin990");
+    }
+}

+ 38 - 0
deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/config/Config.java

@@ -0,0 +1,38 @@
+package com.paddle.demo.matting.config;
+
+import android.graphics.Bitmap;
+
+public class Config {
+
+    public String modelPath = "";
+    public String labelPath = "";
+    public String imagePath = "";
+    public String bgPath = "";
+    public int cpuThreadNum = 1;
+    public String cpuPowerMode = "";
+    public String inputColorFormat = "";
+    public long[] inputShape = new long[]{};
+
+
+    public void init(String modelPath, String labelPath, String imagePath,String bgPath, int cpuThreadNum,
+                     String cpuPowerMode, String inputColorFormat,long[] inputShape){
+
+        this.modelPath = modelPath;
+        this.labelPath = labelPath;
+        this.imagePath = imagePath;
+        this.bgPath = bgPath;
+        this.cpuThreadNum = cpuThreadNum;
+        this.cpuPowerMode = cpuPowerMode;
+        this.inputColorFormat = inputColorFormat;
+        this.inputShape = inputShape;
+    }
+
+    public void setInputShape(Bitmap inputImage){
+        this.inputShape[0] = 1;
+        this.inputShape[1] = 3;
+        this.inputShape[2] = inputImage.getHeight();
+        this.inputShape[3] = inputImage.getWidth();
+
+    }
+
+}

+ 79 - 0
deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/preprocess/Preprocess.java

@@ -0,0 +1,79 @@
+package com.paddle.demo.matting.preprocess;
+
+import android.graphics.Bitmap;
+import android.util.Log;
+
+import com.paddle.demo.matting.config.Config;
+
+import static android.graphics.Color.blue;
+import static android.graphics.Color.green;
+import static android.graphics.Color.red;
+
+public class Preprocess {
+
+    private static final String TAG = Preprocess.class.getSimpleName();
+
+    Config config;
+    int channels;
+    int width;
+    int height;
+
+    public  float[] inputData;
+
+    public void init(Config config){
+        this.config = config;
+        this.channels = (int) config.inputShape[1];
+        this.height = (int) config.inputShape[2];
+        this.width = (int) config.inputShape[3];
+        this.inputData = new float[channels * width * height];
+    }
+
+    public void normalize(float[] inputData){
+        for (int i = 0; i < inputData.length; i++) {
+            inputData[i] = (float) ((inputData[i] / 255 - 0.5) / 0.5);
+        }
+    }
+
+    public boolean to_array(Bitmap inputImage){
+
+        if (channels == 3) {
+            int[] channelIdx = null;
+            if (config.inputColorFormat.equalsIgnoreCase("RGB")) {
+                channelIdx = new int[]{0, 1, 2};
+            } else if (config.inputColorFormat.equalsIgnoreCase("BGR")) {
+                channelIdx = new int[]{2, 1, 0};
+            } else {
+                Log.i(TAG, "unknown color format " + config.inputColorFormat + ", only RGB and BGR color format is " +
+                        "supported!");
+                return false;
+            }
+            int[] channelStride = new int[]{width * height, width * height * 2};
+
+            for (int y = 0; y < height; y++) {
+                for (int x = 0; x < width; x++) {
+                    int color = inputImage.getPixel(x, y);
+                    float[] rgb = new float[]{(float) red(color) , (float) green(color) ,
+                            (float) blue(color)};
+                    inputData[y * width + x] = rgb[channelIdx[0]] ;
+                    inputData[y * width + x + channelStride[0]] = rgb[channelIdx[1]] ;
+                    inputData[y * width + x + channelStride[1]] = rgb[channelIdx[2]];
+                }
+            }
+        } else if (channels == 1) {
+            for (int y = 0; y < height; y++) {
+                for (int x = 0; x < width; x++) {
+                    int color = inputImage.getPixel(x, y);
+                    float gray = (float) (red(color) + green(color) + blue(color));
+                    inputData[y * width + x] = gray;
+                }
+            }
+        } else {
+            Log.i(TAG, "unsupported channel size " + Integer.toString(channels) + ",  only channel 1 and 3 is " +
+                    "supported!");
+            return false;
+        }
+        return true;
+
+    }
+
+}

+ 151 - 0
deploy/human_matting_android_demo/app/src/main/java/com/paddle/demo/matting/visual/Visualize.java

@@ -0,0 +1,151 @@
+package com.paddle.demo.matting.visual;
+
+import android.graphics.Bitmap;
+import android.graphics.Canvas;
+import android.graphics.Matrix;
+import android.graphics.Paint;
+import android.util.Log;
+
+import com.baidu.paddle.lite.Tensor;
+
+import java.io.FileOutputStream;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+
+public class Visualize {
+    private static final String TAG = Visualize.class.getSimpleName();
+
+    public Bitmap segdraw(Bitmap inputImage, Tensor outputTensor,Bitmap bg){
+        //设置叠加颜色(ARGB格式)  背景采用黑色,前景采用黄色(黄色由红绿两种颜色叠加)
+        final int[] colors_map = {0xFF000000, 0xFFFFFF00};
+        long[] output = outputTensor.getLongData();
+        long outputShape[] = outputTensor.shape();
+        long outputSize = 1;
+
+        for (long s : outputShape) {
+            outputSize *= s;
+        }
+
+        int[] objectColor = new int[(int)outputSize];
+
+        for(int i=0;i<output.length;i++){
+            objectColor[i] = colors_map[(int)output[i]];
+        }
+
+        Bitmap.Config config = inputImage.getConfig();
+        Bitmap outputImage = null;
+        if(outputShape.length==3){
+            outputImage = Bitmap.createBitmap(objectColor, (int)outputShape[2], (int)outputShape[1], config);
+            outputImage = Bitmap.createScaledBitmap(outputImage, inputImage.getWidth(), inputImage.getHeight(),true);
+        }
+
+        else if (outputShape.length==4){
+            outputImage = Bitmap.createBitmap(objectColor, (int)outputShape[3], (int)outputShape[2], config);
+        }
+        Bitmap bmOverlay = Bitmap.createBitmap(inputImage.getWidth(), inputImage.getHeight() , inputImage.getConfig());
+        Canvas canvas = new Canvas(bmOverlay);
+        canvas.drawBitmap(inputImage, new Matrix(), null);
+
+        Paint paint = new Paint();
+        paint.setAlpha(0x80);//采用一半对应的透明度进行叠加
+        canvas.drawBitmap(outputImage, 0, 0, paint);
+
+        return bmOverlay;
+    }
+
+    public Bitmap draw(Bitmap inputImage, Tensor outputTensor,Bitmap bg){
+        float[] output = outputTensor.getFloatData();
+        long outputShape[] = outputTensor.shape();
+        int outputSize = 1;
+
+        for (long s : outputShape) {
+            outputSize *= s;
+        }
+        List<Float> arralist = new LinkedList<>();
+        for (int i=0; i<outputSize;i++){
+            arralist.add((float)output[i]);
+        }
+
+        Bitmap mALPHA_IMAGE = floatArrayToBitmap(arralist,(int)outputShape[3],(int)outputShape[2]);
+
+        //调整尺寸
+        Bitmap alpha = Bitmap.createScaledBitmap(mALPHA_IMAGE,inputImage.getWidth(),inputImage.getHeight(),true);
+        Bitmap bgImg = Bitmap.createScaledBitmap(bg,inputImage.getWidth(),inputImage.getHeight(),true);
+
+        //重新合成图像
+        Bitmap result = synthetizeBitmap(inputImage,bgImg, alpha);
+        return result;
+    }
+
+    //将float数组转成bitmap格式的图片
+    private Bitmap floatArrayToBitmap(List<Float>  floatArray,int width,int height){
+        byte alpha = (byte) 255 ;
+        Bitmap bmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888) ;
+        ByteBuffer byteBuffer = ByteBuffer.allocate(width*height*4*3) ;
+        float Maximum = Collections.max(floatArray);
+        float minmum = Collections.min(floatArray);
+        float delta = Maximum - minmum + 0.00000000001f ;
+
+        int i = 0 ;
+        for (float value : floatArray){
+            byte temValue = (byte) ((((value-minmum)/delta)*255));
+            byteBuffer.put(4*i, temValue) ;
+            byteBuffer.put(4*i+1, temValue) ;
+            byteBuffer.put(4*i+2, temValue) ;
+            byteBuffer.put(4*i+3, alpha) ;
+            i++;
+        }
+        bmp.copyPixelsFromBuffer(byteBuffer) ;
+        return bmp ;
+    }
+
+    //将原图与背景按照推理得到的alpha图进行合成
+    private Bitmap synthetizeBitmap(Bitmap front,Bitmap background, Bitmap alpha){
+        int width = front.getWidth();
+        int height = front.getHeight();
+        Bitmap result=Bitmap.createBitmap(width,height, Bitmap.Config.ARGB_8888);
+        int[] frontPixels = new int[width * height];
+        int[] backgroundPixels = new int[width * height];
+        int[] alphaPixels = new int[width * height];
+        front.getPixels(frontPixels,0,width,0,0,width,height);
+        background.getPixels(backgroundPixels,0,width,0,0,width,height);
+        alpha.getPixels(alphaPixels,0,width,0,0,width,height);
+        float frontA = 0,frontR = 0,frontG = 0,frontB = 0;
+        float backgroundR = 0,backgroundG = 0,backgroundB = 0;
+        float alphaR = 0,alphaG = 0,alphaB = 0;
+        int index=0;
+
+        //逐个像素赋值(这种写法比较耗时,后续可以优化)
+        for (int row=0; row < height; row++){
+            for (int col=0; col < width; col++){
+                index = width*row +col;
+
+                //取出前景图像像素值
+                frontA=(frontPixels[index]>>24)&0xff;
+                frontR=(frontPixels[index]>>16)&0xff;
+                frontG=(frontPixels[index]>>8)&0xff;
+                frontB=frontPixels[index]&0xff;
+
+                //取出alpha像素值
+                alphaR=(alphaPixels[index]>>16)&0xff;
+                alphaG=(alphaPixels[index]>>8)&0xff;
+                alphaB=alphaPixels[index]&0xff;
+
+                //取出背景图像像素值
+                backgroundR=(backgroundPixels[index]>>16)&0xff;
+                backgroundG=(backgroundPixels[index]>>8)&0xff;
+                backgroundB=backgroundPixels[index]&0xff;
+
+                //重新合成  ImgOut = F * alpha/255 + BG * ( 1 - alpha/255 )
+                frontR= frontR*alphaR/255 + backgroundR*(1-alphaR/255);
+                frontG=frontG*alphaG/255 + backgroundG*(1-alphaG/255);
+                frontB=frontB*alphaB/255 + backgroundB*(1-alphaB/255);
+                frontPixels[index]=(int)frontA<<24|((int)frontR<<16)|((int)frontG<<8)|(int)frontB;
+            }
+        }
+        result.setPixels(frontPixels,0,width,0,0,width,height);;
+        return result;
+    }
+}

+ 34 - 0
deploy/human_matting_android_demo/app/src/main/res/drawable-v24/ic_launcher_foreground.xml

@@ -0,0 +1,34 @@
+<vector xmlns:android="http://schemas.android.com/apk/res/android"
+    xmlns:aapt="http://schemas.android.com/aapt"
+    android:width="108dp"
+    android:height="108dp"
+    android:viewportWidth="108"
+    android:viewportHeight="108">
+    <path
+        android:fillType="evenOdd"
+        android:pathData="M32,64C32,64 38.39,52.99 44.13,50.95C51.37,48.37 70.14,49.57 70.14,49.57L108.26,87.69L108,109.01L75.97,107.97L32,64Z"
+        android:strokeWidth="1"
+        android:strokeColor="#00000000">
+        <aapt:attr name="android:fillColor">
+            <gradient
+                android:endX="78.5885"
+                android:endY="90.9159"
+                android:startX="48.7653"
+                android:startY="61.0927"
+                android:type="linear">
+                <item
+                    android:color="#44000000"
+                    android:offset="0.0" />
+                <item
+                    android:color="#00000000"
+                    android:offset="1.0" />
+            </gradient>
+        </aapt:attr>
+    </path>
+    <path
+        android:fillColor="#FFFFFF"
+        android:fillType="nonZero"
+        android:pathData="M66.94,46.02L66.94,46.02C72.44,50.07 76,56.61 76,64L32,64C32,56.61 35.56,50.11 40.98,46.06L36.18,41.19C35.45,40.45 35.45,39.3 36.18,38.56C36.91,37.81 38.05,37.81 38.78,38.56L44.25,44.05C47.18,42.57 50.48,41.71 54,41.71C57.48,41.71 60.78,42.57 63.68,44.05L69.11,38.56C69.84,37.81 70.98,37.81 71.71,38.56C72.44,39.3 72.44,40.45 71.71,41.19L66.94,46.02ZM62.94,56.92C64.08,56.92 65,56.01 65,54.88C65,53.76 64.08,52.85 62.94,52.85C61.8,52.85 60.88,53.76 60.88,54.88C60.88,56.01 61.8,56.92 62.94,56.92ZM45.06,56.92C46.2,56.92 47.13,56.01 47.13,54.88C47.13,53.76 46.2,52.85 45.06,52.85C43.92,52.85 43,53.76 43,54.88C43,56.01 43.92,56.92 45.06,56.92Z"
+        android:strokeWidth="1"
+        android:strokeColor="#00000000" />
+</vector>

+ 170 - 0
deploy/human_matting_android_demo/app/src/main/res/drawable/ic_launcher_background.xml

@@ -0,0 +1,170 @@
+<?xml version="1.0" encoding="utf-8"?>
+<vector xmlns:android="http://schemas.android.com/apk/res/android"
+    android:width="108dp"
+    android:height="108dp"
+    android:viewportWidth="108"
+    android:viewportHeight="108">
+    <path
+        android:fillColor="#008577"
+        android:pathData="M0,0h108v108h-108z" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M9,0L9,108"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M19,0L19,108"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M29,0L29,108"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M39,0L39,108"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M49,0L49,108"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M59,0L59,108"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M69,0L69,108"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M79,0L79,108"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M89,0L89,108"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M99,0L99,108"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M0,9L108,9"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M0,19L108,19"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M0,29L108,29"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M0,39L108,39"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M0,49L108,49"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M0,59L108,59"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M0,69L108,69"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M0,79L108,79"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M0,89L108,89"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M0,99L108,99"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M19,29L89,29"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M19,39L89,39"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M19,49L89,49"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M19,59L89,59"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M19,69L89,69"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M19,79L89,79"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M29,19L29,89"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M39,19L39,89"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M49,19L49,89"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M59,19L59,89"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M69,19L69,89"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+    <path
+        android:fillColor="#00000000"
+        android:pathData="M79,19L79,89"
+        android:strokeWidth="0.8"
+        android:strokeColor="#33FFFFFF" />
+</vector>

BIN
deploy/human_matting_android_demo/app/src/main/res/drawable/paddle_logo.png


+ 132 - 0
deploy/human_matting_android_demo/app/src/main/res/layout/activity_main.xml

@@ -0,0 +1,132 @@
+<?xml version="1.0" encoding="utf-8"?>
+<android.support.constraint.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
+                                             xmlns:app="http://schemas.android.com/apk/res-auto"
+                                             xmlns:tools="http://schemas.android.com/tools"
+                                             android:layout_width="match_parent"
+                                             android:layout_height="match_parent">
+
+    <RelativeLayout
+            android:layout_width="match_parent"
+            android:layout_height="match_parent">
+
+        <LinearLayout
+            android:id="@+id/v_input_info"
+            android:layout_width="fill_parent"
+            android:layout_height="wrap_content"
+            android:layout_alignParentTop="true"
+            android:orientation="vertical">
+
+            <TextView
+                android:id="@+id/tv_input_setting"
+                android:layout_width="wrap_content"
+                android:layout_height="wrap_content"
+                android:layout_marginLeft="12dp"
+                android:layout_marginTop="10dp"
+                android:layout_marginRight="12dp"
+                android:layout_marginBottom="5dp"
+                android:lineSpacingExtra="4dp"
+                android:maxLines="6"
+                android:scrollbars="vertical"
+                android:singleLine="false"
+                android:text="" />
+
+        </LinearLayout>
+
+        <RelativeLayout
+                android:layout_width="match_parent"
+                android:layout_height="match_parent"
+                android:layout_above="@+id/v_output_info"
+                android:layout_below="@+id/v_input_info">
+
+            <ImageView
+                    android:id="@+id/iv_input_image"
+                    android:layout_width="400dp"
+                    android:layout_height="400dp"
+                    android:layout_centerHorizontal="true"
+                    android:layout_centerVertical="true"
+                    android:layout_marginLeft="12dp"
+                    android:layout_marginRight="12dp"
+                    android:layout_marginTop="5dp"
+                    android:layout_marginBottom="5dp"
+                    android:adjustViewBounds="true"
+                    android:scaleType="fitCenter"/>
+        </RelativeLayout>
+
+
+        <RelativeLayout
+                android:id="@+id/v_output_info"
+                android:layout_width="wrap_content"
+                android:layout_height="wrap_content"
+                android:layout_alignParentBottom="true"
+                android:layout_centerHorizontal="true">
+
+            <TextView
+                    android:id="@+id/tv_output_result"
+                    android:layout_width="wrap_content"
+                    android:layout_height="wrap_content"
+                    android:layout_alignParentTop="true"
+                    android:layout_centerHorizontal="true"
+                    android:layout_centerVertical="true"
+                    android:scrollbars="vertical"
+                    android:layout_marginLeft="12dp"
+                    android:layout_marginRight="12dp"
+                    android:layout_marginTop="5dp"
+                    android:layout_marginBottom="5dp"
+                    android:textAlignment="center"
+                    android:lineSpacingExtra="5dp"
+                    android:singleLine="false"
+                    android:maxLines="5"
+                    android:text=""
+                    android:gravity="center_horizontal" />
+
+            <TextView
+                    android:id="@+id/tv_inference_time"
+                    android:layout_width="wrap_content"
+                    android:layout_height="wrap_content"
+                    android:layout_below="@+id/tv_output_result"
+                    android:layout_centerHorizontal="true"
+                    android:layout_centerVertical="true"
+                    android:textAlignment="center"
+                    android:layout_marginLeft="12dp"
+                    android:layout_marginRight="12dp"
+                    android:layout_marginTop="5dp"
+                    android:layout_marginBottom="10dp"
+                    android:text=""
+                    android:gravity="center_horizontal" />
+
+            <Button
+                android:id="@+id/save_img"
+                android:layout_width="wrap_content"
+                android:layout_height="wrap_content"
+                android:layout_below="@+id/tv_inference_time"
+                android:layout_centerHorizontal="true"
+                android:layout_centerVertical="true"
+                android:textAlignment="center"
+                android:layout_marginLeft="12dp"
+                android:layout_marginRight="12dp"
+                android:layout_marginTop="5dp"
+                android:layout_marginBottom="10dp"
+                android:text="保存图像"
+                android:gravity="center"
+                android:onClick="clickSaveImg"/>
+
+            <ImageView
+                android:id="@+id/paddlelogo"
+                android:layout_width="400dp"
+                android:layout_height="40dp"
+                android:layout_below="@+id/save_img"
+                android:layout_centerHorizontal="true"
+                android:layout_centerVertical="true"
+                android:layout_marginLeft="12dp"
+                android:layout_marginRight="12dp"
+                android:layout_marginTop="5dp"
+                android:layout_marginBottom="5dp"
+                android:adjustViewBounds="true"
+                android:scaleType="fitCenter"
+                android:src="@drawable/paddle_logo"/>
+
+        </RelativeLayout>
+
+    </RelativeLayout>
+
+</android.support.constraint.ConstraintLayout>

+ 21 - 0
deploy/human_matting_android_demo/app/src/main/res/menu/menu_action_options.xml

@@ -0,0 +1,21 @@
+<menu xmlns:android="http://schemas.android.com/apk/res/android"
+      xmlns:app="http://schemas.android.com/apk/res-auto">
+    <group android:id="@+id/pick_image">
+        <item
+            android:id="@+id/open_gallery"
+            android:title="打开本地相册"
+            app:showAsAction="withText" />
+
+        <item
+            android:id="@+id/take_photo"
+            android:title="打开摄像头拍照"
+            app:showAsAction="withText" />
+    </group>
+
+    <group>
+        <item
+            android:id="@+id/settings"
+            android:title="设置"
+            app:showAsAction="withText" />
+    </group>
+</menu>

+ 5 - 0
deploy/human_matting_android_demo/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml

@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
+    <background android:drawable="@drawable/ic_launcher_background" />
+    <foreground android:drawable="@drawable/ic_launcher_foreground" />
+</adaptive-icon>

+ 5 - 0
deploy/human_matting_android_demo/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml

@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
+    <background android:drawable="@drawable/ic_launcher_background" />
+    <foreground android:drawable="@drawable/ic_launcher_foreground" />
+</adaptive-icon>

BIN
deploy/human_matting_android_demo/app/src/main/res/mipmap-hdpi/ic_launcher.png


BIN
deploy/human_matting_android_demo/app/src/main/res/mipmap-hdpi/ic_launcher_round.png


BIN
deploy/human_matting_android_demo/app/src/main/res/mipmap-mdpi/ic_launcher.png


BIN
deploy/human_matting_android_demo/app/src/main/res/mipmap-mdpi/ic_launcher_round.png


BIN
deploy/human_matting_android_demo/app/src/main/res/mipmap-xhdpi/ic_launcher.png


BIN
deploy/human_matting_android_demo/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png


BIN
deploy/human_matting_android_demo/app/src/main/res/mipmap-xxhdpi/ic_launcher.png


BIN
deploy/human_matting_android_demo/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png


BIN
deploy/human_matting_android_demo/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png


BIN
deploy/human_matting_android_demo/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png


+ 39 - 0
deploy/human_matting_android_demo/app/src/main/res/values/arrays.xml

@@ -0,0 +1,39 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources>
+    <string-array name="cpu_thread_num_entries">
+        <item>1 threads</item>
+        <item>2 threads</item>
+        <item>4 threads</item>
+        <item>8 threads</item>
+    </string-array>
+    <string-array name="cpu_thread_num_values">
+        <item>1</item>
+        <item>2</item>
+        <item>4</item>
+        <item>8</item>
+    </string-array>
+    <string-array name="cpu_power_mode_entries">
+        <item>HIGH(only big cores)</item>
+        <item>LOW(only LITTLE cores)</item>
+        <item>FULL(all cores)</item>
+        <item>NO_BIND(depends on system)</item>
+        <item>RAND_HIGH</item>
+        <item>RAND_LOW</item>
+    </string-array>
+    <string-array name="cpu_power_mode_values">
+        <item>LITE_POWER_HIGH</item>
+        <item>LITE_POWER_LOW</item>
+        <item>LITE_POWER_FULL</item>
+        <item>LITE_POWER_NO_BIND</item>
+        <item>LITE_POWER_RAND_HIGH</item>
+        <item>LITE_POWER_RAND_LOW</item>
+    </string-array>
+    <string-array name="input_color_format_entries">
+        <item>BGR color format</item>
+        <item>RGB color format</item>
+    </string-array>
+    <string-array name="input_color_format_values">
+        <item>BGR</item>
+        <item>RGB</item>
+    </string-array>
+</resources>

+ 6 - 0
deploy/human_matting_android_demo/app/src/main/res/values/colors.xml

@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources>
+    <color name="colorPrimary">#008577</color>
+    <color name="colorPrimaryDark">#00574B</color>
+    <color name="colorAccent">#D81B60</color>
+</resources>

+ 22 - 0
deploy/human_matting_android_demo/app/src/main/res/values/strings.xml

@@ -0,0 +1,22 @@
+<resources>
+<string name="app_name">人像抠图(发丝级背景替换)</string>
+<!-- image matting settings -->
+<string name="CHOOSE_PRE_INSTALLED_MODEL_KEY">CHOOSE_PRE_INSTALLED_MODEL_KEY</string>
+<string name="ENABLE_CUSTOM_SETTINGS_KEY">ENABLE_CUSTOM_SETTINGS_KEY</string>
+<string name="MODEL_PATH_KEY">MODEL_PATH_KEY</string>
+<string name="LABEL_PATH_KEY">LABEL_PATH_KEY</string>
+<string name="IMAGE_PATH_KEY">IMAGE_PATH_KEY</string>
+    <string name="BG_PATH_KEY">BG_PATH_KEY</string>
+<string name="CPU_THREAD_NUM_KEY">CPU_THREAD_NUM_KEY</string>
+<string name="CPU_POWER_MODE_KEY">CPU_POWER_MODE_KEY</string>
+<string name="INPUT_COLOR_FORMAT_KEY">INPUT_COLOR_FORMAT_KEY</string>
+<string name="INPUT_SHAPE_KEY">INPUT_SHAPE_KEY</string>
+<string name="MODEL_PATH_DEFAULT">image_matting/models/modnet</string>
+<string name="LABEL_PATH_DEFAULT">image_matting/labels/label_list</string>
+<string name="IMAGE_PATH_DEFAULT">image_matting/images/human.jpg</string>
+    <string name="BG_PATH_DEFAULT">image_matting/images/bg.jpg</string>
+<string name="CPU_THREAD_NUM_DEFAULT">1</string>
+<string name="CPU_POWER_MODE_DEFAULT">LITE_POWER_HIGH</string>
+<string name="INPUT_COLOR_FORMAT_DEFAULT">RGB</string>
+<string name="INPUT_SHAPE_DEFAULT">1,3,256,256</string>
+</resources>

+ 25 - 0
deploy/human_matting_android_demo/app/src/main/res/values/styles.xml

@@ -0,0 +1,25 @@
+<resources>
+
+    <!-- Base application theme. -->
+    <style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
+        <!-- Customize your theme here. -->
+        <item name="colorPrimary">@color/colorPrimary</item>
+        <item name="colorPrimaryDark">@color/colorPrimaryDark</item>
+        <item name="colorAccent">@color/colorAccent</item>
+        <item name="actionOverflowMenuStyle">@style/OverflowMenuStyle</item>
+    </style>
+
+    <style name="OverflowMenuStyle" parent="Widget.AppCompat.Light.PopupMenu.Overflow">
+        <item name="overlapAnchor">false</item>
+    </style>
+
+    <style name="AppTheme.NoActionBar">
+        <item name="windowActionBar">false</item>
+        <item name="windowNoTitle">true</item>
+    </style>
+
+    <style name="AppTheme.AppBarOverlay" parent="ThemeOverlay.AppCompat.Dark.ActionBar"/>
+
+    <style name="AppTheme.PopupOverlay" parent="ThemeOverlay.AppCompat.Light"/>
+
+</resources>

+ 59 - 0
deploy/human_matting_android_demo/app/src/main/res/xml/settings.xml

@@ -0,0 +1,59 @@
+<?xml version="1.0" encoding="utf-8"?>
+<PreferenceScreen xmlns:android="http://schemas.android.com/apk/res/android" >
+    <PreferenceCategory android:title="Model Settings">
+        <ListPreference
+                android:defaultValue="@string/MODEL_PATH_DEFAULT"
+                android:key="@string/CHOOSE_PRE_INSTALLED_MODEL_KEY"
+                android:negativeButtonText="@null"
+                android:positiveButtonText="@null"
+                android:title="Choose pre-installed models" />
+        <CheckBoxPreference
+                android:defaultValue="false"
+                android:key="@string/ENABLE_CUSTOM_SETTINGS_KEY"
+                android:summaryOn="Enable"
+                android:summaryOff="Disable"
+                android:title="Enable custom settings"/>
+        <EditTextPreference
+                android:key="@string/MODEL_PATH_KEY"
+                android:defaultValue="@string/MODEL_PATH_DEFAULT"
+                android:title="Model Path" />
+        <EditTextPreference
+                android:key="@string/LABEL_PATH_KEY"
+                android:defaultValue="@string/LABEL_PATH_DEFAULT"
+                android:title="Label Path" />
+        <EditTextPreference
+                android:key="@string/IMAGE_PATH_KEY"
+                android:defaultValue="@string/IMAGE_PATH_DEFAULT"
+                android:title="Image Path" />
+    </PreferenceCategory>
+    <PreferenceCategory android:title="CPU Settings">
+        <ListPreference
+                android:defaultValue="@string/CPU_THREAD_NUM_DEFAULT"
+                android:key="@string/CPU_THREAD_NUM_KEY"
+                android:negativeButtonText="@null"
+                android:positiveButtonText="@null"
+                android:title="CPU Thread Num"
+                android:entries="@array/cpu_thread_num_entries"
+                android:entryValues="@array/cpu_thread_num_values"/>
+        <ListPreference
+                android:defaultValue="@string/CPU_POWER_MODE_DEFAULT"
+                android:key="@string/CPU_POWER_MODE_KEY"
+                android:negativeButtonText="@null"
+                android:positiveButtonText="@null"
+                android:title="CPU Power Mode"
+                android:entries="@array/cpu_power_mode_entries"
+                android:entryValues="@array/cpu_power_mode_values"/>
+    </PreferenceCategory>
+    <PreferenceCategory android:title="Input Settings">
+        <ListPreference
+                android:defaultValue="@string/INPUT_COLOR_FORMAT_DEFAULT"
+                android:key="@string/INPUT_COLOR_FORMAT_KEY"
+                android:negativeButtonText="@null"
+                android:positiveButtonText="@null"
+                android:title="Input Color Format: BGR or RGB"
+                android:entries="@array/input_color_format_entries"
+                android:entryValues="@array/input_color_format_values"/>
+
+    </PreferenceCategory>
+</PreferenceScreen>
+

+ 17 - 0
deploy/human_matting_android_demo/app/src/test/java/com/baidu/paddle/lite/demo/ExampleUnitTest.java

@@ -0,0 +1,17 @@
+package com.baidu.paddle.lite.demo;
+
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+/**
+ * Example local unit test, which will execute on the development machine (host).
+ *
+ * @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
+ */
+public class ExampleUnitTest {
+    @Test
+    public void addition_isCorrect() {
+        assertEquals(4, 2 + 2);
+    }
+}

+ 27 - 0
deploy/human_matting_android_demo/build.gradle

@@ -0,0 +1,27 @@
+// Top-level build file where you can add configuration options common to all sub-projects/modules.
+
+buildscript {
+    repositories {
+        google()
+        jcenter()
+        
+    }
+    dependencies {
+        classpath 'com.android.tools.build:gradle:3.4.0'
+        
+        // NOTE: Do not place your application dependencies here; they belong
+        // in the individual module build.gradle files
+    }
+}
+
+allprojects {
+    repositories {
+        google()
+        jcenter()
+        
+    }
+}
+
+task clean(type: Delete) {
+    delete rootProject.buildDir
+}

+ 17 - 0
deploy/human_matting_android_demo/gradle.properties

@@ -0,0 +1,17 @@
+# Project-wide Gradle settings.
+# IDE (e.g. Android Studio) users:
+# Gradle settings configured through the IDE *will override*
+# any settings specified in this file.
+# For more details on how to configure your build environment visit
+# http://www.gradle.org/docs/current/userguide/build_environment.html
+# Specifies the JVM arguments used for the daemon process.
+# The setting is particularly useful for tweaking memory settings.
+org.gradle.jvmargs=-Xmx1536m
+# When configured, Gradle will run in incubating parallel mode.
+# This option should only be used with decoupled projects. More details, visit
+# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
+# org.gradle.parallel=true
+android.injected.testOnly=false
+
+
+

BIN
deploy/human_matting_android_demo/gradle/wrapper/gradle-wrapper.jar


+ 6 - 0
deploy/human_matting_android_demo/gradle/wrapper/gradle-wrapper.properties

@@ -0,0 +1,6 @@
+#Thu Aug 22 15:05:37 CST 2019
+distributionBase=GRADLE_USER_HOME
+distributionPath=wrapper/dists
+zipStoreBase=GRADLE_USER_HOME
+zipStorePath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-5.1.1-all.zip

+ 172 - 0
deploy/human_matting_android_demo/gradlew

@@ -0,0 +1,172 @@
+#!/usr/bin/env sh
+
+##############################################################################
+##
+##  Gradle start up script for UN*X
+##
+##############################################################################
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+    ls=`ls -ld "$PRG"`
+    link=`expr "$ls" : '.*-> \(.*\)$'`
+    if expr "$link" : '/.*' > /dev/null; then
+        PRG="$link"
+    else
+        PRG=`dirname "$PRG"`"/$link"
+    fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS=""
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn () {
+    echo "$*"
+}
+
+die () {
+    echo
+    echo "$*"
+    echo
+    exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+nonstop=false
+case "`uname`" in
+  CYGWIN* )
+    cygwin=true
+    ;;
+  Darwin* )
+    darwin=true
+    ;;
+  MINGW* )
+    msys=true
+    ;;
+  NONSTOP* )
+    nonstop=true
+    ;;
+esac
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+    if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+        # IBM's JDK on AIX uses strange locations for the executables
+        JAVACMD="$JAVA_HOME/jre/sh/java"
+    else
+        JAVACMD="$JAVA_HOME/bin/java"
+    fi
+    if [ ! -x "$JAVACMD" ] ; then
+        die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+    fi
+else
+    JAVACMD="java"
+    which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
+    MAX_FD_LIMIT=`ulimit -H -n`
+    if [ $? -eq 0 ] ; then
+        if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+            MAX_FD="$MAX_FD_LIMIT"
+        fi
+        ulimit -n $MAX_FD
+        if [ $? -ne 0 ] ; then
+            warn "Could not set maximum file descriptor limit: $MAX_FD"
+        fi
+    else
+        warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+    fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+    GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin, switch paths to Windows format before running java
+if $cygwin ; then
+    APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+    CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+    JAVACMD=`cygpath --unix "$JAVACMD"`
+
+    # We build the pattern for arguments to be converted via cygpath
+    ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+    SEP=""
+    for dir in $ROOTDIRSRAW ; do
+        ROOTDIRS="$ROOTDIRS$SEP$dir"
+        SEP="|"
+    done
+    OURCYGPATTERN="(^($ROOTDIRS))"
+    # Add a user-defined pattern to the cygpath arguments
+    if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+        OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+    fi
+    # Now convert the arguments - kludge to limit ourselves to /bin/sh
+    i=0
+    for arg in "$@" ; do
+        CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+        CHECK2=`echo "$arg"|egrep -c "^-"`                                 ### Determine if an option
+
+        if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then                    ### Added a condition
+            eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+        else
+            eval `echo args$i`="\"$arg\""
+        fi
+        i=$((i+1))
+    done
+    case $i in
+        (0) set -- ;;
+        (1) set -- "$args0" ;;
+        (2) set -- "$args0" "$args1" ;;
+        (3) set -- "$args0" "$args1" "$args2" ;;
+        (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+        (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+        (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+        (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+        (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+        (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+    esac
+fi
+
+# Escape application args
+save () {
+    for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
+    echo " "
+}
+APP_ARGS=$(save "$@")
+
+# Collect all arguments for the java command, following the shell quoting and substitution rules
+eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
+
+# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
+if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
+  cd "$(dirname "$0")"
+fi
+
+exec "$JAVACMD" "$@"

+ 84 - 0
deploy/human_matting_android_demo/gradlew.bat

@@ -0,0 +1,84 @@
+@if "%DEBUG%" == "" @echo off
+@rem ##########################################################################
+@rem
+@rem  Gradle startup script for Windows
+@rem
+@rem ##########################################################################
+
+@rem Set local scope for the variables with windows NT shell
+if "%OS%"=="Windows_NT" setlocal
+
+set DIRNAME=%~dp0
+if "%DIRNAME%" == "" set DIRNAME=.
+set APP_BASE_NAME=%~n0
+set APP_HOME=%DIRNAME%
+
+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+set DEFAULT_JVM_OPTS=
+
+@rem Find java.exe
+if defined JAVA_HOME goto findJavaFromJavaHome
+
+set JAVA_EXE=java.exe
+%JAVA_EXE% -version >NUL 2>&1
+if "%ERRORLEVEL%" == "0" goto init
+
+echo.
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:findJavaFromJavaHome
+set JAVA_HOME=%JAVA_HOME:"=%
+set JAVA_EXE=%JAVA_HOME%/bin/java.exe
+
+if exist "%JAVA_EXE%" goto init
+
+echo.
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:init
+@rem Get command-line arguments, handling Windows variants
+
+if not "%OS%" == "Windows_NT" goto win9xME_args
+
+:win9xME_args
+@rem Slurp the command line arguments.
+set CMD_LINE_ARGS=
+set _SKIP=2
+
+:win9xME_args_slurp
+if "x%~1" == "x" goto execute
+
+set CMD_LINE_ARGS=%*
+
+:execute
+@rem Setup the command line
+
+set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
+
+@rem Execute Gradle
+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
+
+:end
+@rem End local scope for the variables with windows NT shell
+if "%ERRORLEVEL%"=="0" goto mainEnd
+
+:fail
+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
+rem the _cmd.exe /c_ return code!
+if  not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
+exit /b 1
+
+:mainEnd
+if "%OS%"=="Windows_NT" endlocal
+
+:omega

+ 1 - 0
deploy/human_matting_android_demo/settings.gradle

@@ -0,0 +1 @@
+include ':app'

+ 782 - 0
deploy/python/infer.py

@@ -0,0 +1,782 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import codecs
+import os
+import sys
+
+import cv2
+import tqdm
+import yaml
+import numpy as np
+import paddle
+from paddle.inference import create_predictor, PrecisionType
+from paddle.inference import Config as PredictConfig
+from paddleseg.cvlibs import manager
+from paddleseg.utils import get_sys_env, logger
+
+LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(os.path.join(LOCAL_PATH, '..', '..'))
+manager.BACKBONES._components_dict.clear()
+manager.TRANSFORMS._components_dict.clear()
+
+import ppmatting.transforms as T
+from ppmatting.utils import get_image_list, mkdir, estimate_foreground_ml, VideoReader, VideoWriter
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Deploy for matting model')
+    parser.add_argument(
+        "--config",
+        dest="cfg",
+        help="The config file.",
+        default=None,
+        type=str,
+        required=True)
+    parser.add_argument(
+        '--image_path',
+        dest='image_path',
+        help='The directory or path or file list of the images to be predicted.',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--trimap_path',
+        dest='trimap_path',
+        help='The directory or path or file list of the triamp to help predicted.',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--batch_size',
+        dest='batch_size',
+        help='Mini batch size of one gpu or cpu. When video inference, it is invalid.',
+        type=int,
+        default=1)
+    parser.add_argument(
+        '--video_path',
+        dest='video_path',
+        help='The path of the video to be predicted.',
+        type=str,
+        default=None)
+    parser.add_argument(
+        '--save_dir',
+        dest='save_dir',
+        help='The directory for saving the predict result.',
+        type=str,
+        default='./output')
+    parser.add_argument(
+        '--device',
+        choices=['cpu', 'gpu'],
+        default="gpu",
+        help="Select which device to inference, defaults to gpu.")
+    parser.add_argument(
+        '--fg_estimate',
+        default=True,
+        type=eval,
+        choices=[True, False],
+        help='Whether to estimate foreground when predicting.')
+
+    parser.add_argument(
+        '--cpu_threads',
+        default=10,
+        type=int,
+        help='Number of threads to predict when using cpu.')
+    parser.add_argument(
+        '--enable_mkldnn',
+        default=False,
+        type=eval,
+        choices=[True, False],
+        help='Enable to use mkldnn to speed up when using cpu.')
+    parser.add_argument(
+        '--use_trt',
+        default=False,
+        type=eval,
+        choices=[True, False],
+        help='Whether to use Nvidia TensorRT to accelerate prediction.')
+    parser.add_argument(
+        "--precision",
+        default="fp32",
+        type=str,
+        choices=["fp32", "fp16", "int8"],
+        help='The tensorrt precision.')
+    parser.add_argument(
+        '--enable_auto_tune',
+        default=False,
+        type=eval,
+        choices=[True, False],
+        help='Whether to enable tuned dynamic shape. We uses some images to collect '
+        'the dynamic shape for trt sub graph, which avoids setting dynamic shape manually.'
+    )
+    parser.add_argument(
+        '--auto_tuned_shape_file',
+        type=str,
+        default="auto_tune_tmp.pbtxt",
+        help='The temp file to save tuned dynamic shape.')
+
+    parser.add_argument(
+        "--benchmark",
+        type=eval,
+        default=False,
+        help="Whether to log some information about environment, model, configuration and performance."
+    )
+    parser.add_argument(
+        "--model_name",
+        default="",
+        type=str,
+        help='When `--benchmark` is True, the specified model name is displayed.'
+    )
+    parser.add_argument(
+        '--print_detail',
+        default=True,
+        type=eval,
+        choices=[True, False],
+        help='Print GLOG information of Paddle Inference.')
+
+    return parser.parse_args()
+
+
+class DeployConfig:
+    def __init__(self, path):
+        with codecs.open(path, 'r', 'utf-8') as file:
+            self.dic = yaml.load(file, Loader=yaml.FullLoader)
+        self._transforms = self.load_transforms(self.dic['Deploy'][
+            'transforms'])
+        self._dir = os.path.dirname(path)
+
+    @property
+    def transforms(self):
+        return self._transforms
+
+    @property
+    def model(self):
+        return os.path.join(self._dir, self.dic['Deploy']['model'])
+
+    @property
+    def params(self):
+        return os.path.join(self._dir, self.dic['Deploy']['params'])
+
+    @staticmethod
+    def load_transforms(t_list):
+        com = manager.TRANSFORMS
+        transforms = []
+        for t in t_list:
+            ctype = t.pop('type')
+            transforms.append(com[ctype](**t))
+        return T.Compose(transforms)
+
+
+def use_auto_tune(args):
+    return hasattr(PredictConfig, "collect_shape_range_info") \
+        and hasattr(PredictConfig, "enable_tuned_tensorrt_dynamic_shape") \
+        and args.device == "gpu" and args.use_trt and args.enable_auto_tune
+
+
+def auto_tune(args, imgs, img_nums):
+    """
+    Use images to auto tune the dynamic shape for trt sub graph.
+    The tuned shape saved in args.auto_tuned_shape_file.
+
+    Args:
+        args(dict): input args.
+        imgs(str, list[str]): the path for images.
+        img_nums(int): the nums of images used for auto tune.
+    Returns:
+        None
+    """
+    logger.info("Auto tune the dynamic shape for GPU TRT.")
+
+    assert use_auto_tune(args)
+
+    if not isinstance(imgs, (list, tuple)):
+        imgs = [imgs]
+    num = min(len(imgs), img_nums)
+
+    cfg = DeployConfig(args.cfg)
+    pred_cfg = PredictConfig(cfg.model, cfg.params)
+    pred_cfg.enable_use_gpu(100, 0)
+    if not args.print_detail:
+        pred_cfg.disable_glog_info()
+    pred_cfg.collect_shape_range_info(args.auto_tuned_shape_file)
+
+    predictor = create_predictor(pred_cfg)
+    input_names = predictor.get_input_names()
+    input_handle = predictor.get_input_handle(input_names[0])
+
+    for i in range(0, num):
+        data = {'img': imgs[i]}
+        data = cfg.transforms(data)
+        input_handle.reshape(data['img'].shape)
+        input_handle.copy_from_cpu(data['img'])
+        try:
+            predictor.run()
+        except:
+            logger.info(
+                "Auto tune fail. Usually, the error is out of GPU memory, "
+                "because the model and image is too large. \n")
+            del predictor
+            if os.path.exists(args.auto_tuned_shape_file):
+                os.remove(args.auto_tuned_shape_file)
+            return
+
+    logger.info("Auto tune success.\n")
+
+
+class Predictor:
+    def __init__(self, args):
+        """
+        Prepare for prediction.
+        The usage and docs of paddle inference, please refer to
+        https://paddleinference.paddlepaddle.org.cn/product_introduction/summary.html
+        """
+        self.args = args
+        self.cfg = DeployConfig(args.cfg)
+
+        self._init_base_config()
+        if args.device == 'cpu':
+            self._init_cpu_config()
+        else:
+            self._init_gpu_config()
+
+        self.predictor = create_predictor(self.pred_cfg)
+
+        if hasattr(args, 'benchmark') and args.benchmark:
+            import auto_log
+            pid = os.getpid()
+            gpu_id = None if args.device == 'cpu' else 0
+            self.autolog = auto_log.AutoLogger(
+                model_name=args.model_name,
+                model_precision=args.precision,
+                batch_size=args.batch_size,
+                data_shape="dynamic",
+                save_path=None,
+                inference_config=self.pred_cfg,
+                pids=pid,
+                process_name=None,
+                gpu_ids=gpu_id,
+                time_keys=[
+                    'preprocess_time', 'inference_time', 'postprocess_time'
+                ],
+                warmup=0,
+                logger=logger)
+
+    def _init_base_config(self):
+        self.pred_cfg = PredictConfig(self.cfg.model, self.cfg.params)
+        if not self.args.print_detail:
+            self.pred_cfg.disable_glog_info()
+        self.pred_cfg.enable_memory_optim()
+        self.pred_cfg.switch_ir_optim(True)
+
+    def _init_cpu_config(self):
+        """
+        Init the config for x86 cpu.
+        """
+        logger.info("Using CPU")
+        self.pred_cfg.disable_gpu()
+        if self.args.enable_mkldnn:
+            logger.info("Using MKLDNN")
+            # cache 1- different shapes for mkldnn
+            self.pred_cfg.set_mkldnn_cache_capacity(10)
+            self.pred_cfg.enable_mkldnn()
+        self.pred_cfg.set_cpu_math_library_num_threads(self.args.cpu_threads)
+
+    def _init_gpu_config(self):
+        """
+        Init the config for nvidia gpu.
+        """
+        logger.info("using GPU")
+        self.pred_cfg.enable_use_gpu(100, 0)
+        precision_map = {
+            "fp16": PrecisionType.Half,
+            "fp32": PrecisionType.Float32,
+            "int8": PrecisionType.Int8
+        }
+        precision_mode = precision_map[self.args.precision]
+
+        if self.args.use_trt:
+            logger.info("Use TRT")
+            self.pred_cfg.enable_tensorrt_engine(
+                workspace_size=1 << 30,
+                max_batch_size=1,
+                min_subgraph_size=300,
+                precision_mode=precision_mode,
+                use_static=False,
+                use_calib_mode=False)
+
+            if use_auto_tune(self.args) and \
+                os.path.exists(self.args.auto_tuned_shape_file):
+                logger.info("Use auto tuned dynamic shape")
+                allow_build_at_runtime = True
+                self.pred_cfg.enable_tuned_tensorrt_dynamic_shape(
+                    self.args.auto_tuned_shape_file, allow_build_at_runtime)
+            else:
+                logger.info("Use manual set dynamic shape")
+                min_input_shape = {"img": [1, 3, 100, 100]}
+                max_input_shape = {"img": [1, 3, 2000, 3000]}
+                opt_input_shape = {"img": [1, 3, 512, 1024]}
+                self.pred_cfg.set_trt_dynamic_shape_info(
+                    min_input_shape, max_input_shape, opt_input_shape)
+
+    def run(self, imgs, trimaps=None, imgs_dir=None):
+        self.imgs_dir = imgs_dir
+        num = len(imgs)
+        input_names = self.predictor.get_input_names()
+        input_handle = {}
+
+        for i in range(len(input_names)):
+            input_handle[input_names[i]] = self.predictor.get_input_handle(
+                input_names[i])
+        output_names = self.predictor.get_output_names()
+        output_handle = self.predictor.get_output_handle(output_names[0])
+        args = self.args
+
+        for i in tqdm.tqdm(range(0, num, args.batch_size)):
+            # warm up
+            if i == 0 and args.benchmark:
+                for _ in range(5):
+                    img_inputs = []
+                    if trimaps is not None:
+                        trimap_inputs = []
+                    trans_info = []
+                    for j in range(i, i + args.batch_size):
+                        img = imgs[j]
+                        trimap = trimaps[j] if trimaps is not None else None
+                        data = self._preprocess(img=img, trimap=trimap)
+                        img_inputs.append(data['img'])
+                        if trimaps is not None:
+                            trimap_inputs.append(data['trimap'][
+                                np.newaxis, :, :])
+                        trans_info.append(data['trans_info'])
+                    img_inputs = np.array(img_inputs)
+                    if trimaps is not None:
+                        trimap_inputs = (
+                            np.array(trimap_inputs)).astype('float32')
+
+                    input_handle['img'].copy_from_cpu(img_inputs)
+                    if trimaps is not None:
+                        input_handle['trimap'].copy_from_cpu(trimap_inputs)
+                    self.predictor.run()
+                    results = output_handle.copy_to_cpu()
+
+                    results = results.squeeze(1)
+                    for j in range(args.batch_size):
+                        trimap = trimap_inputs[
+                            j] if trimaps is not None else None
+                        result = self._postprocess(
+                            results[j], trans_info[j], trimap=trimap)
+
+            # inference
+            if args.benchmark:
+                self.autolog.times.start()
+
+            img_inputs = []
+            if trimaps is not None:
+                trimap_inputs = []
+            trans_info = []
+            for j in range(i, i + args.batch_size):
+                img = imgs[j]
+                trimap = trimaps[j] if trimaps is not None else None
+                data = self._preprocess(img=img, trimap=trimap)
+                img_inputs.append(data['img'])
+                if trimaps is not None:
+                    trimap_inputs.append(data['trimap'][np.newaxis, :, :])
+                trans_info.append(data['trans_info'])
+            img_inputs = np.array(img_inputs)
+            if trimaps is not None:
+                trimap_inputs = (np.array(trimap_inputs)).astype('float32')
+
+            input_handle['img'].copy_from_cpu(img_inputs)
+            if trimaps is not None:
+                input_handle['trimap'].copy_from_cpu(trimap_inputs)
+
+            if args.benchmark:
+                self.autolog.times.stamp()
+
+            self.predictor.run()
+
+            results = output_handle.copy_to_cpu()
+
+            if args.benchmark:
+                self.autolog.times.stamp()
+
+            results = results.squeeze(1)
+            for j in range(args.batch_size):
+                trimap = trimap_inputs[j] if trimaps is not None else None
+                result = self._postprocess(
+                    results[j], trans_info[j], trimap=trimap)
+                self._save_imgs(result, imgs[i + j])
+
+            if args.benchmark:
+                self.autolog.times.end(stamp=True)
+        logger.info("Finish")
+
+    def _preprocess(self, img, trimap=None):
+        data = {}
+        data['img'] = img
+        if trimap is not None:
+            data['trimap'] = trimap
+            data['gt_fields'] = ['trimap']
+        data = self.cfg.transforms(data)
+        return data
+
+    def _postprocess(self, alpha, trans_info, trimap=None):
+        """recover pred to origin shape"""
+        if trimap is not None:
+            trimap = trimap.squeeze(0)
+            alpha[trimap == 0] = 0
+            alpha[trimap == 255] = 1
+        for item in trans_info[::-1]:
+            if item[0] == 'resize':
+                h, w = item[1][0], item[1][1]
+                alpha = cv2.resize(
+                    alpha, (w, h), interpolation=cv2.INTER_LINEAR)
+            elif item[0] == 'padding':
+                h, w = item[1][0], item[1][1]
+                alpha = alpha[0:h, 0:w]
+            else:
+                raise Exception("Unexpected info '{}' in im_info".format(item[
+                    0]))
+        return alpha
+
+    def _save_imgs(self, alpha, img_path, fg=None):
+        ori_img = cv2.imread(img_path)
+        alpha = (alpha * 255).astype('uint8')
+
+        if self.imgs_dir is not None:
+            img_path = img_path.replace(self.imgs_dir, '')
+        else:
+            img_path = os.path.basename(img_path)
+        name, ext = os.path.splitext(img_path)
+        if name[0] == '/' or name[0] == '\\':
+            name = name[1:]
+
+        alpha_save_path = os.path.join(args.save_dir, name + '_alpha.png')
+        rgba_save_path = os.path.join(args.save_dir, name + '_rgba.png')
+
+        # save alpha
+        mkdir(alpha_save_path)
+        cv2.imwrite(alpha_save_path, alpha)
+
+        # save rgba image
+        mkdir(rgba_save_path)
+        if fg is None:
+            if args.fg_estimate:
+                fg = estimate_foreground_ml(ori_img / 255.0,
+                                            alpha / 255.0) * 255
+            else:
+                fg = ori_img
+        else:
+            fg = fg * 255
+        fg = fg.astype('uint8')
+        alpha = alpha[:, :, np.newaxis]
+        rgba = np.concatenate([fg, alpha], axis=-1)
+        cv2.imwrite(rgba_save_path, rgba)
+
+    def run_video(self, video_path):
+        """Video matting only support the trimap-free method"""
+        input_names = self.predictor.get_input_names()
+        input_handle = {}
+
+        for i in range(len(input_names)):
+            input_handle[input_names[i]] = self.predictor.get_input_handle(
+                input_names[i])
+        output_names = self.predictor.get_output_names()
+        output_handle = {}
+        output_handle['alpha'] = self.predictor.get_output_handle(output_names[
+            0])
+
+        # Build reader and writer
+        reader = VideoReader(video_path, self.cfg.transforms)
+        base_name = os.path.basename(video_path)
+        name = os.path.splitext(base_name)[0]
+        alpha_save_path = os.path.join(args.save_dir, name + '_alpha.avi')
+        fg_save_path = os.path.join(args.save_dir, name + '_fg.avi')
+        writer_alpha = VideoWriter(
+            alpha_save_path,
+            reader.fps,
+            frame_size=(reader.width, reader.height),
+            is_color=False)
+        writer_fg = VideoWriter(
+            fg_save_path,
+            reader.fps,
+            frame_size=(reader.width, reader.height),
+            is_color=True)
+
+        for data in tqdm.tqdm(reader):
+            trans_info = data['trans_info']
+            _, h, w = data['img'].shape
+
+            input_handle['img'].copy_from_cpu(data['img'][np.newaxis, ...])
+
+            self.predictor.run()
+
+            alpha = output_handle['alpha'].copy_to_cpu()
+
+            alpha = alpha.squeeze()
+            alpha = self._postprocess(alpha, trans_info)
+            self._save_frame(
+                alpha,
+                fg=None,
+                img=data['ori_img'],
+                writer_alpha=writer_alpha,
+                writer_fg=writer_fg)
+
+        writer_alpha.release()
+        writer_fg.release()
+        reader.release()
+
+    def _save_frame(self, alpha, fg, img, writer_alpha, writer_fg):
+        if fg is None:
+            img = img.transpose((1, 2, 0))
+            if self.args.fg_estimate:
+                fg = estimate_foreground_ml(img, alpha)
+            else:
+                fg = img
+        fg = fg * alpha[:, :, np.newaxis]
+
+        writer_alpha.write(alpha)
+        writer_fg.write(fg)
+
+
+class PredictorRVM(Predictor):
+    def __init__(self, args):
+        super().__init__(args=args)
+
+    def run(self, imgs, trimaps=None, imgs_dir=None):
+        self.imgs_dir = imgs_dir
+        num = len(imgs)
+        input_names = self.predictor.get_input_names()
+        input_handle = {}
+
+        for i in range(len(input_names)):
+            input_handle[input_names[i]] = self.predictor.get_input_handle(
+                input_names[i])
+        output_names = self.predictor.get_output_names()
+        output_handle = {}
+        output_handle['alpha'] = self.predictor.get_output_handle(output_names[
+            0])
+        output_handle['fg'] = self.predictor.get_output_handle(output_names[1])
+        output_handle['r1'] = self.predictor.get_output_handle(output_names[2])
+        output_handle['r2'] = self.predictor.get_output_handle(output_names[3])
+        output_handle['r3'] = self.predictor.get_output_handle(output_names[4])
+        output_handle['r4'] = self.predictor.get_output_handle(output_names[5])
+
+        args = self.args
+
+        for i in tqdm.tqdm(range(0, num, args.batch_size)):
+            # warm up
+            if i == 0 and args.benchmark:
+                for _ in range(5):
+                    img_inputs = []
+                    if trimaps is not None:
+                        trimap_inputs = []
+                    trans_info = []
+                    for j in range(i, i + args.batch_size):
+                        img = imgs[j]
+                        data = self._preprocess(img=img)
+                        img_inputs.append(data['img'])
+                        trans_info.append(data['trans_info'])
+                    img_inputs = np.array(img_inputs)
+                    n, _, h, w = img_inputs.shape
+                    downsample_ratio = min(512 / max(h, w), 1)
+                    downsample_ratio = np.array(
+                        [downsample_ratio], dtype='float32')
+
+                    input_handle['img'].copy_from_cpu(img_inputs)
+                    input_handle['downsample_ratio'].copy_from_cpu(
+                        downsample_ratio.astype('float32'))
+                    r_channels = [16, 20, 40, 64]
+                    for k in range(4):
+                        j = k + 1
+                        hj = int(np.ceil(int(h * downsample_ratio[0]) / 2**j))
+                        wj = int(np.ceil(int(w * downsample_ratio[0]) / 2**j))
+                        rj = np.zeros(
+                            (n, r_channels[k], hj, wj), dtype='float32')
+                        input_handle['r' + str(j)].copy_from_cpu(rj)
+
+                    self.predictor.run()
+                    alphas = output_handle['alpha'].copy_to_cpu()
+                    fgs = output_handle['fg'].copy_to_cpu()
+                    alphas = alphas.squeeze(1)
+                    for j in range(args.batch_size):
+                        alpha = self._postprocess(alphas[j], trans_info[j])
+                        fg = fgs[j]
+                        fg = np.transpose(fg, (1, 2, 0))
+                        fg = self._postprocess(fg, trans_info[j])
+
+            # inference
+            if args.benchmark:
+                self.autolog.times.start()
+
+            img_inputs = []
+            if trimaps is not None:
+                trimap_inputs = []
+            trans_info = []
+            for j in range(i, i + args.batch_size):
+                img = imgs[j]
+                data = self._preprocess(img=img)
+                img_inputs.append(data['img'])
+                trans_info.append(data['trans_info'])
+            img_inputs = np.array(img_inputs)
+            n, _, h, w = img_inputs.shape
+            downsample_ratio = min(512 / max(h, w), 1)
+            downsample_ratio = np.array([downsample_ratio], dtype='float32')
+
+            input_handle['img'].copy_from_cpu(img_inputs)
+            input_handle['downsample_ratio'].copy_from_cpu(
+                downsample_ratio.astype('float32'))
+            r_channels = [16, 20, 40, 64]
+            for k in range(4):
+                j = k + 1
+                hj = int(np.ceil(int(h * downsample_ratio[0]) / 2**j))
+                wj = int(np.ceil(int(w * downsample_ratio[0]) / 2**j))
+                rj = np.zeros((n, r_channels[k], hj, wj), dtype='float32')
+                input_handle['r' + str(j)].copy_from_cpu(rj)
+
+            if args.benchmark:
+                self.autolog.times.stamp()
+
+            self.predictor.run()
+            alphas = output_handle['alpha'].copy_to_cpu()
+            fgs = output_handle['fg'].copy_to_cpu()
+
+            if args.benchmark:
+                self.autolog.times.stamp()
+
+            alphas = alphas.squeeze(1)
+            for j in range(args.batch_size):
+                alpha = self._postprocess(alphas[j], trans_info[j])
+                fg = fgs[j]
+                fg = np.transpose(fg, (1, 2, 0))
+                fg = self._postprocess(fg, trans_info[j])
+                self._save_imgs(alpha, fg=fg, img_path=imgs[i + j])
+
+            if args.benchmark:
+                self.autolog.times.end(stamp=True)
+        logger.info("Finish")
+
+    def run_video(self, video_path):
+        input_names = self.predictor.get_input_names()
+        input_handle = {}
+
+        for i in range(len(input_names)):
+            input_handle[input_names[i]] = self.predictor.get_input_handle(
+                input_names[i])
+        output_names = self.predictor.get_output_names()
+        output_handle = {}
+        output_handle['alpha'] = self.predictor.get_output_handle(output_names[
+            0])
+        output_handle['fg'] = self.predictor.get_output_handle(output_names[1])
+        output_handle['r1'] = self.predictor.get_output_handle(output_names[2])
+        output_handle['r2'] = self.predictor.get_output_handle(output_names[3])
+        output_handle['r3'] = self.predictor.get_output_handle(output_names[4])
+        output_handle['r4'] = self.predictor.get_output_handle(output_names[5])
+
+        # Build reader and writer
+        reader = VideoReader(video_path, self.cfg.transforms)
+        base_name = os.path.basename(video_path)
+        name = os.path.splitext(base_name)[0]
+        alpha_save_path = os.path.join(args.save_dir, name + '_alpha.avi')
+        fg_save_path = os.path.join(args.save_dir, name + '_fg.avi')
+        writer_alpha = VideoWriter(
+            alpha_save_path,
+            reader.fps,
+            frame_size=(reader.width, reader.height),
+            is_color=False)
+        writer_fg = VideoWriter(
+            fg_save_path,
+            reader.fps,
+            frame_size=(reader.width, reader.height),
+            is_color=True)
+
+        r_channels = [16, 20, 40, 64]
+        for i, data in tqdm.tqdm(enumerate(reader)):
+            trans_info = data['trans_info']
+            _, h, w = data['img'].shape
+            if i == 0:
+                downsample_ratio = min(512 / max(h, w), 1)
+                downsample_ratio = np.array([downsample_ratio], dtype='float32')
+                r_channels = [16, 20, 40, 64]
+                for k in range(4):
+                    j = k + 1
+                    hj = int(np.ceil(int(h * downsample_ratio[0]) / 2**j))
+                    wj = int(np.ceil(int(w * downsample_ratio[0]) / 2**j))
+                    rj = np.zeros((1, r_channels[k], hj, wj), dtype='float32')
+                    input_handle['r' + str(j)].copy_from_cpu(rj)
+            else:
+                input_handle['r1'] = output_handle['r1']
+                input_handle['r2'] = output_handle['r2']
+                input_handle['r3'] = output_handle['r3']
+                input_handle['r4'] = output_handle['r4']
+
+            input_handle['img'].copy_from_cpu(data['img'][np.newaxis, ...])
+            input_handle['downsample_ratio'].copy_from_cpu(
+                downsample_ratio.astype('float32'))
+
+            self.predictor.run()
+
+            alpha = output_handle['alpha'].copy_to_cpu()
+            fg = output_handle['fg'].copy_to_cpu()
+
+            alpha = alpha.squeeze()
+            alpha = self._postprocess(alpha, trans_info)
+            fg = fg.squeeze().transpose((1, 2, 0))
+            fg = self._postprocess(fg, trans_info)
+            self._save_frame(alpha, fg, data['ori_img'], writer_alpha,
+                             writer_fg)
+        writer_alpha.release()
+        writer_fg.release()
+        reader.release()
+
+
+def main(args):
+    with open(args.cfg, 'r') as f:
+        yaml_conf = yaml.load(f, Loader=yaml.FullLoader)
+    model_name = yaml_conf.get('ModelName', None)
+    if model_name == 'RVM':
+        predector_ = PredictorRVM
+    else:
+        predector_ = Predictor
+
+    if args.image_path is not None:
+        imgs_list, imgs_dir = get_image_list(args.image_path)
+        if args.trimap_path is None:
+            trimaps_list = None
+        else:
+            trimaps_list, _ = get_image_list(args.trimap_path)
+
+        if use_auto_tune(args):
+            tune_img_nums = 10
+            auto_tune(args, imgs_list, tune_img_nums)
+
+        predictor = predector_(args)
+        predictor.run(imgs=imgs_list, trimaps=trimaps_list, imgs_dir=imgs_dir)
+
+        if use_auto_tune(args) and \
+            os.path.exists(args.auto_tuned_shape_file):
+            os.remove(args.auto_tuned_shape_file)
+
+        if args.benchmark:
+            predictor.autolog.report()
+
+    elif args.video_path is not None:
+        predictor = predector_(args)
+        predictor.run_video(video_path=args.video_path)
+
+    else:
+        raise IOError("Please provide --image_path or --video_path.")
+
+
+if __name__ == '__main__':
+    args = parse_args()
+    main(args)

+ 68 - 0
docs/data_prepare_cn.md

@@ -0,0 +1,68 @@
+# 数据准备
+
+当需要对自定义的数据集进行训练的时候,需要按照相应的格式进行准备。我们提供了两种数据形式,一种为离线合成,一种为在线合成。
+
+## 离线合成
+如果图像已经实现离线合成或者不需要合成,需按照如下模型整理数据结构。
+```
+dataset_root/
+|--train/
+|  |--fg/
+|  |--alpha/
+|
+|--val/
+|  |--fg/
+|  |--alpha/
+|
+|--train.txt
+|
+|--val.txt
+```
+其中,fg目录下存放原图,另外fg目录下的图象名称需和alpha目录下的名称一一对应, 且两者的分辨率需保持一致。
+
+train.txt和val.txt的内容如下
+```
+train/fg/14299313536_ea3e61076c_o.jpg
+train/fg/14429083354_23c8fddff5_o.jpg
+train/fg/14559969490_d33552a324_o.jpg
+...
+```
+
+## 在线合成
+数据读取支持在线合成,即输入网络的原图通过已有的前景图、alpha和背景图进行在线合成。类似[Deep Image Matting](https://arxiv.org/pdf/1703.03872.pdf)论文里使用的数据集Composition-1k,则数据集应整理成如下结构:
+```
+Composition-1k/
+|--bg/
+|
+|--train/
+|  |--fg/
+|  |--alpha/
+|
+|--val/
+|  |--fg/
+|  |--alpha/
+|  |--trimap/ (如果存在)
+|
+|--train.txt
+|
+|--val.txt
+```
+
+其中,fg目录存放前景图片,bg存放背景图片。
+
+train.txt的内容如下:
+```
+train/fg/fg1.jpg bg/bg1.jpg
+train/fg/fg2.jpg bg/bg2.jpg
+train/fg/fg3.jpg bg/bg3.jpg
+...
+```
+其中第一列为前景图像,第二列为背景图。
+
+val.txt的内容如下, 如果不存在对应的trimap,则第三列可不提供,代码将会自动生成。
+```
+val/fg/fg1.jpg bg/bg1.jpg val/trimap/trimap1.jpg
+val/fg/fg2.jpg bg/bg2.jpg val/trimap/trimap2.jpg
+val/fg/fg3.jpg bg/bg3.jpg val/trimap/trimap3.jpg
+...
+```

+ 72 - 0
docs/data_prepare_en.md

@@ -0,0 +1,72 @@
+# Dataset Preparation
+
+When training for custom dataset, you need to prepare dataset in the appropriate format.
+We provide two forms of dataset, one for offline composition and one for online composition.
+
+## Offline composition
+If the images have been composited offline or do not need to be composited, the dataset should be organized as follows.
+```
+dataset_root/
+|--train/
+|  |--fg/
+|  |--alpha/
+|
+|--val/
+|  |--fg/
+|  |--alpha/
+|
+|--train.txt
+|
+|--val.txt
+```
+where, fg folder stores the original images. The image name in the fg folder must correspond to that in the alpha folder one by one,
+and the resolution must be the same for the correspond image in the two folders.
+
+train.txt and val.txt contents are as follows.
+```
+train/fg/14299313536_ea3e61076c_o.jpg
+train/fg/14429083354_23c8fddff5_o.jpg
+train/fg/14559969490_d33552a324_o.jpg
+...
+```
+
+## Online composition
+Data reading support online composition, that is, the image input network composited online by the foreground, alpha, and background images,
+which like the Composition-1k dataset used in [Deep Image Matting](https://arxiv.org/pdf/1703.03872.pdf) .
+The dataset should be organized as follows:
+```
+Composition-1k/
+|--bg/
+|
+|--train/
+|  |--fg/
+|  |--alpha/
+|
+|--val/
+|  |--fg/
+|  |--alpha/
+|  |--trimap/ (如果存在)
+|
+|--train.txt
+|
+|--val.txt
+```
+
+where, the fg folder stores the foreground images and the bg folder stores the background images.
+
+The contents of train.txt is as follows:
+```
+train/fg/fg1.jpg bg/bg1.jpg
+train/fg/fg2.jpg bg/bg2.jpg
+train/fg/fg3.jpg bg/bg3.jpg
+...
+```
+where, the first column is the foreground images and the second column is the background images.
+
+The contents of val.txt are as follows. If trimap does not exist in dataset, the third column is not needed and the code will generate trimap automatically.
+```
+val/fg/fg1.jpg bg/bg1.jpg val/trimap/trimap1.jpg
+val/fg/fg2.jpg bg/bg2.jpg val/trimap/trimap2.jpg
+val/fg/fg3.jpg bg/bg3.jpg val/trimap/trimap3.jpg
+...
+```

+ 270 - 0
docs/full_develop_cn.md

@@ -0,0 +1,270 @@
+# 全流程开发
+
+## 目录
+* [环境配置](#环境配置)
+* [数据集准备](#数据集准备)
+* [模型选择](#模型选择)
+* [训练](#训练)
+* [评估](#评估)
+* [预测](#预测)
+* [背景替换](#背景替换)
+* [导出部署](#导出部署)
+
+## 环境配置
+
+#### 1. 安装PaddlePaddle
+
+版本要求
+
+* PaddlePaddle >= 2.0.2
+
+* Python >= 3.7+
+
+由于图像抠图模型计算开销大,推荐在GPU版本的PaddlePaddle下使用。
+推荐安装10.0以上的CUDA环境。安装教程请见[PaddlePaddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)。
+
+#### 2. 下载PaddleSeg仓库
+
+```shell
+git clone https://github.com/PaddlePaddle/PaddleSeg
+```
+
+#### 3. 安装
+
+```shell
+cd PaddleSeg/Matting
+pip install -r requirements.txt
+```
+
+
+## 数据集准备
+
+利用MODNet开源的[PPM-100](https://github.com/ZHKKKe/PPM)数据集作为我们教程的示例数据集。自定已数据集请参考[数据集准备](data_prepare_cn.md)。
+
+
+下载已经准备好的PPM-100数据集:
+```shell
+mkdir data && cd data
+wget https://paddleseg.bj.bcebos.com/matting/datasets/PPM-100.zip
+unzip PPM-100.zip
+cd ..
+```
+
+数据集结构目录如下:
+
+```
+PPM-100/
+|--train/
+|  |--fg/
+|  |--alpha/
+|
+|--val/
+|  |--fg/
+|  |--alpha
+|
+|--train.txt
+|
+|--val.txt
+```
+
+**注意** : 该数据集仅仅作为教程演示,无法利用其训练得到一个收敛的模型。
+
+## 模型选择
+
+Matting项目支持配置化直接驱动,模型配置文件均放置于[configs](../configs/)目录下,大家可根据实际情况选择相应的配置文件进行训练、预测等流程。Trimap-based类方法(DIM)暂不支持处理视频。
+
+该教程中使用[configs/quick_start/ppmattingv2-stdc1-human_512.yml](../configs/quick_start/ppmattingv2-stdc1-human_512.yml)模型配置文件进行教学演示。
+
+
+## 训练
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/train.py \
+       --config configs/quick_start/ppmattingv2-stdc1-human_512.yml \
+       --do_eval \
+       --use_vdl \
+       --save_interval 500 \
+       --num_workers 5 \
+       --save_dir output
+```
+
+**note:** 使用--do_eval会影响训练速度及增加显存消耗,根据需求进行开闭。
+打开的时候会根据SAD保存历史最佳模型到`{save_dir}/best_model`下面,同时会在该目录下生成`best_sad.txt`记录下此时各个指标信息及iter.
+
+`--num_workers` 多进程数据读取,加快数据预处理速度
+
+更多参数信息请运行如下命令进行查看:
+```shell
+python tools/train.py --help
+```
+如需使用多卡,请用`python -m paddle.distributed.launch`进行启动
+
+## 微调
+如果想利用预训练模型进行微调(finetune),可以在配置文件中添加model.pretained字段,内容为预训练模型权重文件的URL地址或本地路径。下面以使用官方提供的PP-MattingV2模型进行微调为例进行说明。
+
+首先进行预训练模型的下载。
+下载[模型库](../README_CN.md/#模型库)中的预训练模型并放置于pretrained_models目录下。
+```shell
+mkdir pretrained_models && cd pretrained_models
+wget https://paddleseg.bj.bcebos.com/matting/models/ppmattingv2-stdc1-human_512.pdparams
+cd ..
+```
+然后修改配置文件中的`train_dataset.dataset_root`、`val_dataset.dataset_root`、`model.pretrained`等字段,可适当降低学习率,其余字段保持不变即可。
+```yaml
+train_dataset:
+  type: MattingDataset
+  dataset_root: path/to/your/dataset # 自定义数据集路径
+  mode: train
+
+val_dataset:
+  type: MattingDataset
+  dataset_root: path/to/your/dataset # 自定义数据集路径
+  mode: val
+
+model:
+  type: PPMattingV2
+  backbone:
+    type: STDC1
+    pretrained: https://bj.bcebos.com/paddleseg/dygraph/PP_STDCNet1.tar.gz
+  decoder_channels: [128, 96, 64, 32, 16]
+  head_channel: 8
+  dpp_output_channel: 256
+  dpp_merge_type: add
+  pretrained: pretrained_models/ppmattingv2-stdc1-human_512.pdparams # 刚刚下载的预训练模型文件
+lr_scheduler:
+  type: PolynomialDecay
+  learning_rate: 0.001  # 可适当降低学习率
+  end_lr: 0
+  power: 0.9
+  warmup_iters: 1000
+  warmup_start_lr: 1.0e-5
+```
+接下来即可参考`训练`章节内容进行模型微调训练。
+
+## 评估
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/val.py \
+       --config configs/quick_start/ppmattingv2-stdc1-human_512.yml \
+       --model_path output/best_model/model.pdparams \
+       --save_dir ./output/results \
+       --save_results
+```
+`--save_result` 开启会保留图片的预测结果,可选择关闭以加快评估速度。
+
+你可以直接下载我们提供的模型进行评估。
+
+更多参数信息请运行如下命令进行查看:
+```shell
+python tools/val.py --help
+```
+
+## 预测
+
+### 图像预测
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/predict.py \
+    --config configs/quick_start/ppmattingv2-stdc1-human_512.yml \
+    --model_path output/best_model/model.pdparams \
+    --image_path data/PPM-100/val/fg/ \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+如模型需要trimap信息,需要通过`--trimap_path`传入trimap路径。
+
+`--fg_estimate False` 可关闭前景估计功能,可提升预测速度,但图像质量会有所降低
+
+你可以直接下载我们提供的模型进行预测。
+
+更多参数信息请运行如下命令进行查看:
+```shell
+python tools/predict.py --help
+```
+
+### 视频预测
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/predict_video.py \
+    --config configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml \
+    --model_path output/best_model/model.pdparams \
+    --video_path path/to/video \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+
+
+## 背景替换
+### 图像背景替换
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/bg_replace.py \
+    --config configs/quick_start/ppmattingv2-stdc1-human_512.yml \
+    --model_path output/best_model/model.pdparams \
+    --image_path path/to/your/image \
+    --background path/to/your/background/image \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+如模型需要trimap信息,需要通过`--trimap_path`传入trimap路径。
+
+`--background`可以传入背景图片路劲,或选择('r','g','b','w')中的一种,代表红,绿,蓝,白背景, 若不提供则采用绿色作为背景。
+
+`--fg_estimate False` 可关闭前景估计功能,可提升预测速度,但图像质量会有所降低
+
+**注意:** `--image_path`必须是一张图片的具体路径。
+
+你可以直接下载我们提供的模型进行背景替换。
+
+更多参数信息请运行如下命令进行查看:
+```shell
+python tools/bg_replace.py --help
+```
+### 视频背景替换
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/bg_replace_video.py \
+    --config configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml \
+    --model_path output/best_model/model.pdparams \
+    --video_path path/to/video \
+    --background 'g' \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+
+
+## 导出部署
+### 模型导出
+```shell
+python tools/export.py \
+    --config configs/quick_start/ppmattingv2-stdc1-human_512.yml \
+    --model_path output/best_model/model.pdparams \
+    --save_dir output/export \
+    --input_shape 1 3 512 512
+```
+如果模型(比如:DIM)需要trimap的输入,需要增加参数`--trimap`
+
+更多参数信息请运行如下命令进行查看:
+```shell
+python tools/export.py --help
+```
+
+### 应用部署
+```shell
+python deploy/python/infer.py \
+    --config output/export/deploy.yaml \
+    --image_path data/PPM-100/val/fg/ \
+    --save_dir output/results \
+    --fg_estimate True
+```
+如模型需要trimap信息,需要通过`--trimap_path`传入trimap路径。
+
+`--fg_estimate False` 可关闭前景估计功能,可提升预测速度,但图像质量会有所降低
+
+`--video_path` 传入视频路径,可进行视频抠图
+
+更多参数信息请运行如下命令进行查看:
+```shell
+python deploy/python/infer.py --help
+```

+ 269 - 0
docs/full_develop_en.md

@@ -0,0 +1,269 @@
+# Full Development
+
+## Contents
+* [Installation](#Installation)
+* [Dataset preparation](#Dataset-preparation)
+* [Model selection](#Model-selection)
+* [Training](#Training)
+* [Evaluation](#Evaluation)
+* [Prediction](#Prediction)
+* [Background Replacement](#Background-Replacement)
+* [Export and Deployment](#Export-and-Deployment)
+
+## Installation
+
+#### 1. Install PaddlePaddle
+
+Versions
+
+* PaddlePaddle >= 2.0.2
+
+* Python >= 3.7+
+
+Due to the high computational cost of model, PaddleSeg is recommended for GPU version PaddlePaddle.
+CUDA 10.0 or later is recommended. See [PaddlePaddle official website](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html) for the installation tutorial.
+
+#### 2. Download the PaddleSeg repository
+
+```shell
+git clone https://github.com/PaddlePaddle/PaddleSeg
+```
+
+#### 3. Installation
+
+```shell
+cd PaddleSeg/Matting
+pip install -r requirements.txt
+```
+
+
+## Dataset preparation
+
+Using MODNet's open source [PPM-100](https://github.com/ZHKKKe/PPM) dataset as our demo dataset for the tutorial.
+Custom dataset refer to [dataset preparation](data_prepare_en.md).
+
+Download the prepared PPM-100 dataset.
+```shell
+mkdir data && cd data
+wget https://paddleseg.bj.bcebos.com/matting/datasets/PPM-100.zip
+unzip PPM-100.zip
+cd ..
+```
+
+The dataset structure is as follows.
+
+```
+PPM-100/
+|--train/
+|  |--fg/
+|  |--alpha/
+|
+|--val/
+|  |--fg/
+|  |--alpha
+|
+|--train.txt
+|
+|--val.txt
+```
+
+**Note** : This dataset is only used as a tutorial demonstration and cannot be trained to produce a convergent model.
+
+## Model selection
+
+The Matting project supports configurable direct drive, with model config files placed in [configs](../configs/) directory.
+You can select a config file based on the actual situation to perform training, prediction et al.
+The trimap-based methods (DIM) do not support video processing.
+
+This tutorial uses [configs/quick_start/ppmattingv2-stdc1-human_512.yml](../configs/quick_start/ppmattingv2-stdc1-human_512.yml) for teaching demonstrations.
+
+## Training
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/train.py \
+       --config configs/quick_start/ppmattingv2-stdc1-human_512.yml \
+       --do_eval \
+       --use_vdl \
+       --save_interval 500 \
+       --num_workers 5 \
+       --save_dir output
+```
+
+Using `--do_eval` will affect training speed and increase memory consumption, turning on and off according to needs.
+If opening the `--do_eval`, the historical best model will be saved to '{save_dir}/best_model' according to SAD. At the same time, 'best_sad.txt' will be generated in this directory to record the information of metrics and iter at this time.
+
+`--num_workers` Read data in multi-process mode. Speed up data preprocessing.
+
+Run the following command to view more parameters.
+```shell
+python tools/train.py --help
+```
+If you want to use multiple GPUs,please use `python -m paddle.distributed.launch` to run.
+
+## Finetune
+If you want to finetune from a pretrained model, you can set the `model.pretrained` field in config file, whose content is the URL or filepath of the pretrained model weights.Here we use the official PP-MattingV2 pretrained model for finetuning as an example.
+
+First, download the pretrained model in [Models](../README.md/#Models) to `pretrained_models`.
+```shell
+mkdir pretrained_models && cd pretrained_models
+wget https://paddleseg.bj.bcebos.com/matting/models/ppmattingv2-stdc1-human_512.pdparams
+cd ..
+```
+Then modify the `train_dataset.dataset_root`, `val_dataset.dataset_root`, `model.pretrained` fields in the config file, meanwhile the lr is recommended to be reduced, and you can leave the rest of the config file unchanged.
+```yaml
+train_dataset:
+  type: MattingDataset
+  dataset_root: path/to/your/dataset # Path to your own dataset
+  mode: train
+
+val_dataset:
+  type: MattingDataset
+  dataset_root: path/to/your/dataset # Path to your own dataset
+  mode: val
+
+model:
+  type: PPMattingV2
+  backbone:
+    type: STDC1
+    pretrained: https://bj.bcebos.com/paddleseg/dygraph/PP_STDCNet1.tar.gz
+  decoder_channels: [128, 96, 64, 32, 16]
+  head_channel: 8
+  dpp_output_channel: 256
+  dpp_merge_type: add
+  pretrained: pretrained_models/ppmattingv2-stdc1-human_512.pdparams # The pretrained model file just downloaded
+lr_scheduler:
+  type: PolynomialDecay
+  learning_rate: 0.001  # lr is recommended to be reduced
+  end_lr: 0
+  power: 0.9
+  warmup_iters: 1000
+  warmup_start_lr: 1.0e-5
+```
+Finally, you can finetune the model with your dataset following the instructions in `Training`.
+
+## Evaluation
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/val.py \
+       --config configs/quick_start/ppmattingv2-stdc1-human_512.yml \
+       --model_path output/best_model/model.pdparams \
+       --save_dir ./output/results \
+       --save_results
+```
+`--save_result` The prediction results will be saved if turn on. If it is off, it will speed up the evaluation.
+
+You can directly download the provided model for evaluation.
+
+Run the following command to view more parameters.
+```shell
+python tools/val.py --help
+```
+
+## Prediction
+### Image Prediction
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/predict.py \
+    --config configs/quick_start/ppmattingv2-stdc1-human_512.yml \
+    --model_path output/best_model/model.pdparams \
+    --image_path data/PPM-100/val/fg/ \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+If the model requires trimap information, pass the trimap path through '--trimap_path'.
+
+`--fg_estimate False` can turn off foreground estimation, which improves prediction speed but reduces image quality.
+
+You can directly download the provided model for evaluation.
+
+Run the following command to view more parameters.
+```shell
+python tools/predict.py --help
+```
+
+### Video Prediction
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/predict_video.py \
+    --config configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml \
+    --model_path output/best_model/model.pdparams \
+    --video_path path/to/video \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+
+
+## Background Replacement
+### Image Background Replacement
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/bg_replace.py \
+    --config configs/quick_start/ppmattingv2-stdc1-human_512.yml \
+    --model_path output/best_model/model.pdparams \
+    --image_path path/to/your/image \
+    --background path/to/your/background/image \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+If the model requires trimap information, pass the trimap path through `--trimap_path`.
+
+`--background` can pass a path of brackground image or select one of ('r', 'g', 'b', 'w') which represent red, green, blue and white. If it is not specified, a green background is used.
+
+`--fg_Estimate False` can turn off foreground estimation, which improves prediction speed but reduces image quality.
+
+**note:** `--image_path` must be a image path。
+
+You can directly download the provided model for background replacement.
+
+Run the following command to view more parameters.
+```shell
+python tools/bg_replace.py --help
+```
+
+### Video Background Replacement
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/bg_replace_video.py \
+    --config configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml \
+    --model_path output/best_model/model.pdparams \
+    --video_path path/to/video \
+    --background 'g' \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+
+## Export and Deployment
+### Model Export
+```shell
+python tools/export.py \
+    --config configs/quick_start/ppmattingv2-stdc1-human_512.yml \
+    --model_path output/best_model/model.pdparams \
+    --save_dir output/export \
+    --input_shape 1 3 512 512
+```
+If the model requires trimap information such as DIM, `--trimap` is need.
+
+Run the following command to view more parameters.
+```shell
+python tools/export.py --help
+```
+
+### Deployment
+```shell
+python deploy/python/infer.py \
+    --config output/export/deploy.yaml \
+    --image_path data/PPM-100/val/fg/ \
+    --save_dir output/results \
+    --fg_estimate True
+```
+If the model requires trimap information, pass the trimap path through '--trimap_path'.
+
+`--fg_Estimate False` can turn off foreground estimation, which improves prediction speed but reduces image quality.
+
+`--video_path` can pass a video path to have a video matting.
+
+Run the following command to view more parameters.
+```shell
+python deploy/python/infer.py --help
+```

+ 8 - 0
docs/online_demo_cn.md

@@ -0,0 +1,8 @@
+# 在线体验
+欢迎使用外部开发者基于PP-Matting模型开发的在线抠图应用,“[懒人抠图](https://coolseg.cn)"。除了抠图以外,该网站进一步提供了背景替换,证件照制作等功能。
+
+<p align="center">
+<img src="https://user-images.githubusercontent.com/48433081/165077834-c3191509-aeaf-45c8-b226-656174f4c152.gif" width="70%" height="70%">
+</p>
+
+**注意**:网站由外部开发者自行开发和维护,可能存在不稳定的问题。推荐大家基于代码[快速体验](./quick_start_cn.md)

+ 10 - 0
docs/online_demo_en.md

@@ -0,0 +1,10 @@
+# Online experience
+
+Welcome to the online image matting application developed by community developers based on PP-Matting model "[Non-Code Matting](https://coolseg.cn)".
+In addition to matting, the website further provides background replacement, ID photo production et al.
+
+<p align="center">
+<img src="https://user-images.githubusercontent.com/48433081/165077834-c3191509-aeaf-45c8-b226-656174f4c152.gif" width="70%" height="70%">
+</p>
+
+**Note**: The website is developed and maintained by community developers and may be unstable. It is recommended that you based on [quick start](./quick_start_en.md).

+ 116 - 0
docs/quick_start_cn.md

@@ -0,0 +1,116 @@
+# 快速体验
+
+## 环境配置
+
+#### 1. 安装PaddlePaddle
+
+版本要求
+
+* PaddlePaddle >= 2.0.2
+
+* Python >= 3.7+
+
+由于图像抠图模型计算开销大,推荐在GPU版本的PaddlePaddle下使用。
+推荐安装10.0以上的CUDA环境。安装教程请见[PaddlePaddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)。
+
+#### 2. 下载PaddleSeg仓库
+
+```shell
+git clone https://github.com/PaddlePaddle/PaddleSeg
+```
+
+#### 3. 安装
+
+```shell
+cd PaddleSeg/Matting
+pip install -r requirements.txt
+```
+
+## 下载预训练模型
+下载[模型库](../README_CN.md/#模型库)中的预训练模型并放置于pretrained_models目录下。这边以PP—MattingV2为例。
+```shell
+mkdir pretrained_models && cd pretrained_models
+wget https://paddleseg.bj.bcebos.com/matting/models/ppmattingv2-stdc1-human_512.pdparams
+cd ..
+```
+
+## 预测
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/predict.py \
+    --config configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml \
+    --model_path pretrained_models/ppmattingv2-stdc1-human_512.pdparams \
+    --image_path demo/human.jpg \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+
+预测结果如下:
+<div align="center">
+<img src="https://user-images.githubusercontent.com/30919197/201861635-0d139592-7da5-44b1-9bfa-7502d9643320.png"  width = "90%"  />
+</div>
+
+**注意**: `--config`需要与`--model_path`匹配。
+
+## 背景替换
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/bg_replace.py \
+    --config configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml \
+    --model_path pretrained_models/ppmattingv2-stdc1-human_512.pdparams \
+    --image_path demo/human.jpg \
+    --background 'g' \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+背景替换效果如下:
+<div align="center">
+<img src="https://user-images.githubusercontent.com/30919197/201861644-15dd5ccf-fb6e-4440-a731-8e7c1d464699.png"  width = "90%"  />
+</div>
+
+**注意:**
+* `--image_path`必须是一张图片的具体路径。
+* `--config`需要与`--model_path`匹配。
+* `--background`可以传入背景图片路径,或选择('r','g','b','w')中的一种,代表红,绿,蓝,白背景, 若不提供则采用绿色作为背景。
+
+
+## 视频预测
+运行如下命令进行视频预测,切记通过`--video_path`传入待预测视频
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/predict_video.py \
+    --config configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml \
+    --model_path pretrained_models/ppmattingv2-stdc1-human_512.pdparams \
+    --video_path path/to/video \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+预测结果如下:
+
+<p align="center">
+<img src="https://paddleseg.bj.bcebos.com/matting/demo/v1.gif"  height="200">  
+<img src="https://paddleseg.bj.bcebos.com/matting/demo/v1_alpha.gif"  height="200">
+<img src="https://paddleseg.bj.bcebos.com/matting/demo/v1_fg.gif"  height="200">
+</p>
+
+
+## 视频背景替换
+运行如下命令进行视频预测,切记通过`--video_path`传入待背景替换视频
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/bg_replace_video.py \
+    --config configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml \
+    --model_path pretrained_models/ppmattingv2-stdc1-human_512.pdparams \
+    --video_path path/to/video \
+    --background 'g' \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+背景替换效果如下:
+<p align="center">
+<img src="https://paddleseg.bj.bcebos.com/matting/demo/v1.gif"  height="200">  
+<img src="https://paddleseg.bj.bcebos.com/matting/demo/v1_bgv1.gif"  height="200">
+</p>
+
+**注意:**
+* `--background`可以传入背景图片路径,或背景视频路径,或选择('r','g','b','w')中的一种,代表红,绿,蓝,白背景, 若不提供则采用绿色作为背景。

+ 119 - 0
docs/quick_start_en.md

@@ -0,0 +1,119 @@
+# Quick Start
+
+## Installation
+
+#### 1. Install PaddlePaddle
+
+Versions
+
+* PaddlePaddle >= 2.0.2
+
+* Python >= 3.7+
+
+Due to the high computational cost of model, PaddleSeg is recommended for GPU version PaddlePaddle.
+CUDA 10.0 or later is recommended. See [PaddlePaddle official website](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html) for the installation tutorial.
+
+#### 2. Download the PaddleSeg repository
+
+```shell
+git clone https://github.com/PaddlePaddle/PaddleSeg
+```
+
+#### 3. Installation
+
+```shell
+cd PaddleSeg/Matting
+pip install -r requirements.txt
+```
+
+
+## Download pre-trained model
+Download the pre-trained model in [Models](../README.md/#Models) to `pretrained_models`. Take PP-MattingV2 as an example.
+```shell
+mkdir pretrained_models && cd pretrained_models
+wget https://paddleseg.bj.bcebos.com/matting/models/ppmattingv2-stdc1-human_512.pdparams
+cd ..
+```
+
+## Prediction
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/predict.py \
+    --config configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml \
+    --model_path pretrained_models/ppmattingv2-stdc1-human_512.pdparams \
+    --image_path demo/human.jpg \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+
+Prediction results are as follows:
+<div align="center">
+<img src="https://user-images.githubusercontent.com/30919197/201861635-0d139592-7da5-44b1-9bfa-7502d9643320.png"  width = "90%"  />
+</div>
+
+**Note**: `--config` needs to match `--model_path`.
+
+## Background Replacement
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/bg_replace.py \
+    --config configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml \
+    --model_path pretrained_models/ppmattingv2-stdc1-human_512.pdparams \
+    --image_path demo/human.jpg \
+    --background 'g' \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+The background replacement effect is as follows:
+<div align="center">
+<img src="https://user-images.githubusercontent.com/30919197/201861644-15dd5ccf-fb6e-4440-a731-8e7c1d464699.png"  width = "90%"  />
+</div>
+
+**Notes:**
+* `--image_path` must be the specific path of an image.
+* `--config` needs to match `--model_path`.
+* `--background` can be passed into the background image path, or one of ('r','g','b','w'), representing a red, green, blue, or white background, default green if not passed.
+
+
+## Video Prediction
+
+Run the following commad to predict the video, and remember to pass the video path by `--video_path`.
+
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/predict_video.py \
+    --config configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml \
+    --model_path pretrained_models/ppmattingv2-stdc1-human_512.pdparams \
+    --video_path path/to/video \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+Prediction results are as follows:
+
+<p align="center">
+<img src="https://paddleseg.bj.bcebos.com/matting/demo/v1.gif"  height="200">  
+<img src="https://paddleseg.bj.bcebos.com/matting/demo/v1_alpha.gif"  height="200">
+<img src="https://paddleseg.bj.bcebos.com/matting/demo/v1_fg.gif"  height="200">
+</p>
+
+
+## Video Background Replacement
+Run the following commad to replace video background, and remember to pass the video path by `--video_path`.
+```shell
+export CUDA_VISIBLE_DEVICES=0
+python tools/bg_replace_video.py \
+    --config configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml \
+    --model_path pretrained_models/ppmattingv2-stdc1-human_512.pdparams \
+    --video_path path/to/video \
+    --background 'g' \
+    --save_dir ./output/results \
+    --fg_estimate True
+```
+The background replacement effect is as follows:
+<p align="center">
+<img src="https://paddleseg.bj.bcebos.com/matting/demo/v1.gif"  height="200">  
+<img src="https://paddleseg.bj.bcebos.com/matting/demo/v1_bgv1.gif"  height="200">
+</p>
+
+**Notes:**
+* `--background` can be passed into the background image path, or background video path, or one of ('r','g','b','w'), representing a red, green, blue, or white background, default green if not passed.

+ 36 - 0
main.py

@@ -0,0 +1,36 @@
+from flask import Flask, request
+import os
+from tools import predict
+from werkzeug.utils import secure_filename
+
+app = Flask(__name__)
+
+cur_dirs = os.path.abspath(os.path.dirname(__file__))
+save_dir = os.path.join(cur_dirs, 'outputs')
+upload_dir = os.path.join(cur_dirs, "uploads")
+
+app.config['UPLOAD_FOLDER'] = upload_dir
+
+
+def get_upload_file_path(name):
+    return os.path.join(upload_dir, name)
+
+
+@app.route("/image/seg", methods=['POST'])
+def seg():
+    if 'file' not in request.files:
+        return '{}'
+    file = request.files['file']
+
+    if file.filename == '':
+        return '{}'
+
+    filename = secure_filename(file.filename)
+    file_path = get_upload_file_path(filename)
+    file.save(file_path)
+
+    predict.seg(file_path, save_dir)
+
+
+if __name__ == '__main__':
+    app.run(port=20201, host="0.0.0.0", debug=True)

BIN
models/ppmatting-hrnet_w18-human_512.pdparams


BIN
outputs/human_alpha.png


BIN
outputs/human_rgba.png


+ 1 - 0
ppmatting/__init__.py

@@ -0,0 +1 @@
+from . import ml, metrics, transforms, datasets, models, utils

+ 6 - 0
ppmatting/core/__init__.py

@@ -0,0 +1,6 @@
+from .val import evaluate
+from .val_ml import evaluate_ml
+from .train import train
+from .predict import predict
+from .predict_video import predict_video
+from .bg_replace_video import bg_replace_video

+ 217 - 0
ppmatting/core/bg_replace_video.py

@@ -0,0 +1,217 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import math
+import time
+from collections.abc import Iterable
+
+import cv2
+import numpy as np
+import paddle
+import paddle.nn.functional as F
+from paddleseg import utils
+from paddleseg.core import infer
+from paddleseg.utils import logger, progbar, TimeAverager
+
+import ppmatting.transforms as T
+from ppmatting.utils import mkdir, estimate_foreground_ml, VideoReader, VideoWriter
+
+
+def build_loader_writter(video_path, transforms, save_dir):
+    reader = VideoReader(video_path, transforms)
+    loader = paddle.io.DataLoader(reader)
+    base_name = os.path.basename(video_path)
+    name = os.path.splitext(base_name)[0]
+    save_path = os.path.join(save_dir, name + '.avi')
+
+    writer = VideoWriter(
+        save_path,
+        reader.fps,
+        frame_size=(reader.width, reader.height),
+        is_color=True)
+
+    return loader, writer
+
+
+def reverse_transform(img, trans_info):
+    """recover pred to origin shape"""
+    for item in trans_info[::-1]:
+        if item[0][0] == 'resize':
+            h, w = item[1][0], item[1][1]
+            img = F.interpolate(img, [h, w], mode='bilinear')
+        elif item[0][0] == 'padding':
+            h, w = item[1][0], item[1][1]
+            img = img[:, :, 0:h, 0:w]
+        else:
+            raise Exception("Unexpected info '{}' in im_info".format(item[0]))
+    return img
+
+
+def postprocess(fg, alpha, img, bg, trans_info, writer, fg_estimate):
+    """
+    Postprocess for prediction results.
+
+    Args:
+        fg (Tensor): The foreground, value should be in [0, 1].
+        alpha (Tensor): The alpha, value should be in [0, 1].
+        img (Tensor): The original image, value should be in [0, 1].
+        trans_info (list): A list of the shape transformations.
+        writers (dict): A dict of VideoWriter instance.
+        fg_estimate (bool): Whether to estimate foreground. It is invalid when fg is not None.
+
+    """
+    alpha = reverse_transform(alpha, trans_info)
+    bg = F.interpolate(bg, size=alpha.shape[-2:], mode='bilinear')
+    if fg is None:
+        if fg_estimate:
+            img = img.transpose((0, 2, 3, 1)).squeeze().numpy()
+            alpha = alpha.squeeze().numpy()
+            fg = estimate_foreground_ml(img, alpha)
+            bg = bg.transpose((0, 2, 3, 1)).squeeze().numpy()
+        else:
+            fg = img
+    else:
+        fg = reverse_transform(fg, trans_info)
+    if len(alpha.shape) == 2:
+        alpha = alpha[:, :, None]
+    new_img = alpha * fg + (1 - alpha) * bg
+    writer.write(new_img)
+
+
+def get_bg(bg_path, shape):
+    bg = paddle.zeros((1, 3, shape[0], shape[1]))
+    # special color
+    if bg_path == 'r':
+        bg[:, 2, :, :] = 1
+    elif bg_path == 'g':
+        bg[:, 1, :, :] = 1
+    elif bg_path == 'b':
+        bg[:, 0, :, :] = 1
+    elif bg_path == 'w':
+        bg = bg + 1
+
+    elif not os.path.exists(bg_path):
+        raise Exception('The background path is not found: {}'.format(bg_path))
+    # image
+    elif bg_path.endswith(
+        ('.JPEG', '.jpeg', '.JPG', '.jpg', '.BMP', '.bmp', '.PNG', '.png')):
+        bg = cv2.imread(bg_path)
+        bg = bg[np.newaxis, :, :, :]
+        bg = paddle.to_tensor(bg) / 255.
+        bg = bg.transpose((0, 3, 1, 2))
+
+    elif bg_path.lower().endswith(
+        ('.mp4', '.avi', '.mov', '.m4v', '.dat', '.rm', '.rmvb', '.wmv', '.asf',
+         '.asx', '.3gp', '.mkv', '.flv', '.vob')):
+        transforms = T.Compose([T.Normalize(mean=(0, 0, 0), std=(1, 1, 1))])
+        bg = VideoReader(bg_path, transforms=transforms)
+        bg = paddle.io.DataLoader(bg)
+        bg = iter(bg)
+
+    else:
+        raise IOError('The background path is invalid, please check it')
+
+    return bg
+
+
+def bg_replace_video(model,
+                     model_path,
+                     transforms,
+                     video_path,
+                     bg_path='g',
+                     save_dir='output',
+                     fg_estimate=True):
+    """
+    predict and visualize the video.
+
+    Args:
+        model (nn.Layer): Used to predict for input video.
+        model_path (str): The path of pretrained model.
+        transforms (transforms.Compose): Preprocess for frames of video.
+        video_path (str): The video path to be predicted.
+        bg_path (str): The background. It can be image path or video path or a string of (r,g,b,w). Default: 'g'.
+        save_dir (str, optional): The directory to save the visualized results. Default: 'output'.
+        fg_estimate (bool, optional): Whether to estimate foreground when predicting. It is invalid if the foreground is predicted by model. Default: True
+    """
+    utils.utils.load_entire_model(model, model_path)
+    model.eval()
+
+    # Build loader and writer for video
+    loader, writer = build_loader_writter(
+        video_path, transforms, save_dir=save_dir)
+    # Get bg
+    bg_reader = get_bg(
+        bg_path, shape=(loader.dataset.height, loader.dataset.width))
+
+    logger.info("Start to predict...")
+    progbar_pred = progbar.Progbar(target=len(loader), verbose=1)
+    preprocess_cost_averager = TimeAverager()
+    infer_cost_averager = TimeAverager()
+    postprocess_cost_averager = TimeAverager()
+    batch_start = time.time()
+    with paddle.no_grad():
+        for i, data in enumerate(loader):
+            preprocess_cost_averager.record(time.time() - batch_start)
+
+            infer_start = time.time()
+            result = model(data)  # result maybe a Tensor or a dict
+            if isinstance(result, paddle.Tensor):
+                alpha = result
+                fg = None
+            else:
+                alpha = result['alpha']
+                fg = result.get('fg', None)
+            infer_cost_averager.record(time.time() - infer_start)
+
+            # postprocess
+            postprocess_start = time.time()
+            if isinstance(bg_reader, Iterable):
+                try:
+                    bg = next(bg_reader)
+                except StopIteration:
+                    bg_reader = get_bg(
+                        bg_path,
+                        shape=(loader.dataset.height, loader.dataset.width))
+                    bg = next(bg_reader)
+                finally:
+                    bg = bg['ori_img']
+            else:
+                bg = bg_reader
+            postprocess(
+                fg,
+                alpha,
+                data['ori_img'],
+                bg=bg,
+                trans_info=data['trans_info'],
+                writer=writer,
+                fg_estimate=fg_estimate)
+            postprocess_cost_averager.record(time.time() - postprocess_start)
+
+            preprocess_cost = preprocess_cost_averager.get_average()
+            infer_cost = infer_cost_averager.get_average()
+            postprocess_cost = postprocess_cost_averager.get_average()
+            progbar_pred.update(i + 1, [('preprocess_cost', preprocess_cost),
+                                        ('infer_cost cost', infer_cost),
+                                        ('postprocess_cost', postprocess_cost)])
+
+            preprocess_cost_averager.reset()
+            infer_cost_averager.reset()
+            postprocess_cost_averager.reset()
+            batch_start = time.time()
+    if hasattr(model, 'reset'):
+        model.reset()
+    loader.dataset.release()
+    if isinstance(bg, VideoReader):
+        bg_reader.release()

+ 212 - 0
ppmatting/core/predict.py

@@ -0,0 +1,212 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import math
+import time
+
+import cv2
+import numpy as np
+import paddle
+import paddle.nn.functional as F
+from paddleseg import utils
+from paddleseg.core import infer
+from paddleseg.utils import logger, progbar, TimeAverager
+
+from ppmatting.utils import mkdir, estimate_foreground_ml
+
+
+def partition_list(arr, m):
+    """split the list 'arr' into m pieces"""
+    n = int(math.ceil(len(arr) / float(m)))
+    return [arr[i:i + n] for i in range(0, len(arr), n)]
+
+
+def save_result(alpha, path, im_path, trimap=None, fg_estimate=True, fg=None):
+    """
+    Save alpha and rgba.
+
+    Args:
+        alpha (numpy.ndarray): The value of alpha should in [0, 255], shape should be [h,w].
+        path (str): The save path
+        im_path (str): The original image path.
+        trimap (str, optional): The trimap if provided. Default: None.
+        fg_estimate (bool, optional): Whether to estimate the foreground, Default: True.
+        fg (numpy.ndarray, optional): The foreground, if provided, fg_estimate is invalid. Default: None.
+    """
+    dirname = os.path.dirname(path)
+    if not os.path.exists(dirname):
+        os.makedirs(dirname)
+    basename = os.path.basename(path)
+    name = os.path.splitext(basename)[0]
+    alpha_save_path = os.path.join(dirname, name + '_alpha.png')
+    rgba_save_path = os.path.join(dirname, name + '_rgba.png')
+
+    # save alpha matte
+    if trimap is not None:
+        trimap = cv2.imread(trimap, 0)
+        alpha[trimap == 0] = 0
+        alpha[trimap == 255] = 255
+    alpha = (alpha).astype('uint8')
+    cv2.imwrite(alpha_save_path, alpha)
+
+    # save rgba
+    im = cv2.imread(im_path)
+    if fg is None:
+        if fg_estimate:
+            fg = estimate_foreground_ml(im / 255.0, alpha / 255.0) * 255
+        else:
+            fg = im
+    fg = fg.astype('uint8')
+    alpha = alpha[:, :, np.newaxis]
+    rgba = np.concatenate((fg, alpha), axis=-1)
+    cv2.imwrite(rgba_save_path, rgba)
+
+    return fg
+
+
+def reverse_transform(img, trans_info):
+    """recover pred to origin shape"""
+    for item in trans_info[::-1]:
+        if item[0] == 'resize':
+            h, w = item[1][0], item[1][1]
+            img = F.interpolate(img, [h, w], mode='bilinear')
+        elif item[0] == 'padding':
+            h, w = item[1][0], item[1][1]
+            img = img[:, :, 0:h, 0:w]
+        else:
+            raise Exception("Unexpected info '{}' in im_info".format(item[0]))
+    return img
+
+
+def preprocess(img, transforms, trimap=None):
+    data = {}
+    data['img'] = img
+    if trimap is not None:
+        data['trimap'] = trimap
+        data['gt_fields'] = ['trimap']
+    data['trans_info'] = []
+    data = transforms(data)
+    data['img'] = paddle.to_tensor(data['img'])
+    data['img'] = data['img'].unsqueeze(0)
+    if trimap is not None:
+        data['trimap'] = paddle.to_tensor(data['trimap'])
+        data['trimap'] = data['trimap'].unsqueeze((0, 1))
+
+    return data
+
+
+def predict(model,
+            model_path,
+            transforms,
+            image_list,
+            image_dir=None,
+            trimap_list=None,
+            save_dir='output',
+            fg_estimate=True):
+    """
+    predict and visualize the image_list.
+
+    Args:
+        model (nn.Layer): Used to predict for input image.
+        model_path (str): The path of pretrained model.
+        transforms (transforms.Compose): Preprocess for input image.
+        image_list (list): A list of image path to be predicted.
+        image_dir (str, optional): The root directory of the images predicted. Default: None.
+        trimap_list (list, optional): A list of trimap of image_list. Default: None.
+        save_dir (str, optional): The directory to save the visualized results. Default: 'output'.
+    """
+    utils.utils.load_entire_model(model, model_path)
+    model.eval()
+    nranks = paddle.distributed.get_world_size()
+    local_rank = paddle.distributed.get_rank()
+    if nranks > 1:
+        img_lists = partition_list(image_list, nranks)
+        trimap_lists = partition_list(
+            trimap_list, nranks) if trimap_list is not None else None
+    else:
+        img_lists = [image_list]
+        trimap_lists = [trimap_list] if trimap_list is not None else None
+
+    logger.info("Start to predict...")
+    progbar_pred = progbar.Progbar(target=len(img_lists[0]), verbose=1)
+    preprocess_cost_averager = TimeAverager()
+    infer_cost_averager = TimeAverager()
+    postprocess_cost_averager = TimeAverager()
+    batch_start = time.time()
+    with paddle.no_grad():
+        for i, im_path in enumerate(img_lists[local_rank]):
+            preprocess_start = time.time()
+            trimap = trimap_lists[local_rank][
+                i] if trimap_list is not None else None
+            data = preprocess(img=im_path, transforms=transforms, trimap=trimap)
+            preprocess_cost_averager.record(time.time() - preprocess_start)
+
+            infer_start = time.time()
+            result = model(data)
+            infer_cost_averager.record(time.time() - infer_start)
+
+            postprocess_start = time.time()
+            if isinstance(result, paddle.Tensor):
+                alpha = result
+                fg = None
+            else:
+                alpha = result['alpha']
+                fg = result.get('fg', None)
+
+            alpha = reverse_transform(alpha, data['trans_info'])
+            alpha = (alpha.numpy()).squeeze()
+            alpha = (alpha * 255).astype('uint8')
+            if fg is not None:
+                fg = reverse_transform(fg, data['trans_info'])
+                fg = (fg.numpy()).squeeze().transpose((1, 2, 0))
+                fg = (fg * 255).astype('uint8')
+
+            # get the saved name
+            if image_dir is not None:
+                im_file = im_path.replace(image_dir, '')
+            else:
+                im_file = os.path.basename(im_path)
+            if im_file[0] == '/' or im_file[0] == '\\':
+                im_file = im_file[1:]
+
+            save_path = os.path.join(save_dir, im_file)
+            mkdir(save_path)
+            fg = save_result(
+                alpha,
+                save_path,
+                im_path=im_path,
+                trimap=trimap,
+                fg_estimate=fg_estimate,
+                fg=fg)
+
+            # rvm have member which need to reset.
+            if hasattr(model, 'reset'):
+                model.reset()
+
+            postprocess_cost_averager.record(time.time() - postprocess_start)
+
+            preprocess_cost = preprocess_cost_averager.get_average()
+            infer_cost = infer_cost_averager.get_average()
+            postprocess_cost = postprocess_cost_averager.get_average()
+            if local_rank == 0:
+                progbar_pred.update(i + 1,
+                                    [('preprocess_cost', preprocess_cost),
+                                     ('infer_cost cost', infer_cost),
+                                     ('postprocess_cost', postprocess_cost)])
+
+            preprocess_cost_averager.reset()
+            infer_cost_averager.reset()
+            postprocess_cost_averager.reset()
+    return alpha, fg

+ 168 - 0
ppmatting/core/predict_video.py

@@ -0,0 +1,168 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import math
+import time
+
+import cv2
+import numpy as np
+import paddle
+import paddle.nn.functional as F
+from paddleseg import utils
+from paddleseg.core import infer
+from paddleseg.utils import logger, progbar, TimeAverager
+
+from ppmatting.utils import mkdir, estimate_foreground_ml, VideoReader, VideoWriter
+
+
+def build_loader_writter(video_path, transforms, save_dir):
+    reader = VideoReader(video_path, transforms)
+    loader = paddle.io.DataLoader(reader)
+    base_name = os.path.basename(video_path)
+    name = os.path.splitext(base_name)[0]
+    alpha_save_path = os.path.join(save_dir, name + '_alpha.avi')
+    fg_save_path = os.path.join(save_dir, name + '_fg.avi')
+
+    writer_alpha = VideoWriter(
+        alpha_save_path,
+        reader.fps,
+        frame_size=(reader.width, reader.height),
+        is_color=False)
+    writer_fg = VideoWriter(
+        fg_save_path,
+        reader.fps,
+        frame_size=(reader.width, reader.height),
+        is_color=True)
+    writers = {'alpha': writer_alpha, 'fg': writer_fg}
+
+    return loader, writers
+
+
+def reverse_transform(img, trans_info):
+    """recover pred to origin shape"""
+    for item in trans_info[::-1]:
+        if item[0][0] == 'resize':
+            h, w = item[1][0], item[1][1]
+            img = F.interpolate(img, [h, w], mode='bilinear')
+        elif item[0][0] == 'padding':
+            h, w = item[1][0], item[1][1]
+            img = img[:, :, 0:h, 0:w]
+        else:
+            raise Exception("Unexpected info '{}' in im_info".format(item[0]))
+    return img
+
+
+def postprocess(fg, alpha, img, trans_info, writers, fg_estimate):
+    """
+    Postprocess for prediction results.
+
+    Args:
+        fg (Tensor): The foreground, value should be in [0, 1].
+        alpha (Tensor): The alpha, value should be in [0, 1].
+        img (Tensor): The original image, value should be in [0, 1].
+        trans_info (list): A list of the shape transformations.
+        writers (dict): A dict of VideoWriter instance.
+        fg_estimate (bool): Whether to estimate foreground. It is invalid when fg is not None.
+
+    """
+    alpha = reverse_transform(alpha, trans_info)
+    if fg is None:
+        if fg_estimate:
+            img = img.transpose((0, 2, 3, 1)).squeeze().numpy()
+            alpha = alpha.squeeze().numpy()
+            fg = estimate_foreground_ml(img, alpha)
+        else:
+            fg = img
+    else:
+        fg = reverse_transform(fg, trans_info)
+
+    if len(alpha.shape) == 2:
+        fg = alpha[:, :, None] * fg
+    else:
+        fg = alpha * fg
+    writers['alpha'].write(alpha)
+    writers['fg'].write(fg)
+
+
+def predict_video(model,
+                  model_path,
+                  transforms,
+                  video_path,
+                  save_dir='output',
+                  fg_estimate=True):
+    """
+    predict and visualize the video.
+
+    Args:
+        model (nn.Layer): Used to predict for input video.
+        model_path (str): The path of pretrained model.
+        transforms (transforms.Compose): Preprocess for frames of video.
+        video_path (str): the video path to be predicted.
+        save_dir (str, optional): The directory to save the visualized results. Default: 'output'.
+        fg_estimate (bool, optional): Whether to estimate foreground when predicting. It is invalid if the foreground is predicted by model. Default: True
+    """
+    utils.utils.load_entire_model(model, model_path)
+    model.eval()
+
+    # Build loader and writer for video
+    loader, writers = build_loader_writter(
+        video_path, transforms, save_dir=save_dir)
+
+    logger.info("Start to predict...")
+    progbar_pred = progbar.Progbar(target=len(loader), verbose=1)
+    preprocess_cost_averager = TimeAverager()
+    infer_cost_averager = TimeAverager()
+    postprocess_cost_averager = TimeAverager()
+    batch_start = time.time()
+    with paddle.no_grad():
+        for i, data in enumerate(loader):
+            preprocess_cost_averager.record(time.time() - batch_start)
+
+            infer_start = time.time()
+            result = model(data)  # result maybe a Tensor or a dict
+            if isinstance(result, paddle.Tensor):
+                alpha = result
+                fg = None
+            else:
+                alpha = result['alpha']
+                fg = result.get('fg', None)
+            infer_cost_averager.record(time.time() - infer_start)
+
+            postprocess_start = time.time()
+            postprocess(
+                fg,
+                alpha,
+                data['ori_img'],
+                trans_info=data['trans_info'],
+                writers=writers,
+                fg_estimate=fg_estimate)
+            postprocess_cost_averager.record(time.time() - postprocess_start)
+
+            preprocess_cost = preprocess_cost_averager.get_average()
+            infer_cost = infer_cost_averager.get_average()
+            postprocess_cost = postprocess_cost_averager.get_average()
+            progbar_pred.update(i + 1, [('preprocess_cost', preprocess_cost),
+                                        ('infer_cost cost', infer_cost),
+                                        ('postprocess_cost', postprocess_cost)])
+
+            preprocess_cost_averager.reset()
+            infer_cost_averager.reset()
+            postprocess_cost_averager.reset()
+            batch_start = time.time()
+    if hasattr(model, 'reset'):
+        model.reset()
+    loader.dataset.release()
+    for k, v in writers.items():
+        v.release()

+ 353 - 0
ppmatting/core/train.py

@@ -0,0 +1,353 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+from collections import deque, defaultdict
+import pickle
+import shutil
+
+import numpy as np
+import paddle
+import paddle.nn.functional as F
+from paddleseg.utils import TimeAverager, calculate_eta, resume, logger, train_profiler
+
+from .val import evaluate
+
+
+def visual_in_traning(log_writer, vis_dict, step):
+    """
+    Visual in vdl
+
+    Args:
+        log_writer (LogWriter): The log writer of vdl.
+        vis_dict (dict): Dict of tensor. The shape of thesor is (C, H, W)
+    """
+    for key, value in vis_dict.items():
+        value_shape = value.shape
+        if value_shape[0] not in [1, 3]:
+            value = value[0]
+            value = value.unsqueeze(0)
+        value = paddle.transpose(value, (1, 2, 0))
+        min_v = paddle.min(value)
+        max_v = paddle.max(value)
+        if (min_v > 0) and (max_v < 1):
+            value = value * 255
+        elif (min_v < 0 and min_v >= -1) and (max_v <= 1):
+            value = (1 + value) / 2 * 255
+        else:
+            value = (value - min_v) / (max_v - min_v) * 255
+
+        value = value.astype('uint8')
+        value = value.numpy()
+        log_writer.add_image(tag=key, img=value, step=step)
+
+
+def save_best(best_model_dir, metrics_data, iter):
+    with open(os.path.join(best_model_dir, 'best_metrics.txt'), 'w') as f:
+        for key, value in metrics_data.items():
+            line = key + ' ' + str(value) + '\n'
+            f.write(line)
+        f.write('iter' + ' ' + str(iter) + '\n')
+
+
+def get_best(best_file, metrics, resume_model=None):
+    '''Get best metrics and iter from file'''
+    best_metrics_data = {}
+    if os.path.exists(best_file) and (resume_model is not None):
+        values = []
+        with open(best_file, 'r') as f:
+            lines = f.readlines()
+            for line in lines:
+                line = line.strip()
+                key, value = line.split(' ')
+                best_metrics_data[key] = eval(value)
+                if key == 'iter':
+                    best_iter = eval(value)
+    else:
+        for key in metrics:
+            best_metrics_data[key] = np.inf
+        best_iter = -1
+    return best_metrics_data, best_iter
+
+
+def train(model,
+          train_dataset,
+          val_dataset=None,
+          optimizer=None,
+          save_dir='output',
+          iters=10000,
+          batch_size=2,
+          resume_model=None,
+          save_interval=1000,
+          log_iters=10,
+          log_image_iters=1000,
+          num_workers=0,
+          use_vdl=False,
+          losses=None,
+          keep_checkpoint_max=5,
+          eval_begin_iters=None,
+          metrics='sad',
+          precision='fp32',
+          amp_level='O1',
+          profiler_options=None):
+    """
+    Launch training.
+    Args:
+        model(nn.Layer): A matting model.
+        train_dataset (paddle.io.Dataset): Used to read and process training datasets.
+        val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
+        optimizer (paddle.optimizer.Optimizer): The optimizer.
+        save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
+        iters (int, optional): How may iters to train the model. Defualt: 10000.
+        batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
+        resume_model (str, optional): The path of resume model.
+        save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
+        log_iters (int, optional): Display logging information at every log_iters. Default: 10.
+        log_image_iters (int, optional): Log image to vdl. Default: 1000.
+        num_workers (int, optional): Num workers for data loader. Default: 0.
+        use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
+        losses (dict, optional): A dict of loss, refer to the loss function of the model for details. Default: None.
+        keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
+        eval_begin_iters (int): The iters begin evaluation. It will evaluate at iters/2 if it is None. Defalust: None.
+        metrics(str|list, optional): The metrics to evaluate, it may be the combination of ("sad", "mse", "grad", "conn"). 
+        precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the training is normal.
+        amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, 
+            the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators 
+            parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp)
+        profiler_options (str, optional): The option of train profiler.
+    """
+    model.train()
+    nranks = paddle.distributed.ParallelEnv().nranks
+    local_rank = paddle.distributed.ParallelEnv().local_rank
+
+    start_iter = 0
+    if resume_model is not None:
+        start_iter = resume(model, optimizer, resume_model)
+
+    if not os.path.isdir(save_dir):
+        if os.path.exists(save_dir):
+            os.remove(save_dir)
+        os.makedirs(save_dir)
+
+    # Use amp
+    if precision == 'fp16':
+        logger.info('use AMP to train. AMP level = {}'.format(amp_level))
+        scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
+        if amp_level == 'O2':
+            model, optimizer = paddle.amp.decorate(
+                models=model,
+                optimizers=optimizer,
+                level='O2',
+                save_dtype='float32')
+
+    if nranks > 1:
+        # Initialize parallel environment if not done.
+        if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
+        ):
+            paddle.distributed.init_parallel_env()
+            ddp_model = paddle.DataParallel(model)
+        else:
+            ddp_model = paddle.DataParallel(model)
+
+    batch_sampler = paddle.io.DistributedBatchSampler(
+        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
+
+    loader = paddle.io.DataLoader(
+        train_dataset,
+        batch_sampler=batch_sampler,
+        num_workers=num_workers,
+        return_list=True, )
+
+    if use_vdl:
+        from visualdl import LogWriter
+        log_writer = LogWriter(save_dir)
+
+    if isinstance(metrics, str):
+        metrics = [metrics]
+    elif not isinstance(metrics, list):
+        metrics = ['sad']
+    best_metrics_data, best_iter = get_best(
+        os.path.join(save_dir, 'best_model', 'best_metrics.txt'),
+        metrics,
+        resume_model=resume_model)
+    avg_loss = defaultdict(float)
+    iters_per_epoch = len(batch_sampler)
+    reader_cost_averager = TimeAverager()
+    batch_cost_averager = TimeAverager()
+    save_models = deque()
+    batch_start = time.time()
+
+    iter = start_iter
+    while iter < iters:
+        for data in loader:
+            iter += 1
+            if iter > iters:
+                break
+            reader_cost_averager.record(time.time() - batch_start)
+
+            if precision == 'fp16':
+                with paddle.amp.auto_cast(
+                        level=amp_level,
+                        enable=True,
+                        custom_white_list={
+                            "elementwise_add", "batch_norm", "sync_batch_norm"
+                        },
+                        custom_black_list={'bilinear_interp_v2', 'pad3d'}):
+                    logit_dict, loss_dict = ddp_model(
+                        data) if nranks > 1 else model(data)
+
+                scaled = scaler.scale(loss_dict['all'])  # scale the loss
+                scaled.backward()  # do backward
+                scaler.minimize(optimizer, scaled)  # update parameters
+            else:
+                logit_dict, loss_dict = ddp_model(
+                    data) if nranks > 1 else model(data)
+                loss_dict['all'].backward()
+                optimizer.step()
+
+            lr = optimizer.get_lr()
+            if isinstance(optimizer._learning_rate,
+                          paddle.optimizer.lr.LRScheduler):
+                optimizer._learning_rate.step()
+
+            train_profiler.add_profiler_step(profiler_options)
+
+            model.clear_gradients()
+
+            for key, value in loss_dict.items():
+                avg_loss[key] += float(value)
+            batch_cost_averager.record(
+                time.time() - batch_start, num_samples=batch_size)
+
+            if (iter) % log_iters == 0 and local_rank == 0:
+                for key, value in avg_loss.items():
+                    avg_loss[key] = value / log_iters
+                remain_iters = iters - iter
+                avg_train_batch_cost = batch_cost_averager.get_average()
+                avg_train_reader_cost = reader_cost_averager.get_average()
+                eta = calculate_eta(remain_iters, avg_train_batch_cost)
+                # loss info
+                loss_str = ' ' * 26 + '\t[LOSSES]'
+                loss_str = loss_str
+                for key, value in avg_loss.items():
+                    if key != 'all':
+                        loss_str = loss_str + ' ' + key + '={:.4f}'.format(
+                            value)
+                logger.info(
+                    "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f}, ips={:.4f} samples/sec | ETA {}\n{}\n"
+                    .format((iter - 1) // iters_per_epoch + 1, iter, iters,
+                            avg_loss['all'], lr, avg_train_batch_cost,
+                            avg_train_reader_cost,
+                            batch_cost_averager.get_ips_average(
+                            ), eta, loss_str))
+                if use_vdl:
+                    for key, value in avg_loss.items():
+                        log_tag = 'Train/' + key
+                        log_writer.add_scalar(log_tag, value, iter)
+
+                    log_writer.add_scalar('Train/lr', lr, iter)
+                    log_writer.add_scalar('Train/batch_cost',
+                                          avg_train_batch_cost, iter)
+                    log_writer.add_scalar('Train/reader_cost',
+                                          avg_train_reader_cost, iter)
+                    if iter % log_image_iters == 0:
+                        vis_dict = {}
+                        # ground truth
+                        vis_dict['ground truth/img'] = data['img'][0]
+                        for key in data['gt_fields']:
+                            key = key[0]
+                            vis_dict['/'.join(['ground truth', key])] = data[
+                                key][0]
+                        # predict
+                        for key, value in logit_dict.items():
+                            vis_dict['/'.join(['predict', key])] = logit_dict[
+                                key][0]
+                        visual_in_traning(
+                            log_writer=log_writer, vis_dict=vis_dict, step=iter)
+
+                for key in avg_loss.keys():
+                    avg_loss[key] = 0.
+                reader_cost_averager.reset()
+                batch_cost_averager.reset()
+
+            # save model
+            if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
+                current_save_dir = os.path.join(save_dir,
+                                                "iter_{}".format(iter))
+                if not os.path.isdir(current_save_dir):
+                    os.makedirs(current_save_dir)
+                paddle.save(model.state_dict(),
+                            os.path.join(current_save_dir, 'model.pdparams'))
+                paddle.save(optimizer.state_dict(),
+                            os.path.join(current_save_dir, 'model.pdopt'))
+                save_models.append(current_save_dir)
+                if len(save_models) > keep_checkpoint_max > 0:
+                    model_to_remove = save_models.popleft()
+                    shutil.rmtree(model_to_remove)
+
+            # eval model
+            if eval_begin_iters is None:
+                eval_begin_iters = iters // 2
+            if (iter % save_interval == 0 or iter == iters) and (
+                    val_dataset is not None
+            ) and local_rank == 0 and iter >= eval_begin_iters:
+                num_workers = 1 if num_workers > 0 else 0
+                metrics_data = evaluate(
+                    model,
+                    val_dataset,
+                    num_workers=1,
+                    print_detail=True,
+                    save_results=False,
+                    metrics=metrics,
+                    precision=precision,
+                    amp_level=amp_level)
+                model.train()
+
+            # save best model and add evaluation results to vdl
+            if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
+                if val_dataset is not None and iter >= eval_begin_iters:
+                    if metrics_data[metrics[0]] < best_metrics_data[metrics[0]]:
+                        best_iter = iter
+                        best_metrics_data = metrics_data.copy()
+                        best_model_dir = os.path.join(save_dir, "best_model")
+                        paddle.save(
+                            model.state_dict(),
+                            os.path.join(best_model_dir, 'model.pdparams'))
+                        save_best(best_model_dir, best_metrics_data, iter)
+
+                    show_list = []
+                    for key, value in best_metrics_data.items():
+                        show_list.append((key, value))
+                    log_str = '[EVAL] The model with the best validation {} ({:.4f}) was saved at iter {}.'.format(
+                        show_list[0][0], show_list[0][1], best_iter)
+                    if len(show_list) > 1:
+                        log_str += " While"
+                        for i in range(1, len(show_list)):
+                            log_str = log_str + ' {}: {:.4f},'.format(
+                                show_list[i][0], show_list[i][1])
+                        log_str = log_str[:-1]
+                    logger.info(log_str)
+
+                    if use_vdl:
+                        for key, value in metrics_data.items():
+                            log_writer.add_scalar('Evaluate/' + key, value,
+                                                  iter)
+
+            batch_start = time.time()
+
+    # Sleep for half a second to let dataloader release resources.
+    time.sleep(0.5)
+    if use_vdl:
+        log_writer.close()

+ 176 - 0
ppmatting/core/val.py

@@ -0,0 +1,176 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import cv2
+import numpy as np
+import time
+import paddle
+import paddle.nn.functional as F
+from paddleseg.utils import TimeAverager, calculate_eta, logger, progbar
+
+from ppmatting.metrics import metrics_class_dict
+
+np.set_printoptions(suppress=True)
+
+
+def save_alpha_pred(alpha, path):
+    """
+    The value of alpha is range [0, 1], shape should be [h,w]
+    """
+    dirname = os.path.dirname(path)
+    if not os.path.exists(dirname):
+        os.makedirs(dirname)
+
+    alpha = (alpha).astype('uint8')
+    cv2.imwrite(path, alpha)
+
+
+def reverse_transform(alpha, trans_info):
+    """recover pred to origin shape"""
+    for item in trans_info[::-1]:
+        if item[0][0] == 'resize':
+            h, w = item[1][0], item[1][1]
+            alpha = F.interpolate(alpha, [h, w], mode='bilinear')
+        elif item[0][0] == 'padding':
+            h, w = item[1][0], item[1][1]
+            alpha = alpha[:, :, 0:h, 0:w]
+        else:
+            raise Exception("Unexpected info '{}' in im_info".format(item[0]))
+    return alpha
+
+
+def evaluate(model,
+             eval_dataset,
+             num_workers=0,
+             print_detail=True,
+             save_dir='output/results',
+             save_results=True,
+             metrics='sad',
+             precision='fp32',
+             amp_level='O1'):
+    model.eval()
+    nranks = paddle.distributed.ParallelEnv().nranks
+    local_rank = paddle.distributed.ParallelEnv().local_rank
+    if nranks > 1:
+        # Initialize parallel environment if not done.
+        if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
+        ):
+            paddle.distributed.init_parallel_env()
+
+    loader = paddle.io.DataLoader(
+        eval_dataset,
+        batch_size=1,
+        drop_last=False,
+        num_workers=num_workers,
+        return_list=True, )
+
+    total_iters = len(loader)
+    # Get metric instances and data saving
+    metrics_ins = {}
+    metrics_data = {}
+    if isinstance(metrics, str):
+        metrics = [metrics]
+    elif not isinstance(metrics, list):
+        metrics = ['sad']
+    for key in metrics:
+        key = key.lower()
+        metrics_ins[key] = metrics_class_dict[key]()
+        metrics_data[key] = None
+
+    if print_detail:
+        logger.info("Start evaluating (total_samples: {}, total_iters: {})...".
+                    format(len(eval_dataset), total_iters))
+    progbar_val = progbar.Progbar(
+        target=total_iters, verbose=1 if nranks < 2 else 2)
+    reader_cost_averager = TimeAverager()
+    batch_cost_averager = TimeAverager()
+    batch_start = time.time()
+
+    img_name = ''
+    i = 0
+    with paddle.no_grad():
+        for iter, data in enumerate(loader):
+            reader_cost_averager.record(time.time() - batch_start)
+            if precision == 'fp16':
+                with paddle.amp.auto_cast(
+                        level=amp_level,
+                        enable=True,
+                        custom_white_list={
+                            "elementwise_add", "batch_norm", "sync_batch_norm"
+                        },
+                        custom_black_list={'bilinear_interp_v2', 'pad3d'}):
+                    alpha_pred = model(data)
+                    alpha_pred = reverse_transform(alpha_pred,
+                                                   data['trans_info'])
+            else:
+                alpha_pred = model(data)
+                alpha_pred = reverse_transform(alpha_pred, data['trans_info'])
+
+            alpha_pred = alpha_pred.numpy()
+
+            alpha_gt = data['alpha'].numpy() * 255
+            trimap = data.get('ori_trimap')
+            if trimap is not None:
+                trimap = trimap.numpy().astype('uint8')
+            alpha_pred = np.round(alpha_pred * 255)
+            for key in metrics_ins.keys():
+                metrics_data[key] = metrics_ins[key].update(alpha_pred,
+                                                            alpha_gt, trimap)
+
+            if save_results:
+                alpha_pred_one = alpha_pred[0].squeeze()
+                if trimap is not None:
+                    trimap = trimap.squeeze().astype('uint8')
+                    alpha_pred_one[trimap == 255] = 255
+                    alpha_pred_one[trimap == 0] = 0
+
+                save_name = data['img_name'][0]
+                name, ext = os.path.splitext(save_name)
+                if save_name == img_name:
+                    save_name = name + '_' + str(i) + ext
+                    i += 1
+                else:
+                    img_name = save_name
+                    save_name = name + '_' + str(i) + ext
+                    i = 1
+
+                save_alpha_pred(alpha_pred_one,
+                                os.path.join(save_dir, save_name))
+
+            batch_cost_averager.record(
+                time.time() - batch_start, num_samples=len(alpha_gt))
+            batch_cost = batch_cost_averager.get_average()
+            reader_cost = reader_cost_averager.get_average()
+
+            if local_rank == 0 and print_detail:
+                show_list = [(k, v) for k, v in metrics_data.items()]
+                show_list = show_list + [('batch_cost', batch_cost),
+                                         ('reader cost', reader_cost)]
+                progbar_val.update(iter + 1, show_list)
+
+            reader_cost_averager.reset()
+            batch_cost_averager.reset()
+            batch_start = time.time()
+
+    for key in metrics_ins.keys():
+        metrics_data[key] = metrics_ins[key].evaluate()
+    log_str = '[EVAL] '
+    for key, value in metrics_data.items():
+        log_str = log_str + key + ': {:.4f}, '.format(value)
+    log_str = log_str[:-2]
+
+    logger.info(log_str)
+    return metrics_data

Some files were not shown because too many files changed in this diff