123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import time
- from collections import deque, defaultdict
- import pickle
- import shutil
- import numpy as np
- import paddle
- import paddle.nn.functional as F
- from paddleseg.utils import TimeAverager, calculate_eta, resume, logger, train_profiler
- from .val import evaluate
- def visual_in_traning(log_writer, vis_dict, step):
- """
- Visual in vdl
- Args:
- log_writer (LogWriter): The log writer of vdl.
- vis_dict (dict): Dict of tensor. The shape of thesor is (C, H, W)
- """
- for key, value in vis_dict.items():
- value_shape = value.shape
- if value_shape[0] not in [1, 3]:
- value = value[0]
- value = value.unsqueeze(0)
- value = paddle.transpose(value, (1, 2, 0))
- min_v = paddle.min(value)
- max_v = paddle.max(value)
- if (min_v > 0) and (max_v < 1):
- value = value * 255
- elif (min_v < 0 and min_v >= -1) and (max_v <= 1):
- value = (1 + value) / 2 * 255
- else:
- value = (value - min_v) / (max_v - min_v) * 255
- value = value.astype('uint8')
- value = value.numpy()
- log_writer.add_image(tag=key, img=value, step=step)
- def save_best(best_model_dir, metrics_data, iter):
- with open(os.path.join(best_model_dir, 'best_metrics.txt'), 'w') as f:
- for key, value in metrics_data.items():
- line = key + ' ' + str(value) + '\n'
- f.write(line)
- f.write('iter' + ' ' + str(iter) + '\n')
- def get_best(best_file, metrics, resume_model=None):
- '''Get best metrics and iter from file'''
- best_metrics_data = {}
- if os.path.exists(best_file) and (resume_model is not None):
- values = []
- with open(best_file, 'r') as f:
- lines = f.readlines()
- for line in lines:
- line = line.strip()
- key, value = line.split(' ')
- best_metrics_data[key] = eval(value)
- if key == 'iter':
- best_iter = eval(value)
- else:
- for key in metrics:
- best_metrics_data[key] = np.inf
- best_iter = -1
- return best_metrics_data, best_iter
- def train(model,
- train_dataset,
- val_dataset=None,
- optimizer=None,
- save_dir='output',
- iters=10000,
- batch_size=2,
- resume_model=None,
- save_interval=1000,
- log_iters=10,
- log_image_iters=1000,
- num_workers=0,
- use_vdl=False,
- losses=None,
- keep_checkpoint_max=5,
- eval_begin_iters=None,
- metrics='sad',
- precision='fp32',
- amp_level='O1',
- profiler_options=None):
- """
- Launch training.
- Args:
- model(nn.Layer): A matting model.
- train_dataset (paddle.io.Dataset): Used to read and process training datasets.
- val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
- optimizer (paddle.optimizer.Optimizer): The optimizer.
- save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
- iters (int, optional): How may iters to train the model. Defualt: 10000.
- batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
- resume_model (str, optional): The path of resume model.
- save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
- log_iters (int, optional): Display logging information at every log_iters. Default: 10.
- log_image_iters (int, optional): Log image to vdl. Default: 1000.
- num_workers (int, optional): Num workers for data loader. Default: 0.
- use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
- losses (dict, optional): A dict of loss, refer to the loss function of the model for details. Default: None.
- keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
- eval_begin_iters (int): The iters begin evaluation. It will evaluate at iters/2 if it is None. Defalust: None.
- metrics(str|list, optional): The metrics to evaluate, it may be the combination of ("sad", "mse", "grad", "conn").
- precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the training is normal.
- amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision,
- the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators
- parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp)
- profiler_options (str, optional): The option of train profiler.
- """
- model.train()
- nranks = paddle.distributed.ParallelEnv().nranks
- local_rank = paddle.distributed.ParallelEnv().local_rank
- start_iter = 0
- if resume_model is not None:
- start_iter = resume(model, optimizer, resume_model)
- if not os.path.isdir(save_dir):
- if os.path.exists(save_dir):
- os.remove(save_dir)
- os.makedirs(save_dir)
- # Use amp
- if precision == 'fp16':
- logger.info('use AMP to train. AMP level = {}'.format(amp_level))
- scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
- if amp_level == 'O2':
- model, optimizer = paddle.amp.decorate(
- models=model,
- optimizers=optimizer,
- level='O2',
- save_dtype='float32')
- if nranks > 1:
- # Initialize parallel environment if not done.
- if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
- ):
- paddle.distributed.init_parallel_env()
- ddp_model = paddle.DataParallel(model)
- else:
- ddp_model = paddle.DataParallel(model)
- batch_sampler = paddle.io.DistributedBatchSampler(
- train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
- loader = paddle.io.DataLoader(
- train_dataset,
- batch_sampler=batch_sampler,
- num_workers=num_workers,
- return_list=True, )
- if use_vdl:
- from visualdl import LogWriter
- log_writer = LogWriter(save_dir)
- if isinstance(metrics, str):
- metrics = [metrics]
- elif not isinstance(metrics, list):
- metrics = ['sad']
- best_metrics_data, best_iter = get_best(
- os.path.join(save_dir, 'best_model', 'best_metrics.txt'),
- metrics,
- resume_model=resume_model)
- avg_loss = defaultdict(float)
- iters_per_epoch = len(batch_sampler)
- reader_cost_averager = TimeAverager()
- batch_cost_averager = TimeAverager()
- save_models = deque()
- batch_start = time.time()
- iter = start_iter
- while iter < iters:
- for data in loader:
- iter += 1
- if iter > iters:
- break
- reader_cost_averager.record(time.time() - batch_start)
- if precision == 'fp16':
- with paddle.amp.auto_cast(
- level=amp_level,
- enable=True,
- custom_white_list={
- "elementwise_add", "batch_norm", "sync_batch_norm"
- },
- custom_black_list={'bilinear_interp_v2', 'pad3d'}):
- logit_dict, loss_dict = ddp_model(
- data) if nranks > 1 else model(data)
- scaled = scaler.scale(loss_dict['all']) # scale the loss
- scaled.backward() # do backward
- scaler.minimize(optimizer, scaled) # update parameters
- else:
- logit_dict, loss_dict = ddp_model(
- data) if nranks > 1 else model(data)
- loss_dict['all'].backward()
- optimizer.step()
- lr = optimizer.get_lr()
- if isinstance(optimizer._learning_rate,
- paddle.optimizer.lr.LRScheduler):
- optimizer._learning_rate.step()
- train_profiler.add_profiler_step(profiler_options)
- model.clear_gradients()
- for key, value in loss_dict.items():
- avg_loss[key] += float(value)
- batch_cost_averager.record(
- time.time() - batch_start, num_samples=batch_size)
- if (iter) % log_iters == 0 and local_rank == 0:
- for key, value in avg_loss.items():
- avg_loss[key] = value / log_iters
- remain_iters = iters - iter
- avg_train_batch_cost = batch_cost_averager.get_average()
- avg_train_reader_cost = reader_cost_averager.get_average()
- eta = calculate_eta(remain_iters, avg_train_batch_cost)
- # loss info
- loss_str = ' ' * 26 + '\t[LOSSES]'
- loss_str = loss_str
- for key, value in avg_loss.items():
- if key != 'all':
- loss_str = loss_str + ' ' + key + '={:.4f}'.format(
- value)
- logger.info(
- "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f}, ips={:.4f} samples/sec | ETA {}\n{}\n"
- .format((iter - 1) // iters_per_epoch + 1, iter, iters,
- avg_loss['all'], lr, avg_train_batch_cost,
- avg_train_reader_cost,
- batch_cost_averager.get_ips_average(
- ), eta, loss_str))
- if use_vdl:
- for key, value in avg_loss.items():
- log_tag = 'Train/' + key
- log_writer.add_scalar(log_tag, value, iter)
- log_writer.add_scalar('Train/lr', lr, iter)
- log_writer.add_scalar('Train/batch_cost',
- avg_train_batch_cost, iter)
- log_writer.add_scalar('Train/reader_cost',
- avg_train_reader_cost, iter)
- if iter % log_image_iters == 0:
- vis_dict = {}
- # ground truth
- vis_dict['ground truth/img'] = data['img'][0]
- for key in data['gt_fields']:
- key = key[0]
- vis_dict['/'.join(['ground truth', key])] = data[
- key][0]
- # predict
- for key, value in logit_dict.items():
- vis_dict['/'.join(['predict', key])] = logit_dict[
- key][0]
- visual_in_traning(
- log_writer=log_writer, vis_dict=vis_dict, step=iter)
- for key in avg_loss.keys():
- avg_loss[key] = 0.
- reader_cost_averager.reset()
- batch_cost_averager.reset()
- # save model
- if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
- current_save_dir = os.path.join(save_dir,
- "iter_{}".format(iter))
- if not os.path.isdir(current_save_dir):
- os.makedirs(current_save_dir)
- paddle.save(model.state_dict(),
- os.path.join(current_save_dir, 'model.pdparams'))
- paddle.save(optimizer.state_dict(),
- os.path.join(current_save_dir, 'model.pdopt'))
- save_models.append(current_save_dir)
- if len(save_models) > keep_checkpoint_max > 0:
- model_to_remove = save_models.popleft()
- shutil.rmtree(model_to_remove)
- # eval model
- if eval_begin_iters is None:
- eval_begin_iters = iters // 2
- if (iter % save_interval == 0 or iter == iters) and (
- val_dataset is not None
- ) and local_rank == 0 and iter >= eval_begin_iters:
- num_workers = 1 if num_workers > 0 else 0
- metrics_data = evaluate(
- model,
- val_dataset,
- num_workers=1,
- print_detail=True,
- save_results=False,
- metrics=metrics,
- precision=precision,
- amp_level=amp_level)
- model.train()
- # save best model and add evaluation results to vdl
- if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
- if val_dataset is not None and iter >= eval_begin_iters:
- if metrics_data[metrics[0]] < best_metrics_data[metrics[0]]:
- best_iter = iter
- best_metrics_data = metrics_data.copy()
- best_model_dir = os.path.join(save_dir, "best_model")
- paddle.save(
- model.state_dict(),
- os.path.join(best_model_dir, 'model.pdparams'))
- save_best(best_model_dir, best_metrics_data, iter)
- show_list = []
- for key, value in best_metrics_data.items():
- show_list.append((key, value))
- log_str = '[EVAL] The model with the best validation {} ({:.4f}) was saved at iter {}.'.format(
- show_list[0][0], show_list[0][1], best_iter)
- if len(show_list) > 1:
- log_str += " While"
- for i in range(1, len(show_list)):
- log_str = log_str + ' {}: {:.4f},'.format(
- show_list[i][0], show_list[i][1])
- log_str = log_str[:-1]
- logger.info(log_str)
- if use_vdl:
- for key, value in metrics_data.items():
- log_writer.add_scalar('Evaluate/' + key, value,
- iter)
- batch_start = time.time()
- # Sleep for half a second to let dataloader release resources.
- time.sleep(0.5)
- if use_vdl:
- log_writer.close()
|