from __future__ import absolute_import, division, print_function

import iota.threads.iota_threads

"""
Author      : Lyubimov, A.Y.
Created     : 10/12/2014
Last Changed: 11/21/2019
Description : Interprets command line arguments. Initializes all IOTA starting
              parameters. Starts main log. Options for a variety of running
              modes, including resuming an aborted run.
"""

import os
import copy
import time

assert time

import iota.init.iota_input as inp
from iota.utils import utils as util
from iota.base.info import ProcInfo


def initialize_interface(args, phil_args=None, gui=False):
    """Read and process input, create PHIL."""

    msg = []
    input_dict = iota.threads.iota_threads.ginp.process_mixed_input(args.path)
    if input_dict and not gui and not input_dict["imagefiles"]:
        return None, None, "IOTA_INIT_ERROR: No readable image files in path(s)!"

    # Move args that were included in paths and not processed into phil_args,
    # to try and interpret them as PHIL args
    if input_dict["badpaths"]:
        phil_args.extend(input_dict["badpaths"])

    # Read in parameters, make IOTA PHIL
    iota_phil, bad_args = inp.process_input(
        args=args, phil_args=phil_args, paramfile=input_dict["paramfile"], gui=gui
    )

    # Check if any PHIL args not read into the PHIL were in fact bad paths
    if bad_args:
        input_dict["badpaths"] = [a for a in bad_args if a in input_dict["badpaths"]]
        if input_dict["badpaths"]:
            msg.append("Files or directories not found:")
            for badpath in input_dict["badpaths"]:
                msg += "\n{}".format(badpath)
        bad_args = [a for a in bad_args if a not in input_dict["badpaths"]]
        if bad_args:
            msg += "\nThese arguments could not be interpreted: "
            for arg in bad_args:
                msg += "\n{}".format(arg)

    return input_dict, iota_phil, msg


def initialize_new_run(phil, input_dict=None, target_phil=None):
    """Create base integration folder; safe phil, input, and info to file."""
    try:
        params = phil.extract()
        int_base, run_no = util.set_base_dir(
            dirname="integration", out_dir=params.output, get_run_no=True
        )
        if not os.path.isdir(int_base):
            os.makedirs(int_base)

        # Create input list file and populate param input line
        if input_dict:
            if len(input_dict["imagepaths"]) >= 25:
                input_list_file = os.path.join(int_base, "input.lst")
                with open(input_list_file, "w") as lf:
                    for f in input_dict["imagefiles"]:
                        lf.write("{}\n".format(f))
                    params.input = [input_list_file]
            else:
                # If there are too many imagefiles, re-constitute the "glob" format
                # by matching filepaths and replacing non-matching characters with
                # asterisks
                if len(input_dict["imagefiles"]) >= 25:
                    input_paths = []
                    for path in input_dict["imagepaths"]:
                        fileset = [
                            os.path.basename(i)
                            for i in input_dict["imagefiles"]
                            if path in i
                        ]
                        zips = [list(set(i)) for i in zip(*fileset)]
                        chars = [i[0] if len(i) == 1 else "*" for i in zips]
                        fname = "".join(chars)
                        while "*" * 2 in fname:
                            fname = fname.replace("*" * 2, "*")
                        input_paths.append(os.path.join(path, fname))
                    params.input = input_paths
                else:
                    params.input = input_dict["imagefiles"]
                input_list_file = None
        else:
            input_list_file = None

        # Generate default backend PHIL, write to file, and update params
        target_fp = os.path.join(int_base, "target.phil")
        if target_phil:
            target_phil = inp.write_phil(
                phil_str=target_phil, dest_file=target_fp, write_target_file=True
            )
        else:
            if params.cctbx_xfel.target:
                target_phil = inp.write_phil(
                    phil_file=params.cctbx_xfel.target,
                    dest_file=target_fp,
                    write_target_file=True,
                )
            else:
                method = params.advanced.processing_backend
                target_phil, _ = inp.write_defaults(
                    method=method, write_param_file=False, filepath=target_fp
                )
        params.cctbx_xfel.target = target_fp

        # Save PHIL for this run in base integration folder
        paramfile = os.path.join(int_base, "iota_r{}.param".format(run_no))
        phil = phil.format(python_object=params)

        with open(paramfile, "w") as philf:
            philf.write(phil.as_str())

        # Initialize main log
        logfile = os.path.abspath(os.path.join(int_base, "iota.log"))

        # Initialize proc.info object and save to file
        info = ProcInfo.from_args(
            iota_phil=phil.as_str(),
            target_phil=target_phil.as_str(),
            int_base=int_base,
            input_list_file=input_list_file,
            info_file=os.path.join(int_base, "proc.info"),
            cluster_info_file=os.path.join(int_base, "cluster.info"),
            paramfile=paramfile,
            logfile=logfile,
            run_number=run_no,
            description=params.description,
            status="initialized",
            have_results=False,
            errors=[],
            init_proc=False,
        )
        info.export_json()
        return True, info, "IOTA_XTERM_INIT: Initialization complete!"
    except Exception as e:
        msg = "IOTA_INIT_ERROR: Could not initialize run! {}".format(e)
        return False, None, msg


def initialize_processing(paramfile, run_no):
    """Initialize processing for a set of images.

    :param paramfile: text file with IOTA parameters
    :param run_no: number of the processing run
    :return: info: INFO object
             params: IOTA params
    """
    try:
        phil, _ = inp.get_input_phil(paramfile=paramfile)
    except Exception as e:
        msg = "IOTA_PROC_ERROR: Cannot import IOTA parameters! {}".format(e)
        return False, msg
    else:
        params = phil.extract()

    # Reconstruct integration base path and get info object
    int_base = os.path.join(params.output, "integration/{:03d}".format(run_no))
    try:
        info_file = os.path.join(int_base, "proc.info")
        info = ProcInfo.from_json(filepath=info_file)
    except Exception as e:
        msg = "IOTA_PROC_ERROR: Cannot import INFO object! {}".format(e)
        return False, msg

    # Generate input list and input base
    if not hasattr(info, "input_list"):
        info.generate_input_list(params=params)
    filepath_list = []
    for item in info.input_list:
        if isinstance(item, list) or isinstance(item, tuple):
            fp = [i for i in item if os.path.exists(str(i))]
            if fp and len(fp) == 1:
                filepath_list.append(fp[0])
        else:
            if os.path.exists(item):
                filepath_list.append(item)
    common_pfx = os.path.abspath(os.path.dirname(os.path.commonprefix(filepath_list)))

    input_base = common_pfx
    if os.path.isdir(os.path.abspath(params.input[0])):
        new_common_pfx = os.path.commonprefix(
            [os.path.abspath(params.input[0]), common_pfx]
        )
        if new_common_pfx not in ("", "."):
            input_base = new_common_pfx

    # Generate subfolder paths
    paths = dict(
        obj_base=os.path.join(int_base, "image_objects"),
        fin_base=os.path.join(int_base, "final"),
        log_base=os.path.join(int_base, "logs"),
        dials_log_base=os.path.join(int_base, "logs/dials_logs"),
        viz_base=os.path.join(int_base, "visualization"),
        tmp_base=os.path.join(int_base, "tmp"),
        input_base=input_base,
    )
    for bkey, bvalue in paths.items():
        if bkey == "input_base":
            continue
        if not os.path.isdir(bvalue):
            os.makedirs(bvalue)
    info.update(paths)

    # Generate filepaths for various info files
    info_files = dict(
        obj_list_file=os.path.join(info.tmp_base, "finished_objects.lst"),
        idx_file=os.path.join(info.int_base, "observations.pickle"),
    )
    info.update(info_files)

    # Initialize stat containers
    info = generate_stat_containers(info=info, params=params)

    # Initialize main log
    util.main_log(info.logfile, "{:*^80} \n".format(" IOTA MAIN LOG "))
    util.main_log(info.logfile, "{:-^80} \n".format(" SETTINGS FOR THIS RUN "))
    util.main_log(info.logfile, info.iota_phil)
    util.main_log(info.logfile, "{:-^80} \n".format("BACKEND SETTINGS"))
    util.main_log(info.logfile, info.target_phil)

    info.export_json()

    return info, params


def resume_processing(info):
    """Initialize run parameters for an existing run (e.g. for resuming a
    terminated run or re-submitting with new images)

    :param info: INFO object
    :return: info: Updated INFO object
             params: IOTA params
    """

    if not info.init_proc:
        return initialize_processing(info.paramfile, info.run_number)
    else:
        try:
            phil, _ = inp.get_input_phil(paramfile=info.paramfile)
        except Exception:
            return None, None
        else:
            info.status = "processing"
            return info, phil.extract()


def initialize_single_image(
    img, paramfile, output_file=None, output_dir=None, min_bragg=10
):

    phil, _ = inp.get_input_phil(paramfile=paramfile)
    params = phil.extract()

    params.input = [img]
    params.mp.n_processors = 1
    params.data_selection.image_triage.minimum_Bragg_peaks = min_bragg
    phil = phil.format(python_object=params)

    info = ProcInfo.from_args(iota_phil=phil.as_str(), paramfile=paramfile)

    # Initialize output
    if output_file is not None:
        if output_dir is not None:
            output = os.path.join(os.path.abspath(output_dir), output_file)
        else:
            output = os.path.abspath(output_file)
    else:
        output = None
    info.obj_list_file = output

    info.generate_input_list(params=params)
    info = generate_stat_containers(info=info, params=params)

    return info, params


def generate_stat_containers(info, params):
    # Generate containers for processing information
    info.update(
        bookmark=0,
        merged_indices={},
        b_factors=[],
        final_objects=[],
        finished_objects=[],
        status_summary={"nonzero": [], "names": [], "patches": []},
        cluster_iterable=[],
        clusters=[],
        prime_info=[],
        user_sg="P1",
        best_pg=None,
        best_uc=None,
        msg="",
        categories=dict(
            total=(
                copy.deepcopy(info.unprocessed),
                "images read in",
                "full_input.lst",
                None,
            ),
            have_diffraction=([], "have diffraction", "have_diffraction.lst", None),
            failed_triage=([], "failed triage", "failed_triage.lst", "#d73027"),
            failed_spotfinding=(
                [],
                "failed spotfinding",
                "failed_spotfinding.lst",
                "#f46d43",
            ),
            failed_indexing=([], "failed indexing", "failed_indexing.lst", "#fdae61"),
            failed_refinement=(
                [],
                "failed refinement",
                "failed_refinement.lst",
                "#fdae6b",
            ),
            failed_grid_search=(
                [],
                "failed grid search",
                "failed_integration.lst",
                "#fee090",
            ),
            failed_integration=(
                [],
                "failed integration",
                "failed_integration.lst",
                "#fee090",
            ),
            failed_filter=([], "failed filter", "failed_filter.lst", "#ffffbf"),
            integrated=([], "integrated", "integrated.lst", "#4575b4"),
            not_processed=(
                copy.deepcopy(info.unprocessed),
                "not processed",
                "not_processed.lst",
                "#e0f3f8",
            ),
        ),
        stats={},
        pointers={},
        pixel_size=None,
        status="processing",
        init_proc=True,
        have_results=False,
    )

    # Grid search stats dictionary (HA14 - deprecated)
    if params.advanced.processing_backend == "ha14":
        gs_stat_keys = [
            ("s", "signal height", "Signal Height"),
            ("h", "spot height", "Spot Height"),
            ("a", "spot area", "Spot Area"),
        ]
        info.gs_stats = {}
        for key in gs_stat_keys:
            k = key[0]
            l = key[2]
            info.stats[k] = dict(lst=[], mean=0, std=0, max=0, min=0, cons=0, label=l)

    # Statistics dictionary
    stat_keys = [
        ("res", "Resolution"),
        ("lres", "Low Resolution"),
        ("strong", "Number of spots"),
        ("mos", "Mosaicity"),
        ("wavelength", "X-ray Wavelength"),
        ("distance", "Detector Distance"),
        ("beamX", "BeamX (mm)"),
        ("beamY", "BeamY (mm)"),
    ]
    for key in stat_keys:
        k = key[0]
        l = key[1]
        info.stats[k] = dict(
            lst=[], median=0, mean=0, std=0, max=0, min=0, cons=0, label=l
        )

    return info
