"""Reads a multi-tile CBF image, discovering its detector geometry automatically"""

from __future__ import annotations

import sys

import numpy
import pycbf

from scitbx import matrix
from scitbx.array_family import flex

from dxtbx.format.FormatCBF import FormatCBF
from dxtbx.format.FormatCBFFull import FormatCBFFull
from dxtbx.format.FormatStill import FormatStill
from dxtbx.model.detector import Detector
from dxtbx.model.detector_helpers import find_undefined_value, find_underload_value


def angle_and_axis(basis):
    """Normalize a quaternion and return the angle and axis
    @param params metrology object"""
    q = matrix.col(basis.orientation).normalize()
    return q.unit_quaternion_as_axis_and_angle(deg=True)


class cbf_wrapper(pycbf.cbf_handle_struct):
    """Wrapper class that provides convenience functions for working with cbflib"""

    def add_category(self, name, columns):
        """Create a new category and populate it with column names"""
        self.new_category(name.encode())
        for column in columns:
            self.new_column(column.encode())

    def add_row(self, data):
        """Add a row to the current category.  If data contains more entries than
        there are columns in this category, then the remainder is truncated
        Use '.' for an empty value in a row."""
        self.new_row()
        self.rewind_column()
        for item in data:
            try:
                self.set_value(item.encode())
            except AttributeError:
                self.set_value(item)
            if item == ".":
                self.set_typeofvalue(b"null")
            try:
                self.next_column()
            except Exception:
                break

    def has_sections(self):
        """True if the cbf has the array_structure_list_section table, which
        changes how its data is stored in the binary sections
        """
        try:
            self.find_category(b"array_structure_list_section")
            return True
        except Exception as e:
            if "CBF_NOTFOUND" in str(e):
                return False
            raise e

    def add_frame_shift(self, basis, axis_settings):
        """Add an axis representing a frame shift (a rotation axis with an offset)"""
        angle, axis = angle_and_axis(basis)

        if angle == 0:
            axis = (0, 0, 1)

        if basis.include_translation:
            translation = basis.translation
        else:
            translation = (0, 0, 0)

        self.add_row(
            [
                basis.axis_name,
                "rotation",
                "detector",
                basis.depends_on,
                str(axis[0]),
                str(axis[1]),
                str(axis[2]),
                str(translation[0]),
                str(translation[1]),
                str(translation[2]),
                basis.equipment_component,
            ]
        )

        axis_settings.append([basis.axis_name, "FRAME1", str(angle), "0"])


class FormatCBFMultiTile(FormatCBFFull):
    """An image reading class multi-tile CBF files"""

    @staticmethod
    def understand(image_file):
        """Check to see if this looks like an CBF format image, i.e. we can
        make sense of it."""

        try:
            cbf_handle = pycbf.cbf_handle_struct()
            cbf_handle.read_widefile(image_file.encode(), pycbf.MSG_DIGEST)
        except Exception as e:
            if "CBFlib Error" in str(e):
                return False

        # check if multiple arrays
        try:
            return cbf_handle.count_elements() > 1
        except Exception as e:
            if "CBFlib Error" in str(e):
                return False

    def _start(self):
        """Open the image file as a cbf file handle, and keep this somewhere
        safe."""
        FormatCBF._start(self)  # Note, skip up an inheritance level

    def detectorbase_start(self):
        pass

    def _get_cbf_handle(self):
        try:
            return self._cbf_handle
        except AttributeError:
            self._cbf_handle = cbf_wrapper()
            self._cbf_handle.read_widefile(self._image_file.encode(), pycbf.MSG_DIGEST)
            return self._cbf_handle

    def _detector(self):
        """Return a working detector instance."""

        cbf = self._get_cbf_handle()

        d = Detector()

        for i in range(cbf.count_elements()):
            ele_id = cbf.get_element_id(i)
            cbf.find_category(b"diffrn_data_frame")
            cbf.find_column(b"detector_element_id")
            cbf.find_row(ele_id)
            cbf.find_column(b"array_id")
            array_id = cbf.get_value()

            cbf_detector = cbf.construct_detector(i)

            p = d.add_panel()
            p.set_name(array_id)

            # code adapted below from dxtbx.model.detector.DetectorFactory.imgCIF_H
            pixel = (
                cbf_detector.get_inferred_pixel_size(1),
                cbf_detector.get_inferred_pixel_size(2),
            )

            fast = cbf_detector.get_detector_axes()[0:3]
            slow = cbf_detector.get_detector_axes()[3:6]
            origin = cbf_detector.get_pixel_coordinates_fs(0, 0)

            size = tuple(reversed(cbf.get_image_size(0)))

            try:
                min_trusted_value = find_underload_value(cbf)
            except Exception:
                try:
                    # By convention, if underload is not set, then assume the minimum
                    # trusted pixel is 1 more than the undefined pixel
                    min_trusted_value = find_undefined_value(cbf) + 1
                except Exception:
                    min_trusted_value = 0
            try:
                max_trusted_value = cbf.get_overload(i) - 1
            except Exception:
                max_trusted_value = 1.0e6

            trusted_range = (min_trusted_value, max_trusted_value)

            try:
                cbf.find_column(b"gain")
                cbf.select_row(i)
                gain = cbf.get_doublevalue()
            except Exception as e:
                if "CBF_NOTFOUND" not in str(e):
                    raise
                gain = 1.0

            cbf_detector.__swig_destroy__(cbf_detector)
            del cbf_detector

            p.set_local_frame(fast, slow, origin)

            p.set_pixel_size(tuple(map(float, pixel)))
            p.set_image_size(size)
            p.set_trusted_range(tuple(map(float, trusted_range)))
            p.set_gain(gain)
            # p.set_px_mm_strategy(px_mm) FIXME

        return d

    def _beam(self):
        """Return a working beam instance."""

        return self._beam_factory.imgCIF_H(self._get_cbf_handle())

    def get_raw_data(self):
        if self._raw_data is None:
            self._raw_data = []

            cbf = self._get_cbf_handle()

            # find the data
            cbf.select_category(0)
            while cbf.category_name().lower() != "array_data":
                try:
                    cbf.next_category()
                except Exception:
                    return None
            cbf.select_column(0)
            cbf.select_row(0)

            d = self.get_detector()

            for panel in d:
                name = panel.get_name()
                cbf.find_column(b"array_id")
                assert name == cbf.get_value()

                cbf.find_column(b"data")
                assert cbf.get_typeofvalue().find(b"bnry") > -1

                image_string = cbf.get_realarray_as_string()
                image = flex.double(numpy.frombuffer(image_string, numpy.float))

                parameters = cbf.get_realarrayparameters_wdims_fs()
                image_size = (parameters[6], parameters[5])

                image.reshape(flex.grid(*image_size))

                self._raw_data.append(image)

                try:
                    cbf.next_row()
                except Exception:
                    break
            assert len(d) == len(self._raw_data)

        return tuple(self._raw_data)


class FormatCBFMultiTileStill(FormatStill, FormatCBFMultiTile):
    """An image reading class for full CBF format images i.e. those from
    a variety of cameras which support this format. Custom derived from
    the FormatStill to handle images without a gonimeter or scan"""

    @staticmethod
    def understand(image_file):
        """Check to see if this looks like an CBF format image, i.e. we can
        make sense of it."""

        header = FormatCBFMultiTile.get_cbf_header(image_file)

        # According to ImageCIF, "Data items in the DIFFRN_MEASUREMENT_AXIS
        # category associate axes with goniometers."
        # http://www.iucr.org/__data/iucr/cifdic_html/2/cif_img.dic/Cdiffrn_measurement_axis.html
        if "diffrn_measurement_axis" in header:
            return False
        return True


if __name__ == "__main__":
    for arg in sys.argv[1:]:
        print(FormatCBFMultiTile.understand(arg))
