from __future__ import annotations

import os
import pathlib

import pytest

from dxtbx.serialize import load

import dials.command_line.split_still_data as split


@pytest.mark.xfail(
    os.name == "nt",
    reason="Failures due to translated paths; see https://github.com/cctbx/dxtbx/issues/613",
)
@pytest.mark.parametrize("use_yaml", [True, False])
def test_split_still_data(dials_data, run_in_tmp_path, use_yaml):
    data = dials_data("cunir_serial_processed")
    args = [
        os.fspath(data / "integrated.expt"),
        os.fspath(data / "integrated.refl"),
        "nproc=1",
    ]
    if use_yaml:
        images = os.fspath(dials_data("cunir_serial"))
        yml = f"""
---
metadata:
  timepoint:
    {images}/merlin0047_#####.cbf : 'repeat=2'
grouping:
  group_by:
    values:
      - timepoint
"""
        test_yaml = run_in_tmp_path / "tmp.yaml"
        with open(test_yaml, "w") as f:
            f.write(yml)
        args.append(f"grouping={test_yaml}")
    else:
        args.append("series_repeat=2")
    split.run(args=args)
    assert pathlib.Path(run_in_tmp_path / "group_0_0.expt").is_file()
    assert pathlib.Path(run_in_tmp_path / "group_0_0.refl").is_file()
    assert pathlib.Path(run_in_tmp_path / "group_1_0.expt").is_file()
    assert pathlib.Path(run_in_tmp_path / "group_1_0.refl").is_file()
    expts1 = load.experiment_list("group_0_0.expt", check_format=False)
    assert len(expts1) == 3
    # not old style elist datastructures (no scan, single imageset)
    assert expts1[0].imageset.get_path(0).split("_")[-1].rstrip(".cbf") == "17000"
    assert expts1[1].imageset.get_path(0).split("_")[-1].rstrip(".cbf") == "17002"
    assert expts1[2].imageset.get_path(0).split("_")[-1].rstrip(".cbf") == "17004"
    expts2 = load.experiment_list(
        run_in_tmp_path / "group_1_0.expt", check_format=False
    )
    assert len(expts2) == 2
    assert expts2[0].imageset.get_path(0).split("_")[-1].rstrip(".cbf") == "17001"
    assert expts2[1].imageset.get_path(0).split("_")[-1].rstrip(".cbf") == "17003"


def test_split_still_data_h5(dials_data, run_in_tmp_path):
    data = dials_data("lysozyme_ssx_processed")
    args = [
        os.fspath(data / "integrated.expt"),
        os.fspath(data / "integrated.refl"),
        "nproc=1",
        "series_repeat=2",
    ]
    split.run(args=args)
    assert pathlib.Path(run_in_tmp_path / "group_0_0.expt").is_file()
    assert pathlib.Path(run_in_tmp_path / "group_0_0.refl").is_file()
    assert pathlib.Path(run_in_tmp_path / "group_1_0.expt").is_file()
    assert pathlib.Path(run_in_tmp_path / "group_1_0.refl").is_file()
    expts1 = load.experiment_list("group_0_0.expt", check_format=False)
    first_images = [
        99,
        133,
        247,
        565,
        633,
        763,
        811,
        819,
        823,
        847,
        859,
        929,
        933,
        971,
        973,
    ]
    assert len(expts1) == len(first_images)

    for i, expt in zip(first_images, expts1):
        assert expt.scan.get_image_range() == (i, i)

    expts2 = load.experiment_list(
        run_in_tmp_path / "group_1_0.expt", check_format=False
    )
    second_images = [414, 472, 602, 878, 884, 920]
    assert len(expts2) == len(second_images)
    for i, expt in zip(second_images, expts2):
        assert expt.scan.get_image_range() == (i, i)
