infer.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import codecs
  16. import os
  17. import sys
  18. import cv2
  19. import tqdm
  20. import yaml
  21. import numpy as np
  22. import paddle
  23. from paddle.inference import create_predictor, PrecisionType
  24. from paddle.inference import Config as PredictConfig
  25. from paddleseg.cvlibs import manager
  26. from paddleseg.utils import get_sys_env, logger
  27. LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
  28. sys.path.append(os.path.join(LOCAL_PATH, '..', '..'))
  29. manager.BACKBONES._components_dict.clear()
  30. manager.TRANSFORMS._components_dict.clear()
  31. import ppmatting.transforms as T
  32. from ppmatting.utils import get_image_list, mkdir, estimate_foreground_ml, VideoReader, VideoWriter
  33. def parse_args():
  34. parser = argparse.ArgumentParser(description='Deploy for matting model')
  35. parser.add_argument(
  36. "--config",
  37. dest="cfg",
  38. help="The config file.",
  39. default=None,
  40. type=str,
  41. required=True)
  42. parser.add_argument(
  43. '--image_path',
  44. dest='image_path',
  45. help='The directory or path or file list of the images to be predicted.',
  46. type=str,
  47. default=None)
  48. parser.add_argument(
  49. '--trimap_path',
  50. dest='trimap_path',
  51. help='The directory or path or file list of the triamp to help predicted.',
  52. type=str,
  53. default=None)
  54. parser.add_argument(
  55. '--batch_size',
  56. dest='batch_size',
  57. help='Mini batch size of one gpu or cpu. When video inference, it is invalid.',
  58. type=int,
  59. default=1)
  60. parser.add_argument(
  61. '--video_path',
  62. dest='video_path',
  63. help='The path of the video to be predicted.',
  64. type=str,
  65. default=None)
  66. parser.add_argument(
  67. '--save_dir',
  68. dest='save_dir',
  69. help='The directory for saving the predict result.',
  70. type=str,
  71. default='./output')
  72. parser.add_argument(
  73. '--device',
  74. choices=['cpu', 'gpu'],
  75. default="gpu",
  76. help="Select which device to inference, defaults to gpu.")
  77. parser.add_argument(
  78. '--fg_estimate',
  79. default=True,
  80. type=eval,
  81. choices=[True, False],
  82. help='Whether to estimate foreground when predicting.')
  83. parser.add_argument(
  84. '--cpu_threads',
  85. default=10,
  86. type=int,
  87. help='Number of threads to predict when using cpu.')
  88. parser.add_argument(
  89. '--enable_mkldnn',
  90. default=False,
  91. type=eval,
  92. choices=[True, False],
  93. help='Enable to use mkldnn to speed up when using cpu.')
  94. parser.add_argument(
  95. '--use_trt',
  96. default=False,
  97. type=eval,
  98. choices=[True, False],
  99. help='Whether to use Nvidia TensorRT to accelerate prediction.')
  100. parser.add_argument(
  101. "--precision",
  102. default="fp32",
  103. type=str,
  104. choices=["fp32", "fp16", "int8"],
  105. help='The tensorrt precision.')
  106. parser.add_argument(
  107. '--enable_auto_tune',
  108. default=False,
  109. type=eval,
  110. choices=[True, False],
  111. help='Whether to enable tuned dynamic shape. We uses some images to collect '
  112. 'the dynamic shape for trt sub graph, which avoids setting dynamic shape manually.'
  113. )
  114. parser.add_argument(
  115. '--auto_tuned_shape_file',
  116. type=str,
  117. default="auto_tune_tmp.pbtxt",
  118. help='The temp file to save tuned dynamic shape.')
  119. parser.add_argument(
  120. "--benchmark",
  121. type=eval,
  122. default=False,
  123. help="Whether to log some information about environment, model, configuration and performance."
  124. )
  125. parser.add_argument(
  126. "--model_name",
  127. default="",
  128. type=str,
  129. help='When `--benchmark` is True, the specified model name is displayed.'
  130. )
  131. parser.add_argument(
  132. '--print_detail',
  133. default=True,
  134. type=eval,
  135. choices=[True, False],
  136. help='Print GLOG information of Paddle Inference.')
  137. return parser.parse_args()
  138. class DeployConfig:
  139. def __init__(self, path):
  140. with codecs.open(path, 'r', 'utf-8') as file:
  141. self.dic = yaml.load(file, Loader=yaml.FullLoader)
  142. self._transforms = self.load_transforms(self.dic['Deploy'][
  143. 'transforms'])
  144. self._dir = os.path.dirname(path)
  145. @property
  146. def transforms(self):
  147. return self._transforms
  148. @property
  149. def model(self):
  150. return os.path.join(self._dir, self.dic['Deploy']['model'])
  151. @property
  152. def params(self):
  153. return os.path.join(self._dir, self.dic['Deploy']['params'])
  154. @staticmethod
  155. def load_transforms(t_list):
  156. com = manager.TRANSFORMS
  157. transforms = []
  158. for t in t_list:
  159. ctype = t.pop('type')
  160. transforms.append(com[ctype](**t))
  161. return T.Compose(transforms)
  162. def use_auto_tune(args):
  163. return hasattr(PredictConfig, "collect_shape_range_info") \
  164. and hasattr(PredictConfig, "enable_tuned_tensorrt_dynamic_shape") \
  165. and args.device == "gpu" and args.use_trt and args.enable_auto_tune
  166. def auto_tune(args, imgs, img_nums):
  167. """
  168. Use images to auto tune the dynamic shape for trt sub graph.
  169. The tuned shape saved in args.auto_tuned_shape_file.
  170. Args:
  171. args(dict): input args.
  172. imgs(str, list[str]): the path for images.
  173. img_nums(int): the nums of images used for auto tune.
  174. Returns:
  175. None
  176. """
  177. logger.info("Auto tune the dynamic shape for GPU TRT.")
  178. assert use_auto_tune(args)
  179. if not isinstance(imgs, (list, tuple)):
  180. imgs = [imgs]
  181. num = min(len(imgs), img_nums)
  182. cfg = DeployConfig(args.cfg)
  183. pred_cfg = PredictConfig(cfg.model, cfg.params)
  184. pred_cfg.enable_use_gpu(100, 0)
  185. if not args.print_detail:
  186. pred_cfg.disable_glog_info()
  187. pred_cfg.collect_shape_range_info(args.auto_tuned_shape_file)
  188. predictor = create_predictor(pred_cfg)
  189. input_names = predictor.get_input_names()
  190. input_handle = predictor.get_input_handle(input_names[0])
  191. for i in range(0, num):
  192. data = {'img': imgs[i]}
  193. data = cfg.transforms(data)
  194. input_handle.reshape(data['img'].shape)
  195. input_handle.copy_from_cpu(data['img'])
  196. try:
  197. predictor.run()
  198. except:
  199. logger.info(
  200. "Auto tune fail. Usually, the error is out of GPU memory, "
  201. "because the model and image is too large. \n")
  202. del predictor
  203. if os.path.exists(args.auto_tuned_shape_file):
  204. os.remove(args.auto_tuned_shape_file)
  205. return
  206. logger.info("Auto tune success.\n")
  207. class Predictor:
  208. def __init__(self, args):
  209. """
  210. Prepare for prediction.
  211. The usage and docs of paddle inference, please refer to
  212. https://paddleinference.paddlepaddle.org.cn/product_introduction/summary.html
  213. """
  214. self.args = args
  215. self.cfg = DeployConfig(args.cfg)
  216. self._init_base_config()
  217. if args.device == 'cpu':
  218. self._init_cpu_config()
  219. else:
  220. self._init_gpu_config()
  221. self.predictor = create_predictor(self.pred_cfg)
  222. if hasattr(args, 'benchmark') and args.benchmark:
  223. import auto_log
  224. pid = os.getpid()
  225. gpu_id = None if args.device == 'cpu' else 0
  226. self.autolog = auto_log.AutoLogger(
  227. model_name=args.model_name,
  228. model_precision=args.precision,
  229. batch_size=args.batch_size,
  230. data_shape="dynamic",
  231. save_path=None,
  232. inference_config=self.pred_cfg,
  233. pids=pid,
  234. process_name=None,
  235. gpu_ids=gpu_id,
  236. time_keys=[
  237. 'preprocess_time', 'inference_time', 'postprocess_time'
  238. ],
  239. warmup=0,
  240. logger=logger)
  241. def _init_base_config(self):
  242. self.pred_cfg = PredictConfig(self.cfg.model, self.cfg.params)
  243. if not self.args.print_detail:
  244. self.pred_cfg.disable_glog_info()
  245. self.pred_cfg.enable_memory_optim()
  246. self.pred_cfg.switch_ir_optim(True)
  247. def _init_cpu_config(self):
  248. """
  249. Init the config for x86 cpu.
  250. """
  251. logger.info("Using CPU")
  252. self.pred_cfg.disable_gpu()
  253. if self.args.enable_mkldnn:
  254. logger.info("Using MKLDNN")
  255. # cache 1- different shapes for mkldnn
  256. self.pred_cfg.set_mkldnn_cache_capacity(10)
  257. self.pred_cfg.enable_mkldnn()
  258. self.pred_cfg.set_cpu_math_library_num_threads(self.args.cpu_threads)
  259. def _init_gpu_config(self):
  260. """
  261. Init the config for nvidia gpu.
  262. """
  263. logger.info("using GPU")
  264. self.pred_cfg.enable_use_gpu(100, 0)
  265. precision_map = {
  266. "fp16": PrecisionType.Half,
  267. "fp32": PrecisionType.Float32,
  268. "int8": PrecisionType.Int8
  269. }
  270. precision_mode = precision_map[self.args.precision]
  271. if self.args.use_trt:
  272. logger.info("Use TRT")
  273. self.pred_cfg.enable_tensorrt_engine(
  274. workspace_size=1 << 30,
  275. max_batch_size=1,
  276. min_subgraph_size=300,
  277. precision_mode=precision_mode,
  278. use_static=False,
  279. use_calib_mode=False)
  280. if use_auto_tune(self.args) and \
  281. os.path.exists(self.args.auto_tuned_shape_file):
  282. logger.info("Use auto tuned dynamic shape")
  283. allow_build_at_runtime = True
  284. self.pred_cfg.enable_tuned_tensorrt_dynamic_shape(
  285. self.args.auto_tuned_shape_file, allow_build_at_runtime)
  286. else:
  287. logger.info("Use manual set dynamic shape")
  288. min_input_shape = {"img": [1, 3, 100, 100]}
  289. max_input_shape = {"img": [1, 3, 2000, 3000]}
  290. opt_input_shape = {"img": [1, 3, 512, 1024]}
  291. self.pred_cfg.set_trt_dynamic_shape_info(
  292. min_input_shape, max_input_shape, opt_input_shape)
  293. def run(self, imgs, trimaps=None, imgs_dir=None):
  294. self.imgs_dir = imgs_dir
  295. num = len(imgs)
  296. input_names = self.predictor.get_input_names()
  297. input_handle = {}
  298. for i in range(len(input_names)):
  299. input_handle[input_names[i]] = self.predictor.get_input_handle(
  300. input_names[i])
  301. output_names = self.predictor.get_output_names()
  302. output_handle = self.predictor.get_output_handle(output_names[0])
  303. args = self.args
  304. for i in tqdm.tqdm(range(0, num, args.batch_size)):
  305. # warm up
  306. if i == 0 and args.benchmark:
  307. for _ in range(5):
  308. img_inputs = []
  309. if trimaps is not None:
  310. trimap_inputs = []
  311. trans_info = []
  312. for j in range(i, i + args.batch_size):
  313. img = imgs[j]
  314. trimap = trimaps[j] if trimaps is not None else None
  315. data = self._preprocess(img=img, trimap=trimap)
  316. img_inputs.append(data['img'])
  317. if trimaps is not None:
  318. trimap_inputs.append(data['trimap'][
  319. np.newaxis, :, :])
  320. trans_info.append(data['trans_info'])
  321. img_inputs = np.array(img_inputs)
  322. if trimaps is not None:
  323. trimap_inputs = (
  324. np.array(trimap_inputs)).astype('float32')
  325. input_handle['img'].copy_from_cpu(img_inputs)
  326. if trimaps is not None:
  327. input_handle['trimap'].copy_from_cpu(trimap_inputs)
  328. self.predictor.run()
  329. results = output_handle.copy_to_cpu()
  330. results = results.squeeze(1)
  331. for j in range(args.batch_size):
  332. trimap = trimap_inputs[
  333. j] if trimaps is not None else None
  334. result = self._postprocess(
  335. results[j], trans_info[j], trimap=trimap)
  336. # inference
  337. if args.benchmark:
  338. self.autolog.times.start()
  339. img_inputs = []
  340. if trimaps is not None:
  341. trimap_inputs = []
  342. trans_info = []
  343. for j in range(i, i + args.batch_size):
  344. img = imgs[j]
  345. trimap = trimaps[j] if trimaps is not None else None
  346. data = self._preprocess(img=img, trimap=trimap)
  347. img_inputs.append(data['img'])
  348. if trimaps is not None:
  349. trimap_inputs.append(data['trimap'][np.newaxis, :, :])
  350. trans_info.append(data['trans_info'])
  351. img_inputs = np.array(img_inputs)
  352. if trimaps is not None:
  353. trimap_inputs = (np.array(trimap_inputs)).astype('float32')
  354. input_handle['img'].copy_from_cpu(img_inputs)
  355. if trimaps is not None:
  356. input_handle['trimap'].copy_from_cpu(trimap_inputs)
  357. if args.benchmark:
  358. self.autolog.times.stamp()
  359. self.predictor.run()
  360. results = output_handle.copy_to_cpu()
  361. if args.benchmark:
  362. self.autolog.times.stamp()
  363. results = results.squeeze(1)
  364. for j in range(args.batch_size):
  365. trimap = trimap_inputs[j] if trimaps is not None else None
  366. result = self._postprocess(
  367. results[j], trans_info[j], trimap=trimap)
  368. self._save_imgs(result, imgs[i + j])
  369. if args.benchmark:
  370. self.autolog.times.end(stamp=True)
  371. logger.info("Finish")
  372. def _preprocess(self, img, trimap=None):
  373. data = {}
  374. data['img'] = img
  375. if trimap is not None:
  376. data['trimap'] = trimap
  377. data['gt_fields'] = ['trimap']
  378. data = self.cfg.transforms(data)
  379. return data
  380. def _postprocess(self, alpha, trans_info, trimap=None):
  381. """recover pred to origin shape"""
  382. if trimap is not None:
  383. trimap = trimap.squeeze(0)
  384. alpha[trimap == 0] = 0
  385. alpha[trimap == 255] = 1
  386. for item in trans_info[::-1]:
  387. if item[0] == 'resize':
  388. h, w = item[1][0], item[1][1]
  389. alpha = cv2.resize(
  390. alpha, (w, h), interpolation=cv2.INTER_LINEAR)
  391. elif item[0] == 'padding':
  392. h, w = item[1][0], item[1][1]
  393. alpha = alpha[0:h, 0:w]
  394. else:
  395. raise Exception("Unexpected info '{}' in im_info".format(item[
  396. 0]))
  397. return alpha
  398. def _save_imgs(self, alpha, img_path, fg=None):
  399. ori_img = cv2.imread(img_path)
  400. alpha = (alpha * 255).astype('uint8')
  401. if self.imgs_dir is not None:
  402. img_path = img_path.replace(self.imgs_dir, '')
  403. else:
  404. img_path = os.path.basename(img_path)
  405. name, ext = os.path.splitext(img_path)
  406. if name[0] == '/' or name[0] == '\\':
  407. name = name[1:]
  408. alpha_save_path = os.path.join(args.save_dir, name + '_alpha.png')
  409. rgba_save_path = os.path.join(args.save_dir, name + '_rgba.png')
  410. # save alpha
  411. mkdir(alpha_save_path)
  412. cv2.imwrite(alpha_save_path, alpha)
  413. # save rgba image
  414. mkdir(rgba_save_path)
  415. if fg is None:
  416. if args.fg_estimate:
  417. fg = estimate_foreground_ml(ori_img / 255.0,
  418. alpha / 255.0) * 255
  419. else:
  420. fg = ori_img
  421. else:
  422. fg = fg * 255
  423. fg = fg.astype('uint8')
  424. alpha = alpha[:, :, np.newaxis]
  425. rgba = np.concatenate([fg, alpha], axis=-1)
  426. cv2.imwrite(rgba_save_path, rgba)
  427. def run_video(self, video_path):
  428. """Video matting only support the trimap-free method"""
  429. input_names = self.predictor.get_input_names()
  430. input_handle = {}
  431. for i in range(len(input_names)):
  432. input_handle[input_names[i]] = self.predictor.get_input_handle(
  433. input_names[i])
  434. output_names = self.predictor.get_output_names()
  435. output_handle = {}
  436. output_handle['alpha'] = self.predictor.get_output_handle(output_names[
  437. 0])
  438. # Build reader and writer
  439. reader = VideoReader(video_path, self.cfg.transforms)
  440. base_name = os.path.basename(video_path)
  441. name = os.path.splitext(base_name)[0]
  442. alpha_save_path = os.path.join(args.save_dir, name + '_alpha.avi')
  443. fg_save_path = os.path.join(args.save_dir, name + '_fg.avi')
  444. writer_alpha = VideoWriter(
  445. alpha_save_path,
  446. reader.fps,
  447. frame_size=(reader.width, reader.height),
  448. is_color=False)
  449. writer_fg = VideoWriter(
  450. fg_save_path,
  451. reader.fps,
  452. frame_size=(reader.width, reader.height),
  453. is_color=True)
  454. for data in tqdm.tqdm(reader):
  455. trans_info = data['trans_info']
  456. _, h, w = data['img'].shape
  457. input_handle['img'].copy_from_cpu(data['img'][np.newaxis, ...])
  458. self.predictor.run()
  459. alpha = output_handle['alpha'].copy_to_cpu()
  460. alpha = alpha.squeeze()
  461. alpha = self._postprocess(alpha, trans_info)
  462. self._save_frame(
  463. alpha,
  464. fg=None,
  465. img=data['ori_img'],
  466. writer_alpha=writer_alpha,
  467. writer_fg=writer_fg)
  468. writer_alpha.release()
  469. writer_fg.release()
  470. reader.release()
  471. def _save_frame(self, alpha, fg, img, writer_alpha, writer_fg):
  472. if fg is None:
  473. img = img.transpose((1, 2, 0))
  474. if self.args.fg_estimate:
  475. fg = estimate_foreground_ml(img, alpha)
  476. else:
  477. fg = img
  478. fg = fg * alpha[:, :, np.newaxis]
  479. writer_alpha.write(alpha)
  480. writer_fg.write(fg)
  481. class PredictorRVM(Predictor):
  482. def __init__(self, args):
  483. super().__init__(args=args)
  484. def run(self, imgs, trimaps=None, imgs_dir=None):
  485. self.imgs_dir = imgs_dir
  486. num = len(imgs)
  487. input_names = self.predictor.get_input_names()
  488. input_handle = {}
  489. for i in range(len(input_names)):
  490. input_handle[input_names[i]] = self.predictor.get_input_handle(
  491. input_names[i])
  492. output_names = self.predictor.get_output_names()
  493. output_handle = {}
  494. output_handle['alpha'] = self.predictor.get_output_handle(output_names[
  495. 0])
  496. output_handle['fg'] = self.predictor.get_output_handle(output_names[1])
  497. output_handle['r1'] = self.predictor.get_output_handle(output_names[2])
  498. output_handle['r2'] = self.predictor.get_output_handle(output_names[3])
  499. output_handle['r3'] = self.predictor.get_output_handle(output_names[4])
  500. output_handle['r4'] = self.predictor.get_output_handle(output_names[5])
  501. args = self.args
  502. for i in tqdm.tqdm(range(0, num, args.batch_size)):
  503. # warm up
  504. if i == 0 and args.benchmark:
  505. for _ in range(5):
  506. img_inputs = []
  507. if trimaps is not None:
  508. trimap_inputs = []
  509. trans_info = []
  510. for j in range(i, i + args.batch_size):
  511. img = imgs[j]
  512. data = self._preprocess(img=img)
  513. img_inputs.append(data['img'])
  514. trans_info.append(data['trans_info'])
  515. img_inputs = np.array(img_inputs)
  516. n, _, h, w = img_inputs.shape
  517. downsample_ratio = min(512 / max(h, w), 1)
  518. downsample_ratio = np.array(
  519. [downsample_ratio], dtype='float32')
  520. input_handle['img'].copy_from_cpu(img_inputs)
  521. input_handle['downsample_ratio'].copy_from_cpu(
  522. downsample_ratio.astype('float32'))
  523. r_channels = [16, 20, 40, 64]
  524. for k in range(4):
  525. j = k + 1
  526. hj = int(np.ceil(int(h * downsample_ratio[0]) / 2**j))
  527. wj = int(np.ceil(int(w * downsample_ratio[0]) / 2**j))
  528. rj = np.zeros(
  529. (n, r_channels[k], hj, wj), dtype='float32')
  530. input_handle['r' + str(j)].copy_from_cpu(rj)
  531. self.predictor.run()
  532. alphas = output_handle['alpha'].copy_to_cpu()
  533. fgs = output_handle['fg'].copy_to_cpu()
  534. alphas = alphas.squeeze(1)
  535. for j in range(args.batch_size):
  536. alpha = self._postprocess(alphas[j], trans_info[j])
  537. fg = fgs[j]
  538. fg = np.transpose(fg, (1, 2, 0))
  539. fg = self._postprocess(fg, trans_info[j])
  540. # inference
  541. if args.benchmark:
  542. self.autolog.times.start()
  543. img_inputs = []
  544. if trimaps is not None:
  545. trimap_inputs = []
  546. trans_info = []
  547. for j in range(i, i + args.batch_size):
  548. img = imgs[j]
  549. data = self._preprocess(img=img)
  550. img_inputs.append(data['img'])
  551. trans_info.append(data['trans_info'])
  552. img_inputs = np.array(img_inputs)
  553. n, _, h, w = img_inputs.shape
  554. downsample_ratio = min(512 / max(h, w), 1)
  555. downsample_ratio = np.array([downsample_ratio], dtype='float32')
  556. input_handle['img'].copy_from_cpu(img_inputs)
  557. input_handle['downsample_ratio'].copy_from_cpu(
  558. downsample_ratio.astype('float32'))
  559. r_channels = [16, 20, 40, 64]
  560. for k in range(4):
  561. j = k + 1
  562. hj = int(np.ceil(int(h * downsample_ratio[0]) / 2**j))
  563. wj = int(np.ceil(int(w * downsample_ratio[0]) / 2**j))
  564. rj = np.zeros((n, r_channels[k], hj, wj), dtype='float32')
  565. input_handle['r' + str(j)].copy_from_cpu(rj)
  566. if args.benchmark:
  567. self.autolog.times.stamp()
  568. self.predictor.run()
  569. alphas = output_handle['alpha'].copy_to_cpu()
  570. fgs = output_handle['fg'].copy_to_cpu()
  571. if args.benchmark:
  572. self.autolog.times.stamp()
  573. alphas = alphas.squeeze(1)
  574. for j in range(args.batch_size):
  575. alpha = self._postprocess(alphas[j], trans_info[j])
  576. fg = fgs[j]
  577. fg = np.transpose(fg, (1, 2, 0))
  578. fg = self._postprocess(fg, trans_info[j])
  579. self._save_imgs(alpha, fg=fg, img_path=imgs[i + j])
  580. if args.benchmark:
  581. self.autolog.times.end(stamp=True)
  582. logger.info("Finish")
  583. def run_video(self, video_path):
  584. input_names = self.predictor.get_input_names()
  585. input_handle = {}
  586. for i in range(len(input_names)):
  587. input_handle[input_names[i]] = self.predictor.get_input_handle(
  588. input_names[i])
  589. output_names = self.predictor.get_output_names()
  590. output_handle = {}
  591. output_handle['alpha'] = self.predictor.get_output_handle(output_names[
  592. 0])
  593. output_handle['fg'] = self.predictor.get_output_handle(output_names[1])
  594. output_handle['r1'] = self.predictor.get_output_handle(output_names[2])
  595. output_handle['r2'] = self.predictor.get_output_handle(output_names[3])
  596. output_handle['r3'] = self.predictor.get_output_handle(output_names[4])
  597. output_handle['r4'] = self.predictor.get_output_handle(output_names[5])
  598. # Build reader and writer
  599. reader = VideoReader(video_path, self.cfg.transforms)
  600. base_name = os.path.basename(video_path)
  601. name = os.path.splitext(base_name)[0]
  602. alpha_save_path = os.path.join(args.save_dir, name + '_alpha.avi')
  603. fg_save_path = os.path.join(args.save_dir, name + '_fg.avi')
  604. writer_alpha = VideoWriter(
  605. alpha_save_path,
  606. reader.fps,
  607. frame_size=(reader.width, reader.height),
  608. is_color=False)
  609. writer_fg = VideoWriter(
  610. fg_save_path,
  611. reader.fps,
  612. frame_size=(reader.width, reader.height),
  613. is_color=True)
  614. r_channels = [16, 20, 40, 64]
  615. for i, data in tqdm.tqdm(enumerate(reader)):
  616. trans_info = data['trans_info']
  617. _, h, w = data['img'].shape
  618. if i == 0:
  619. downsample_ratio = min(512 / max(h, w), 1)
  620. downsample_ratio = np.array([downsample_ratio], dtype='float32')
  621. r_channels = [16, 20, 40, 64]
  622. for k in range(4):
  623. j = k + 1
  624. hj = int(np.ceil(int(h * downsample_ratio[0]) / 2**j))
  625. wj = int(np.ceil(int(w * downsample_ratio[0]) / 2**j))
  626. rj = np.zeros((1, r_channels[k], hj, wj), dtype='float32')
  627. input_handle['r' + str(j)].copy_from_cpu(rj)
  628. else:
  629. input_handle['r1'] = output_handle['r1']
  630. input_handle['r2'] = output_handle['r2']
  631. input_handle['r3'] = output_handle['r3']
  632. input_handle['r4'] = output_handle['r4']
  633. input_handle['img'].copy_from_cpu(data['img'][np.newaxis, ...])
  634. input_handle['downsample_ratio'].copy_from_cpu(
  635. downsample_ratio.astype('float32'))
  636. self.predictor.run()
  637. alpha = output_handle['alpha'].copy_to_cpu()
  638. fg = output_handle['fg'].copy_to_cpu()
  639. alpha = alpha.squeeze()
  640. alpha = self._postprocess(alpha, trans_info)
  641. fg = fg.squeeze().transpose((1, 2, 0))
  642. fg = self._postprocess(fg, trans_info)
  643. self._save_frame(alpha, fg, data['ori_img'], writer_alpha,
  644. writer_fg)
  645. writer_alpha.release()
  646. writer_fg.release()
  647. reader.release()
  648. def main(args):
  649. with open(args.cfg, 'r') as f:
  650. yaml_conf = yaml.load(f, Loader=yaml.FullLoader)
  651. model_name = yaml_conf.get('ModelName', None)
  652. if model_name == 'RVM':
  653. predector_ = PredictorRVM
  654. else:
  655. predector_ = Predictor
  656. if args.image_path is not None:
  657. imgs_list, imgs_dir = get_image_list(args.image_path)
  658. if args.trimap_path is None:
  659. trimaps_list = None
  660. else:
  661. trimaps_list, _ = get_image_list(args.trimap_path)
  662. if use_auto_tune(args):
  663. tune_img_nums = 10
  664. auto_tune(args, imgs_list, tune_img_nums)
  665. predictor = predector_(args)
  666. predictor.run(imgs=imgs_list, trimaps=trimaps_list, imgs_dir=imgs_dir)
  667. if use_auto_tune(args) and \
  668. os.path.exists(args.auto_tuned_shape_file):
  669. os.remove(args.auto_tuned_shape_file)
  670. if args.benchmark:
  671. predictor.autolog.report()
  672. elif args.video_path is not None:
  673. predictor = predector_(args)
  674. predictor.run_video(video_path=args.video_path)
  675. else:
  676. raise IOError("Please provide --image_path or --video_path.")
  677. if __name__ == '__main__':
  678. args = parse_args()
  679. main(args)