video.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import os
  2. import warnings
  3. import cv2
  4. import paddle
  5. import numpy as np
  6. import ppmatting
  7. import ppmatting.transforms as T
  8. class VideoReader(paddle.io.Dataset):
  9. """
  10. Read a video
  11. """
  12. def __init__(self, path, transforms=None):
  13. super().__init__()
  14. if not os.path.exists(path):
  15. raise IOError('There is not found about video path:{} '.format(
  16. path))
  17. self.cap_video = cv2.VideoCapture(path)
  18. if not self.cap_video.isOpened():
  19. raise IOError('Video can not be oepned normally')
  20. # Get some video property
  21. self.fps = int(self.cap_video.get(cv2.CAP_PROP_FPS))
  22. self.frames = int(self.cap_video.get(cv2.CAP_PROP_FRAME_COUNT))
  23. self.width = int(self.cap_video.get(cv2.CAP_PROP_FRAME_WIDTH))
  24. self.height = int(self.cap_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
  25. transforms = [] if transforms is None else transforms
  26. if transforms is None or isinstance(transforms, list):
  27. self.transforms = T.Compose(transforms)
  28. elif isinstance(transforms, T.Compose):
  29. self.transforms = transforms
  30. else:
  31. raise ValueError(
  32. "transforms type is error, it should be list or ppmatting,transforms.Compose"
  33. )
  34. def __len__(self):
  35. return self.frames
  36. def __getitem__(self, idx):
  37. if idx >= self.frames:
  38. raise IndexError('The frame {} is read failed.'.format(idx))
  39. self.cap_video.set(cv2.CAP_PROP_POS_FRAMES, idx)
  40. ret, frame = self.cap_video.retrieve()
  41. if not ret:
  42. warnings.warn(
  43. "the frame {} is read failed. Video reading exit.".format(idx))
  44. raise IndexError('The frame {} is read failed.'.format(idx))
  45. data = {'img': frame}
  46. if self.transforms is not None:
  47. data = self.transforms(data)
  48. data['ori_img'] = frame.transpose((2, 0, 1)) / 255.
  49. return data
  50. def release(self):
  51. self.cap_video.release()
  52. class VideoWriter:
  53. """
  54. Video writer.
  55. Args:
  56. path (str): The path to save a video.
  57. fps (int): The fps of the saved video.
  58. frame_size (tuple): The frame size (width, height) of the saved video.
  59. is_color (bool): Whethe to save the video in color format.
  60. """
  61. def __init__(self, path, fps, frame_size, is_color=True):
  62. self.is_color = is_color
  63. ppmatting.utils.mkdir(path)
  64. fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
  65. self.cap_out = cv2.VideoWriter(
  66. filename=path,
  67. fourcc=fourcc,
  68. fps=fps,
  69. frameSize=frame_size,
  70. isColor=is_color)
  71. def write(self, frames):
  72. """
  73. Save frames.
  74. Args:
  75. frames(Tensor|numpy.ndarray): If `frames` is a tensor, it's shape should be like [N, C, H, W].
  76. If it is a ndarray, it's shape should be like [H, W, 3] or [H, W]. The value is in [0, 1].
  77. """
  78. if isinstance(frames, paddle.Tensor):
  79. if frames.ndim != 4:
  80. raise ValueError(
  81. 'The frames should have the shape like [N, C, H, W], but it is {}'.
  82. format(frames.shape))
  83. n, c, h, w = frames.shape
  84. if not (c == 1 or c == 3):
  85. raise ValueError(
  86. 'the channels of frames should be 1 or 3, but it is {}'.
  87. format(c))
  88. if c == 1 and self.is_color:
  89. frames = paddle.repeat_interleave(frames, repeats=3, axis=1)
  90. frames = (frames.transpose(
  91. (0, 2, 3, 1)).numpy() * 255).astype('uint8')
  92. for i in range(n):
  93. frame = frames[i]
  94. self.cap_out.write(frame)
  95. else:
  96. frames = (frames * 255).astype('uint8')
  97. self.cap_out.write(frames)
  98. def release(self):
  99. self.cap_out.release()