train.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. from collections import deque, defaultdict
  17. import pickle
  18. import shutil
  19. import numpy as np
  20. import paddle
  21. import paddle.nn.functional as F
  22. from paddleseg.utils import TimeAverager, calculate_eta, resume, logger, train_profiler
  23. from .val import evaluate
  24. def visual_in_traning(log_writer, vis_dict, step):
  25. """
  26. Visual in vdl
  27. Args:
  28. log_writer (LogWriter): The log writer of vdl.
  29. vis_dict (dict): Dict of tensor. The shape of thesor is (C, H, W)
  30. """
  31. for key, value in vis_dict.items():
  32. value_shape = value.shape
  33. if value_shape[0] not in [1, 3]:
  34. value = value[0]
  35. value = value.unsqueeze(0)
  36. value = paddle.transpose(value, (1, 2, 0))
  37. min_v = paddle.min(value)
  38. max_v = paddle.max(value)
  39. if (min_v > 0) and (max_v < 1):
  40. value = value * 255
  41. elif (min_v < 0 and min_v >= -1) and (max_v <= 1):
  42. value = (1 + value) / 2 * 255
  43. else:
  44. value = (value - min_v) / (max_v - min_v) * 255
  45. value = value.astype('uint8')
  46. value = value.numpy()
  47. log_writer.add_image(tag=key, img=value, step=step)
  48. def save_best(best_model_dir, metrics_data, iter):
  49. with open(os.path.join(best_model_dir, 'best_metrics.txt'), 'w') as f:
  50. for key, value in metrics_data.items():
  51. line = key + ' ' + str(value) + '\n'
  52. f.write(line)
  53. f.write('iter' + ' ' + str(iter) + '\n')
  54. def get_best(best_file, metrics, resume_model=None):
  55. '''Get best metrics and iter from file'''
  56. best_metrics_data = {}
  57. if os.path.exists(best_file) and (resume_model is not None):
  58. values = []
  59. with open(best_file, 'r') as f:
  60. lines = f.readlines()
  61. for line in lines:
  62. line = line.strip()
  63. key, value = line.split(' ')
  64. best_metrics_data[key] = eval(value)
  65. if key == 'iter':
  66. best_iter = eval(value)
  67. else:
  68. for key in metrics:
  69. best_metrics_data[key] = np.inf
  70. best_iter = -1
  71. return best_metrics_data, best_iter
  72. def train(model,
  73. train_dataset,
  74. val_dataset=None,
  75. optimizer=None,
  76. save_dir='output',
  77. iters=10000,
  78. batch_size=2,
  79. resume_model=None,
  80. save_interval=1000,
  81. log_iters=10,
  82. log_image_iters=1000,
  83. num_workers=0,
  84. use_vdl=False,
  85. losses=None,
  86. keep_checkpoint_max=5,
  87. eval_begin_iters=None,
  88. metrics='sad',
  89. precision='fp32',
  90. amp_level='O1',
  91. profiler_options=None):
  92. """
  93. Launch training.
  94. Args:
  95. model(nn.Layer): A matting model.
  96. train_dataset (paddle.io.Dataset): Used to read and process training datasets.
  97. val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
  98. optimizer (paddle.optimizer.Optimizer): The optimizer.
  99. save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
  100. iters (int, optional): How may iters to train the model. Defualt: 10000.
  101. batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
  102. resume_model (str, optional): The path of resume model.
  103. save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
  104. log_iters (int, optional): Display logging information at every log_iters. Default: 10.
  105. log_image_iters (int, optional): Log image to vdl. Default: 1000.
  106. num_workers (int, optional): Num workers for data loader. Default: 0.
  107. use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
  108. losses (dict, optional): A dict of loss, refer to the loss function of the model for details. Default: None.
  109. keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
  110. eval_begin_iters (int): The iters begin evaluation. It will evaluate at iters/2 if it is None. Defalust: None.
  111. metrics(str|list, optional): The metrics to evaluate, it may be the combination of ("sad", "mse", "grad", "conn").
  112. precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the training is normal.
  113. amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision,
  114. the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators
  115. parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp)
  116. profiler_options (str, optional): The option of train profiler.
  117. """
  118. model.train()
  119. nranks = paddle.distributed.ParallelEnv().nranks
  120. local_rank = paddle.distributed.ParallelEnv().local_rank
  121. start_iter = 0
  122. if resume_model is not None:
  123. start_iter = resume(model, optimizer, resume_model)
  124. if not os.path.isdir(save_dir):
  125. if os.path.exists(save_dir):
  126. os.remove(save_dir)
  127. os.makedirs(save_dir)
  128. # Use amp
  129. if precision == 'fp16':
  130. logger.info('use AMP to train. AMP level = {}'.format(amp_level))
  131. scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
  132. if amp_level == 'O2':
  133. model, optimizer = paddle.amp.decorate(
  134. models=model,
  135. optimizers=optimizer,
  136. level='O2',
  137. save_dtype='float32')
  138. if nranks > 1:
  139. # Initialize parallel environment if not done.
  140. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  141. ):
  142. paddle.distributed.init_parallel_env()
  143. ddp_model = paddle.DataParallel(model)
  144. else:
  145. ddp_model = paddle.DataParallel(model)
  146. batch_sampler = paddle.io.DistributedBatchSampler(
  147. train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
  148. loader = paddle.io.DataLoader(
  149. train_dataset,
  150. batch_sampler=batch_sampler,
  151. num_workers=num_workers,
  152. return_list=True, )
  153. if use_vdl:
  154. from visualdl import LogWriter
  155. log_writer = LogWriter(save_dir)
  156. if isinstance(metrics, str):
  157. metrics = [metrics]
  158. elif not isinstance(metrics, list):
  159. metrics = ['sad']
  160. best_metrics_data, best_iter = get_best(
  161. os.path.join(save_dir, 'best_model', 'best_metrics.txt'),
  162. metrics,
  163. resume_model=resume_model)
  164. avg_loss = defaultdict(float)
  165. iters_per_epoch = len(batch_sampler)
  166. reader_cost_averager = TimeAverager()
  167. batch_cost_averager = TimeAverager()
  168. save_models = deque()
  169. batch_start = time.time()
  170. iter = start_iter
  171. while iter < iters:
  172. for data in loader:
  173. iter += 1
  174. if iter > iters:
  175. break
  176. reader_cost_averager.record(time.time() - batch_start)
  177. if precision == 'fp16':
  178. with paddle.amp.auto_cast(
  179. level=amp_level,
  180. enable=True,
  181. custom_white_list={
  182. "elementwise_add", "batch_norm", "sync_batch_norm"
  183. },
  184. custom_black_list={'bilinear_interp_v2', 'pad3d'}):
  185. logit_dict, loss_dict = ddp_model(
  186. data) if nranks > 1 else model(data)
  187. scaled = scaler.scale(loss_dict['all']) # scale the loss
  188. scaled.backward() # do backward
  189. scaler.minimize(optimizer, scaled) # update parameters
  190. else:
  191. logit_dict, loss_dict = ddp_model(
  192. data) if nranks > 1 else model(data)
  193. loss_dict['all'].backward()
  194. optimizer.step()
  195. lr = optimizer.get_lr()
  196. if isinstance(optimizer._learning_rate,
  197. paddle.optimizer.lr.LRScheduler):
  198. optimizer._learning_rate.step()
  199. train_profiler.add_profiler_step(profiler_options)
  200. model.clear_gradients()
  201. for key, value in loss_dict.items():
  202. avg_loss[key] += float(value)
  203. batch_cost_averager.record(
  204. time.time() - batch_start, num_samples=batch_size)
  205. if (iter) % log_iters == 0 and local_rank == 0:
  206. for key, value in avg_loss.items():
  207. avg_loss[key] = value / log_iters
  208. remain_iters = iters - iter
  209. avg_train_batch_cost = batch_cost_averager.get_average()
  210. avg_train_reader_cost = reader_cost_averager.get_average()
  211. eta = calculate_eta(remain_iters, avg_train_batch_cost)
  212. # loss info
  213. loss_str = ' ' * 26 + '\t[LOSSES]'
  214. loss_str = loss_str
  215. for key, value in avg_loss.items():
  216. if key != 'all':
  217. loss_str = loss_str + ' ' + key + '={:.4f}'.format(
  218. value)
  219. logger.info(
  220. "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f}, ips={:.4f} samples/sec | ETA {}\n{}\n"
  221. .format((iter - 1) // iters_per_epoch + 1, iter, iters,
  222. avg_loss['all'], lr, avg_train_batch_cost,
  223. avg_train_reader_cost,
  224. batch_cost_averager.get_ips_average(
  225. ), eta, loss_str))
  226. if use_vdl:
  227. for key, value in avg_loss.items():
  228. log_tag = 'Train/' + key
  229. log_writer.add_scalar(log_tag, value, iter)
  230. log_writer.add_scalar('Train/lr', lr, iter)
  231. log_writer.add_scalar('Train/batch_cost',
  232. avg_train_batch_cost, iter)
  233. log_writer.add_scalar('Train/reader_cost',
  234. avg_train_reader_cost, iter)
  235. if iter % log_image_iters == 0:
  236. vis_dict = {}
  237. # ground truth
  238. vis_dict['ground truth/img'] = data['img'][0]
  239. for key in data['gt_fields']:
  240. key = key[0]
  241. vis_dict['/'.join(['ground truth', key])] = data[
  242. key][0]
  243. # predict
  244. for key, value in logit_dict.items():
  245. vis_dict['/'.join(['predict', key])] = logit_dict[
  246. key][0]
  247. visual_in_traning(
  248. log_writer=log_writer, vis_dict=vis_dict, step=iter)
  249. for key in avg_loss.keys():
  250. avg_loss[key] = 0.
  251. reader_cost_averager.reset()
  252. batch_cost_averager.reset()
  253. # save model
  254. if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
  255. current_save_dir = os.path.join(save_dir,
  256. "iter_{}".format(iter))
  257. if not os.path.isdir(current_save_dir):
  258. os.makedirs(current_save_dir)
  259. paddle.save(model.state_dict(),
  260. os.path.join(current_save_dir, 'model.pdparams'))
  261. paddle.save(optimizer.state_dict(),
  262. os.path.join(current_save_dir, 'model.pdopt'))
  263. save_models.append(current_save_dir)
  264. if len(save_models) > keep_checkpoint_max > 0:
  265. model_to_remove = save_models.popleft()
  266. shutil.rmtree(model_to_remove)
  267. # eval model
  268. if eval_begin_iters is None:
  269. eval_begin_iters = iters // 2
  270. if (iter % save_interval == 0 or iter == iters) and (
  271. val_dataset is not None
  272. ) and local_rank == 0 and iter >= eval_begin_iters:
  273. num_workers = 1 if num_workers > 0 else 0
  274. metrics_data = evaluate(
  275. model,
  276. val_dataset,
  277. num_workers=1,
  278. print_detail=True,
  279. save_results=False,
  280. metrics=metrics,
  281. precision=precision,
  282. amp_level=amp_level)
  283. model.train()
  284. # save best model and add evaluation results to vdl
  285. if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
  286. if val_dataset is not None and iter >= eval_begin_iters:
  287. if metrics_data[metrics[0]] < best_metrics_data[metrics[0]]:
  288. best_iter = iter
  289. best_metrics_data = metrics_data.copy()
  290. best_model_dir = os.path.join(save_dir, "best_model")
  291. paddle.save(
  292. model.state_dict(),
  293. os.path.join(best_model_dir, 'model.pdparams'))
  294. save_best(best_model_dir, best_metrics_data, iter)
  295. show_list = []
  296. for key, value in best_metrics_data.items():
  297. show_list.append((key, value))
  298. log_str = '[EVAL] The model with the best validation {} ({:.4f}) was saved at iter {}.'.format(
  299. show_list[0][0], show_list[0][1], best_iter)
  300. if len(show_list) > 1:
  301. log_str += " While"
  302. for i in range(1, len(show_list)):
  303. log_str = log_str + ' {}: {:.4f},'.format(
  304. show_list[i][0], show_list[i][1])
  305. log_str = log_str[:-1]
  306. logger.info(log_str)
  307. if use_vdl:
  308. for key, value in metrics_data.items():
  309. log_writer.add_scalar('Evaluate/' + key, value,
  310. iter)
  311. batch_start = time.time()
  312. # Sleep for half a second to let dataloader release resources.
  313. time.sleep(0.5)
  314. if use_vdl:
  315. log_writer.close()