# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
from collections import deque, defaultdict
import pickle
import shutil

import numpy as np
import paddle
import paddle.nn.functional as F
from paddleseg.utils import TimeAverager, calculate_eta, resume, logger, train_profiler

from .val import evaluate


def visual_in_traning(log_writer, vis_dict, step):
    """
    Visual in vdl

    Args:
        log_writer (LogWriter): The log writer of vdl.
        vis_dict (dict): Dict of tensor. The shape of thesor is (C, H, W)
    """
    for key, value in vis_dict.items():
        value_shape = value.shape
        if value_shape[0] not in [1, 3]:
            value = value[0]
            value = value.unsqueeze(0)
        value = paddle.transpose(value, (1, 2, 0))
        min_v = paddle.min(value)
        max_v = paddle.max(value)
        if (min_v > 0) and (max_v < 1):
            value = value * 255
        elif (min_v < 0 and min_v >= -1) and (max_v <= 1):
            value = (1 + value) / 2 * 255
        else:
            value = (value - min_v) / (max_v - min_v) * 255

        value = value.astype('uint8')
        value = value.numpy()
        log_writer.add_image(tag=key, img=value, step=step)


def save_best(best_model_dir, metrics_data, iter):
    with open(os.path.join(best_model_dir, 'best_metrics.txt'), 'w') as f:
        for key, value in metrics_data.items():
            line = key + ' ' + str(value) + '\n'
            f.write(line)
        f.write('iter' + ' ' + str(iter) + '\n')


def get_best(best_file, metrics, resume_model=None):
    '''Get best metrics and iter from file'''
    best_metrics_data = {}
    if os.path.exists(best_file) and (resume_model is not None):
        values = []
        with open(best_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                line = line.strip()
                key, value = line.split(' ')
                best_metrics_data[key] = eval(value)
                if key == 'iter':
                    best_iter = eval(value)
    else:
        for key in metrics:
            best_metrics_data[key] = np.inf
        best_iter = -1
    return best_metrics_data, best_iter


def train(model,
          train_dataset,
          val_dataset=None,
          optimizer=None,
          save_dir='output',
          iters=10000,
          batch_size=2,
          resume_model=None,
          save_interval=1000,
          log_iters=10,
          log_image_iters=1000,
          num_workers=0,
          use_vdl=False,
          losses=None,
          keep_checkpoint_max=5,
          eval_begin_iters=None,
          metrics='sad',
          precision='fp32',
          amp_level='O1',
          profiler_options=None):
    """
    Launch training.
    Args:
        model(nn.Layer): A matting model.
        train_dataset (paddle.io.Dataset): Used to read and process training datasets.
        val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
        optimizer (paddle.optimizer.Optimizer): The optimizer.
        save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
        iters (int, optional): How may iters to train the model. Defualt: 10000.
        batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
        resume_model (str, optional): The path of resume model.
        save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
        log_iters (int, optional): Display logging information at every log_iters. Default: 10.
        log_image_iters (int, optional): Log image to vdl. Default: 1000.
        num_workers (int, optional): Num workers for data loader. Default: 0.
        use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
        losses (dict, optional): A dict of loss, refer to the loss function of the model for details. Default: None.
        keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
        eval_begin_iters (int): The iters begin evaluation. It will evaluate at iters/2 if it is None. Defalust: None.
        metrics(str|list, optional): The metrics to evaluate, it may be the combination of ("sad", "mse", "grad", "conn"). 
        precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the training is normal.
        amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, 
            the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators 
            parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp)
        profiler_options (str, optional): The option of train profiler.
    """
    model.train()
    nranks = paddle.distributed.ParallelEnv().nranks
    local_rank = paddle.distributed.ParallelEnv().local_rank

    start_iter = 0
    if resume_model is not None:
        start_iter = resume(model, optimizer, resume_model)

    if not os.path.isdir(save_dir):
        if os.path.exists(save_dir):
            os.remove(save_dir)
        os.makedirs(save_dir)

    # Use amp
    if precision == 'fp16':
        logger.info('use AMP to train. AMP level = {}'.format(amp_level))
        scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
        if amp_level == 'O2':
            model, optimizer = paddle.amp.decorate(
                models=model,
                optimizers=optimizer,
                level='O2',
                save_dtype='float32')

    if nranks > 1:
        # Initialize parallel environment if not done.
        if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
        ):
            paddle.distributed.init_parallel_env()
            ddp_model = paddle.DataParallel(model)
        else:
            ddp_model = paddle.DataParallel(model)

    batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    loader = paddle.io.DataLoader(
        train_dataset,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        return_list=True, )

    if use_vdl:
        from visualdl import LogWriter
        log_writer = LogWriter(save_dir)

    if isinstance(metrics, str):
        metrics = [metrics]
    elif not isinstance(metrics, list):
        metrics = ['sad']
    best_metrics_data, best_iter = get_best(
        os.path.join(save_dir, 'best_model', 'best_metrics.txt'),
        metrics,
        resume_model=resume_model)
    avg_loss = defaultdict(float)
    iters_per_epoch = len(batch_sampler)
    reader_cost_averager = TimeAverager()
    batch_cost_averager = TimeAverager()
    save_models = deque()
    batch_start = time.time()

    iter = start_iter
    while iter < iters:
        for data in loader:
            iter += 1
            if iter > iters:
                break
            reader_cost_averager.record(time.time() - batch_start)

            if precision == 'fp16':
                with paddle.amp.auto_cast(
                        level=amp_level,
                        enable=True,
                        custom_white_list={
                            "elementwise_add", "batch_norm", "sync_batch_norm"
                        },
                        custom_black_list={'bilinear_interp_v2', 'pad3d'}):
                    logit_dict, loss_dict = ddp_model(
                        data) if nranks > 1 else model(data)

                scaled = scaler.scale(loss_dict['all'])  # scale the loss
                scaled.backward()  # do backward
                scaler.minimize(optimizer, scaled)  # update parameters
            else:
                logit_dict, loss_dict = ddp_model(
                    data) if nranks > 1 else model(data)
                loss_dict['all'].backward()
                optimizer.step()

            lr = optimizer.get_lr()
            if isinstance(optimizer._learning_rate,
                          paddle.optimizer.lr.LRScheduler):
                optimizer._learning_rate.step()

            train_profiler.add_profiler_step(profiler_options)

            model.clear_gradients()

            for key, value in loss_dict.items():
                avg_loss[key] += float(value)
            batch_cost_averager.record(
                time.time() - batch_start, num_samples=batch_size)

            if (iter) % log_iters == 0 and local_rank == 0:
                for key, value in avg_loss.items():
                    avg_loss[key] = value / log_iters
                remain_iters = iters - iter
                avg_train_batch_cost = batch_cost_averager.get_average()
                avg_train_reader_cost = reader_cost_averager.get_average()
                eta = calculate_eta(remain_iters, avg_train_batch_cost)
                # loss info
                loss_str = ' ' * 26 + '\t[LOSSES]'
                loss_str = loss_str
                for key, value in avg_loss.items():
                    if key != 'all':
                        loss_str = loss_str + ' ' + key + '={:.4f}'.format(
                            value)
                logger.info(
                    "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f}, ips={:.4f} samples/sec | ETA {}\n{}\n"
                    .format((iter - 1) // iters_per_epoch + 1, iter, iters,
                            avg_loss['all'], lr, avg_train_batch_cost,
                            avg_train_reader_cost,
                            batch_cost_averager.get_ips_average(
                            ), eta, loss_str))
                if use_vdl:
                    for key, value in avg_loss.items():
                        log_tag = 'Train/' + key
                        log_writer.add_scalar(log_tag, value, iter)

                    log_writer.add_scalar('Train/lr', lr, iter)
                    log_writer.add_scalar('Train/batch_cost',
                                          avg_train_batch_cost, iter)
                    log_writer.add_scalar('Train/reader_cost',
                                          avg_train_reader_cost, iter)
                    if iter % log_image_iters == 0:
                        vis_dict = {}
                        # ground truth
                        vis_dict['ground truth/img'] = data['img'][0]
                        for key in data['gt_fields']:
                            key = key[0]
                            vis_dict['/'.join(['ground truth', key])] = data[
                                key][0]
                        # predict
                        for key, value in logit_dict.items():
                            vis_dict['/'.join(['predict', key])] = logit_dict[
                                key][0]
                        visual_in_traning(
                            log_writer=log_writer, vis_dict=vis_dict, step=iter)

                for key in avg_loss.keys():
                    avg_loss[key] = 0.
                reader_cost_averager.reset()
                batch_cost_averager.reset()

            # save model
            if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
                current_save_dir = os.path.join(save_dir,
                                                "iter_{}".format(iter))
                if not os.path.isdir(current_save_dir):
                    os.makedirs(current_save_dir)
                paddle.save(model.state_dict(),
                            os.path.join(current_save_dir, 'model.pdparams'))
                paddle.save(optimizer.state_dict(),
                            os.path.join(current_save_dir, 'model.pdopt'))
                save_models.append(current_save_dir)
                if len(save_models) > keep_checkpoint_max > 0:
                    model_to_remove = save_models.popleft()
                    shutil.rmtree(model_to_remove)

            # eval model
            if eval_begin_iters is None:
                eval_begin_iters = iters // 2
            if (iter % save_interval == 0 or iter == iters) and (
                    val_dataset is not None
            ) and local_rank == 0 and iter >= eval_begin_iters:
                num_workers = 1 if num_workers > 0 else 0
                metrics_data = evaluate(
                    model,
                    val_dataset,
                    num_workers=1,
                    print_detail=True,
                    save_results=False,
                    metrics=metrics,
                    precision=precision,
                    amp_level=amp_level)
                model.train()

            # save best model and add evaluation results to vdl
            if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
                if val_dataset is not None and iter >= eval_begin_iters:
                    if metrics_data[metrics[0]] < best_metrics_data[metrics[0]]:
                        best_iter = iter
                        best_metrics_data = metrics_data.copy()
                        best_model_dir = os.path.join(save_dir, "best_model")
                        paddle.save(
                            model.state_dict(),
                            os.path.join(best_model_dir, 'model.pdparams'))
                        save_best(best_model_dir, best_metrics_data, iter)

                    show_list = []
                    for key, value in best_metrics_data.items():
                        show_list.append((key, value))
                    log_str = '[EVAL] The model with the best validation {} ({:.4f}) was saved at iter {}.'.format(
                        show_list[0][0], show_list[0][1], best_iter)
                    if len(show_list) > 1:
                        log_str += " While"
                        for i in range(1, len(show_list)):
                            log_str = log_str + ' {}: {:.4f},'.format(
                                show_list[i][0], show_list[i][1])
                        log_str = log_str[:-1]
                    logger.info(log_str)

                    if use_vdl:
                        for key, value in metrics_data.items():
                            log_writer.add_scalar('Evaluate/' + key, value,
                                                  iter)

            batch_start = time.time()

    # Sleep for half a second to let dataloader release resources.
    time.sleep(0.5)
    if use_vdl:
        log_writer.close()