import logging
import os
from dataclasses import dataclass

from torch.utils.tensorboard import SummaryWriter

CSV_EPOCH_LOGGER = 'csv_epoch_logger'
CSV_BATCH_LOGGER = 'csv_batch_logger'


class FileHandlerWithHeader(logging.FileHandler):

    def __init__(self, filename, header, mode='a',
                 encoding=None, delay=False, errors=None):
        self.header = header
        self.file_pre_exists = os.path.exists(filename)

        super(FileHandlerWithHeader, self).__init__(
            filename, mode, encoding, delay, errors
        )
        if not delay and self.stream is not None and not self.file_pre_exists:
            self.stream.write(f'{header}\n')

    def emit(self, record):
        if self.stream is None:
            self.stream = self._open()
            if not self.file_pre_exists:
                self.stream.write(f'{self.header}\n')

        logging.FileHandler.emit(self, record)


def setup_logging(name="log",
                  filename=None,
                  stream_log_level="INFO",
                  file_log_level="INFO"):
    logger = logging.getLogger(name)
    logger.setLevel("INFO")
    formatter = logging.Formatter(
        '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
    )
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(getattr(logging, stream_log_level))
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)
    if filename is not None:
        header = 'time,logger,'
        if name == CSV_BATCH_LOGGER:
            header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
        elif name == CSV_EPOCH_LOGGER:
            header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
        else:
            raise NotImplementedError(f"Logger '{name}' is not implemented.")

        os.makedirs(os.path.dirname(filename), exist_ok=True)
        file_handler = FileHandlerWithHeader(filename, header)
        file_handler.setLevel(getattr(logging, file_log_level))
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    return logger


def training_log(name):
    def log_this(function):
        logger = logging.getLogger(name)

        def wrapper(*args, **kwargs):
            output = function(*args, **kwargs)
            logger.info(','.join(map(str, output.values())))
            return output

        return wrapper

    return log_this


@dataclass
class BaseBatchLogRecord:
    batch: int
    num_batches: int
    global_batch: int
    epoch: int
    num_epochs: int


@dataclass
class BaseEpochLogRecord:
    epoch: int
    num_epochs: int


@dataclass
class Loggers:
    csv_batch: logging.Logger
    csv_epoch: logging.Logger | None
    tensorboard: SummaryWriter


def init_csv_logger(name="log",
                    filename="log.csv",
                    metric_names=None,
                    stream_log_level="INFO",
                    file_log_level="INFO"):
    logger = logging.getLogger(name)
    logger.setLevel("INFO")
    formatter = logging.Formatter(
        '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
    )
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(getattr(logging, stream_log_level))
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    header = ['time', 'logger']
    if metric_names:
        header += metric_names

    header = ','.join(header)
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    file_handler = FileHandlerWithHeader(filename, header)
    file_handler.setLevel(getattr(logging, file_log_level))
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    return logger


def csv_logger(function):
    def wrapper(*args, **kwargs):
        loggers, metrics = function(*args, **kwargs)
        if isinstance(metrics, BaseEpochLogRecord):
            logger = loggers.csv_epoch
        elif isinstance(metrics, BaseBatchLogRecord):
            logger = loggers.csv_batch
        else:
            raise NotImplementedError(f"Unknown log type: '{type(metrics)}'")

        logger.info(','.join(map(str, metrics.__dict__.values())))
        return loggers, metrics

    return wrapper


def tensorboard_logger(function):
    def wrapper(*args, **kwargs):
        loggers, metrics = function(*args, **kwargs)
        if isinstance(metrics, BaseBatchLogRecord):
            metrics_exclude = BaseBatchLogRecord.__annotations__.keys()
            global_step = metrics.global_batch
        elif isinstance(metrics, BaseEpochLogRecord):
            metrics_exclude = BaseEpochLogRecord.__annotations__.keys()
            global_step = metrics.epoch
        else:
            raise NotImplementedError(f"Unknown log type: '{type(metrics)}'")

        logger = loggers.tensorboard
        for metric_name, metric_value in metrics.__dict__.items():
            if metric_name not in metrics_exclude:
                if isinstance(metric_value, float):
                    logger.add_scalar(metric_name, metric_value, global_step + 1)
                else:
                    NotImplementedError(f"Unsupported type: '{type(metric_value)}'")
        return loggers, metrics

    return wrapper