#!/usr/bin/env python3

import os
import sys
import argparse
import logging
import time
import re
import gi
from tofu import config, __version__

gi.require_version('Ufo', '0.0')


LOG = logging.getLogger('tofu')


def init(args):
    if not os.path.exists(args.config):
        config.write(args.config)
    else:
        raise RuntimeError("{0} already exists".format(args.config))


def run_tomo(args):
    from tofu import reco
    reco.tomo(args)


def run_lamino(args):
    from tofu import lamino
    lamino.lamino(args)


def run_genreco(args):
    from tofu import genreco
    genreco.genreco(args)


def run_flat_correct(args):
    from tofu import preprocess
    preprocess.run_flat_correct(args)


def run_preprocessing(args):
    from tofu import preprocess
    preprocess.run_preprocessing(args)


def run_sinos(args):
    from tofu import preprocess
    preprocess.run_sinogram_generation(args)


def run_ez(args):
    from tofu.ez.GUI.ezufo_launcher import main_qt
    main_qt(args)


def get_ipython_shell(config=None):
    import IPython

    version = IPython.__version__
    shell = None

    def cmp_versions(v1, v2):
        """Compare two version numbers and return cmp compatible result"""
        def normalize(v):
            return [int(x) for x in re.sub(r'(\.0+)*$', '', v).split(".")]

        n1 = normalize(v1)
        n2 = normalize(v2)
        return (n1 > n2) - (n1 < n2)

    if cmp_versions(version, '0.11') < 0:
        from IPython.Shell import IPShellEmbed
        shell = IPShellEmbed()
    elif cmp_versions(version, '1.0') < 0:
        from IPython.frontend.terminal.embed import \
            InteractiveShellEmbed
        shell = InteractiveShellEmbed(config=config, banner1='')
    else:
        from IPython.terminal.embed import InteractiveShellEmbed
        shell = InteractiveShellEmbed(config=config, banner1='')

    return shell


def run_shell(args):
    from tofu import reco

    shell = get_ipython_shell()
    shell()


def run_find_large_spots(args):
    from tofu.find_large_spots import find_large_spots, find_large_spots_median

    if args.method == 'grow':
        find_large_spots(args)
    else:
        find_large_spots_median(args)


def run_inpaint(args):
    from tofu import inpaint

    inpaint.run(args)


def gui(args):
    try:
        from tofu import gui
        gui.main(args)
    except ImportError as e:
        LOG.error(str(e))


def run_flow(args):
    from tofu.flow.main import main as flow_main
    flow_main()


def estimate(params):
    from tofu import reco
    center = reco.estimate_center(params)
    if params.verbose:
        out = '>>> Best axis of rotation: {}'.format(center)
    else:
        out = center

    print(out)


def perf(args):
    from tofu import reco

    def measure(args):
        exec_times = []
        total_times = []

        for i in range(args.num_runs):
            start = time.time()
            exec_times.append(reco.tomo(args))
            total_times.append(time.time() - start)

        exec_time = sum(exec_times) / len(exec_times)
        total_time = sum(total_times) / len(total_times)
        overhead = (total_time / exec_time - 1.0) * 100
        input_bandwidth = args.width * args.height * num_projections * 4 / exec_time / 1024. / 1024.
        output_bandwidth = args.width * args.width * height * 4 / exec_time / 1024. / 1024.
        slice_bandwidth = args.height / exec_time

        # Four bytes of our output bandwidth constitute one slice pixel, for each
        # pixel we have to do roughly n * 6 floating point ops (2 mad, 1 add, 1
        # interpolation)
        flops = output_bandwidth / 4 * 6 * num_projections / 1024

        msg = ("width={:<6d} height={:<6d} n_proj={:<6d}  "
               "exec={:.4f}s  total={:.4f}s  overhead={:.2f}%  "
               "bandwidth_i={:.2f}MB/s  bandwidth_o={:.2f}MB/s slices={:.2f}/s  "
               "flops={:.2f}GFLOPs\n")

        sys.stdout.write(msg.format(args.width, args.height, args.number,
                                    exec_time, total_time, overhead,
                                    input_bandwidth, output_bandwidth, slice_bandwidth, flops))
        sys.stdout.flush()

    args.projections = None
    args.sinograms = None
    args.dry_run = True

    for width in range(*args.width_range):
        for height in range(*args.height_range):
            for num_projections in range(*args.num_projection_range):
                args.width = width
                args.height = height
                args.number = num_projections
                measure(args)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', **config.SECTIONS['general']['config'])
    parser.add_argument('--version', action='version',
                        version='%(prog)s {}'.format(__version__))

    sino_params = ('flat-correction', 'sinos')
    reco_params = ('flat-correction', 'reconstruction')
    tomo_params = config.TOMO_PARAMS
    lamino_params = config.LAMINO_PARAMS
    gui_params = tomo_params + ('gui', )

    cmd_parsers = [
        ('init',        init,           (),                             "Create configuration file"),
        ('preprocess',  run_preprocessing, config.PREPROC_PARAMS,       "Run preprocessing"),
        ('flatcorrect', run_flat_correct, ('flat-correction',),         "Run flat field correction"),
        ('sinos',       run_sinos,      sino_params,                    "Generate sinograms from projections"),
        ('tomo',        run_tomo,       tomo_params,                    "Run tomographic reconstruction"),
        ('lamino',      run_lamino,     lamino_params,                  "Run laminographic reconstruction"),
        ('reco',        run_genreco,    config.GEN_RECO_PARAMS,         "Run general projection-based "
                                                                        "reconstruction for tomographic/"
                                                                        "laminographic cone/parallel beam"),
        ('gui',         gui,            tomo_params + ('gui',),         "GUI for tomographic reconstruction"),
        ('flow',        run_flow,       (),                             "Visual flow creation"),
        ('ez',          run_ez,         (),                             "GUI for making ufo-kit data processing pipelines"),
        ('estimate',    estimate,       tomo_params + ('estimate',),    "Estimate center of rotation"),
        ('perf',        perf,           tomo_params + ('perf',),        "Check reconstruction performance"),
        ('interactive', run_shell,      tomo_params,                    "Run interactive mode"),
        ('find-large-spots', run_find_large_spots, ('find-large-spots',), "Find large spots on images"),
        ('inpaint',     run_inpaint,    ('inpaint',),                   "Inpaint images"),
    ]

    if sys.version < '3.7':
        subparsers = parser.add_subparsers(title="Commands", dest='commands')
    else:
        subparsers = parser.add_subparsers(title="Commands", dest='commands', required=True)

    for cmd, func, sections, text in cmd_parsers:
        cmd_params = config.Params(sections=sections)
        cmd_parser = subparsers.add_parser(cmd, help=text, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        cmd_parser = cmd_params.add_arguments(cmd_parser)
        cmd_parser.set_defaults(_func=func)

    args = config.parse_known_args(parser, subparser=True)

    log_level = logging.DEBUG if args.verbose else logging.INFO
    LOG.setLevel(log_level)

    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))
    LOG.addHandler(stream_handler)

    if args.log:
        file_handler = logging.FileHandler(args.log)
        file_handler.setFormatter(logging.Formatter('[%(asctime)s] %(name)s:%(levelname)s: %(message)s'))
        LOG.addHandler(file_handler)

    try:
        config.log_values(args)
        args._func(args)
    except RuntimeError as e:
        LOG.error(str(e))
        sys.exit(1)


if __name__ == '__main__':
    main()

# vim: ft=python
