Source code for dials.algorithms.indexing.lattice_search.low_res_spot_match

import copy
import logging
import math
import operator

import libtbx.phil
from cctbx import miller
from dxtbx.model import Crystal
from scitbx import matrix
from scitbx.math import least_squares_plane, superpose

from dials.algorithms.indexing import DialsIndexError
from dials.array_family import flex

from .strategy import Strategy

logger = logging.getLogger(__name__)

TWO_PI = 2.0 * math.pi
FIVE_DEG = TWO_PI * 5.0 / 360.0


class CompleteGraph:
    def __init__(self, seed_vertex):
        self.vertices = [seed_vertex]
        self.weight = [{0: 0.0}]
        self.total_weight = 0.0

    def factory_add_vertex(self, vertex, weights_to_other):
        # Return a new graph as a copy of this with an extra vertex added. This
        # is a factory rather than a change in-place because CompleteGraph ought
        # to be immutable to implement __hash__
        g = copy.deepcopy(self)

        current_len = len(g.vertices)
        assert len(weights_to_other) == current_len
        g.vertices.append(vertex)
        node = current_len

        # Update distances from other nodes to the new one
        for i, w in enumerate(weights_to_other):
            g.weight[i][node] = w

        # Add distances to other nodes from this one
        weights_to_other.append(0.0)
        to_other = {}
        for i, w in enumerate(weights_to_other):
            to_other[i] = w
        g.weight.append(to_other)

        # Update the total weight
        g.total_weight += sum(weights_to_other)

        # Sort the vertices and weights by spot_id
        l = zip(g.vertices, g.weight)
        l = sorted(l, key=lambda v_w: v_w[0]["spot_id"])
        v, w = zip(*l)
        g.vertices = list(v)
        g.weight = list(w)

        return g

    def __hash__(self):
        h = tuple((e["spot_id"], e["miller_index"]) for e in self.vertices)
        return hash(h)

    def __eq__(self, other):
        for a, b in zip(self.vertices, other.vertices):
            if a["spot_id"] != b["spot_id"]:
                return False
            if a["miller_index"] != b["miller_index"]:
                return False
        return True

    def __ne__(self, other):
        return not self == other


low_res_spot_match_phil_str = """\
candidate_spots
{
    limit_resolution_by = *n_spots d_min
    .type = choice

    d_min = 15.0
    .type = float(value_min=0)

    n_spots = 10
    .type = int

    d_star_tolerance = 4.0
    .help = "Number of sigmas from the centroid position for which to "
            "calculate d* bands"
    .type = float
}

use_P1_indices_as_seeds = False
    .type = bool

search_depth = *triplets quads
    .type = choice

bootstrap_crystal = False
    .type = bool

max_pairs = 200
    .type = int

max_triplets = 600
    .type = int

max_quads = 600
    .type = int
"""


[docs]class LowResSpotMatch(Strategy): """ A lattice search strategy matching low resolution spots to candidate indices. The match is based on resolution and reciprocal space distance between observed spots. A prior unit cell and space group are required and solutions are assessed by matching the low resolution spots against candidate reflection positions predicted from the known cell. This lattice search strategy is a special case designed to work for electron diffraction still images, in which one typically only collects reflections from the zero-order Laue zone. In principle, it is not limited to this type of data, but probably works best with narrow wedges, good initial geometry and a small beam-stop shadow so that a good number of low-order reflections are collected. """ phil_help = ( "A lattice search strategy that matches low resolution spots to candidate " "indices based on a known unit cell and space group. Designed primarily for " "electron diffraction still images." ) phil_scope = libtbx.phil.parse(low_res_spot_match_phil_str)
[docs] def __init__( self, target_symmetry_primitive, max_lattices, params=None, *args, **kwargs ): """Construct a LowResSpotMatch object. Args: target_symmetry_primitive (cctbx.crystal.symmetry): The target crystal symmetry and unit cell max_lattices (int): The maximum number of lattice models to find """ super().__init__(params=params, *args, **kwargs) self._target_symmetry_primitive = target_symmetry_primitive self._max_lattices = max_lattices if target_symmetry_primitive is None: raise DialsIndexError( "Target unit cell and space group must be provided for low_res_spot_match" ) # Set reciprocal space orthogonalisation matrix uc = self._target_symmetry_primitive.unit_cell() self.Bmat = matrix.sqr(uc.fractionalization_matrix()).transpose()
[docs] def find_crystal_models(self, reflections, experiments): """Find a list of candidate crystal models. Args: reflections (dials.array_family.flex.reflection_table): The found spots centroids and associated data experiments (dxtbx.model.experiment_list.ExperimentList): The experimental geometry models """ # Take a subset of the observations at the same resolution and calculate # some values that will be needed for the search self._calc_obs_data(reflections, experiments) # Construct a library of candidate low res indices with their d* values self._calc_candidate_hkls() # First search: match each observation with candidate indices within the # acceptable resolution band self._calc_seeds_and_stems() if self._params.use_P1_indices_as_seeds: seeds = self.stems else: seeds = self.seeds logger.info("Using %s seeds", len(seeds)) # Second search: match seed spots with another spot from a different # reciprocal lattice row, such that the observed reciprocal space distances # are within tolerances pairs = [] for seed in seeds: pairs.extend(self._pairs_with_seed(seed)) logger.info("Found %s pairs", len(pairs)) pairs = list(set(pairs)) # filter duplicates if self._params.max_pairs: pairs.sort(key=operator.attrgetter("total_weight")) idx = self._params.max_pairs pairs = pairs[0:idx] logger.info("Using %s highest-scoring pairs", len(pairs)) # Further search iterations: extend to more spots within tolerated distances triplets = [] for pair in pairs: triplets.extend(self._extend_by_candidates(pair)) logger.info("Found %s triplets", len(triplets)) triplets = list(set(triplets)) # filter duplicates if self._params.max_triplets: triplets.sort(key=operator.attrgetter("total_weight")) idx = self._params.max_triplets triplets = triplets[0:idx] logger.info("Using %s highest-scoring triplets", len(triplets)) branches = triplets if self._params.search_depth == "quads": quads = [] for triplet in triplets: quads.extend(self._extend_by_candidates(triplet)) logger.info("%s quads", len(quads)) quads = list(set(quads)) # filter duplicates if self._params.max_quads: quads.sort(key=operator.attrgetter("total_weight")) idx = self._params.max_quads quads = quads[0:idx] logger.info("Using %s highest-scoring quads", len(quads)) branches = quads # Sort branches by total deviation of observed distances from expected branches.sort(key=operator.attrgetter("total_weight")) candidate_crystal_models = [] for branch in branches: model = self._fit_crystal_model(branch) if model: candidate_crystal_models.append(model) if len(candidate_crystal_models) == self._max_lattices: break self.candidate_crystal_models = candidate_crystal_models return self.candidate_crystal_models
def _calc_candidate_hkls(self): # First a list of indices that fill 1 ASU hkl_list = miller.build_set( self._target_symmetry_primitive, anomalous_flag=False, d_min=self._params.candidate_spots.d_min, ) rt = flex.reflection_table() rt["miller_index"] = hkl_list.indices() rt["d_star"] = 1.0 / hkl_list.d_spacings().data() rt["rlp_datum"] = self.Bmat.elems * rt["miller_index"].as_vec3_double() self.candidate_hkls = rt # Now P1 indices with separate Friedel pairs hkl_list = miller.build_set( self._target_symmetry_primitive, anomalous_flag=True, d_min=self._params.candidate_spots.d_min, ) hkl_list_p1 = hkl_list.expand_to_p1() rt = flex.reflection_table() rt["miller_index"] = hkl_list_p1.indices() rt["d_star"] = 1.0 / hkl_list_p1.d_spacings().data() rt["rlp_datum"] = self.Bmat.elems * rt["miller_index"].as_vec3_double() self.candidate_hkls_p1 = rt return def _calc_obs_data(self, reflections, experiments): """Calculates a set of low resolution observations to try to match to indices. Each observation will record its d* value as well as tolerated d* bands and a 'clock angle'""" spot_d_star = reflections["rlp"].norms() if self._params.candidate_spots.limit_resolution_by == "n_spots": n_spots = self._params.candidate_spots.n_spots n_spots = min(n_spots, len(reflections) - 1) d_star_max = flex.sorted(spot_d_star)[n_spots - 1] self._params.candidate_spots.d_min = 1.0 / d_star_max # First select low resolution spots only spot_d_star = reflections["rlp"].norms() d_star_max = 1.0 / self._params.candidate_spots.d_min sel = spot_d_star <= d_star_max self.spots = reflections.select(sel) self.spots["d_star"] = spot_d_star.select(sel) # XXX In what circumstance might there be more than one experiment? detector = experiments.detectors()[0] beam = experiments.beams()[0] # Lab coordinate of the beam centre, using the first spot's panel panel = detector[self.spots[0]["panel"]] bc = panel.get_ray_intersection(beam.get_s0()) bc_lab = panel.get_lab_coord(bc) # Lab coordinate of each spot spot_lab = flex.vec3_double(len(self.spots)) pnl_ids = set(self.spots["panel"]) for pnl in pnl_ids: sel = self.spots["panel"] == pnl panel = detector[pnl] obs = self.spots["xyzobs.mm.value"].select(sel) x_mm, y_mm, _ = obs.parts() spot_lab.set_selected( sel, panel.get_lab_coord(flex.vec2_double(x_mm, y_mm)) ) # Radius vectors for each spot radius = spot_lab - bc_lab # Usually the radius vectors would all be in a single plane, but this might # not be the case if the spots are on different panels. To put them on the # same plane, project onto fast/slow of the panel used to get the beam # centre df = flex.vec3_double(len(self.spots), detector[0].get_fast_axis()) ds = flex.vec3_double(len(self.spots), detector[0].get_slow_axis()) clock_dirs = (radius.dot(df) * df + radius.dot(ds) * ds).each_normalize() # From this, find positive angles of each vector around a clock, using the # fast axis as 12 o'clock angs = clock_dirs.angle(detector[0].get_fast_axis()) dots = clock_dirs.dot(detector[0].get_slow_axis()) sel = dots < 0 # select directions in the second half of the clock face angs.set_selected(sel, (TWO_PI - angs.select(sel))) self.spots["clock_angle"] = angs # Project radius vectors onto fast/slow of the relevant panels df = flex.vec3_double(len(self.spots)) ds = flex.vec3_double(len(self.spots)) for pnl in pnl_ids: sel = self.spots["panel"] == pnl panel = detector[pnl] df.set_selected(sel, panel.get_fast_axis()) ds.set_selected(sel, panel.get_slow_axis()) panel_dirs = (radius.dot(df) * df + radius.dot(ds) * ds).each_normalize() # Calc error along each panel direction with simple error propagation # that assumes no covariance between x and y centroid errors. x = panel_dirs.dot(df) y = panel_dirs.dot(ds) x2, y2 = flex.pow2(x), flex.pow2(y) r2 = x2 + y2 sig_x2, sig_y2, _ = self.spots["xyzobs.mm.variance"].parts() var_r = (x2 / r2) * sig_x2 + (y2 / r2) * sig_y2 sig_r = flex.sqrt(var_r) # Pixel coordinates at limits of the band tol = self._params.candidate_spots.d_star_tolerance outer_spot_lab = spot_lab + panel_dirs * (tol * sig_r) inner_spot_lab = spot_lab - panel_dirs * (tol * sig_r) # Set d* at band limits inv_lambda = 1.0 / beam.get_wavelength() s1_outer = outer_spot_lab.each_normalize() * inv_lambda s1_inner = inner_spot_lab.each_normalize() * inv_lambda self.spots["d_star_outer"] = (s1_outer - beam.get_s0()).norms() self.spots["d_star_inner"] = (s1_inner - beam.get_s0()).norms() self.spots["d_star_band2"] = flex.pow2( self.spots["d_star_outer"] - self.spots["d_star_inner"] ) def _calc_seeds_and_stems(self): # As the first stage of search, determine a list of seed spots for further # stages. Order these by distance of observed d* from the candidate # reflection's canonical d* # First the 'seeds' (in 1 ASU) self.seeds = [] for i, spot in enumerate(self.spots.rows()): sel = (self.candidate_hkls["d_star"] <= spot["d_star_outer"]) & ( self.candidate_hkls["d_star"] >= spot["d_star_inner"] ) cands = self.candidate_hkls.select(sel) for c in cands.rows(): r_dst = abs(c["d_star"] - spot["d_star"]) self.seeds.append( { "spot_id": i, "miller_index": c["miller_index"], "rlp_datum": matrix.col(c["rlp_datum"]), "residual_d_star": r_dst, "clock_angle": spot["clock_angle"], } ) self.seeds.sort(key=operator.itemgetter("residual_d_star")) # Now the 'stems' to use in second search level, using all indices in P 1 self.stems = [] for i, spot in enumerate(self.spots.rows()): sel = (self.candidate_hkls_p1["d_star"] <= spot["d_star_outer"]) & ( self.candidate_hkls_p1["d_star"] >= spot["d_star_inner"] ) cands = self.candidate_hkls_p1.select(sel) for c in cands.rows(): r_dst = abs(c["d_star"] - spot["d_star"]) self.stems.append( { "spot_id": i, "miller_index": c["miller_index"], "rlp_datum": matrix.col(c["rlp_datum"]), "residual_d_star": r_dst, "clock_angle": spot["clock_angle"], } ) self.stems.sort(key=operator.itemgetter("residual_d_star")) def _pairs_with_seed(self, seed): seed_rlp = matrix.col(self.spots[seed["spot_id"]]["rlp"]) result = [] for cand in self.stems: # Don't check the seed spot itself if cand["spot_id"] == seed["spot_id"]: continue # Skip spots at a very similar clock angle, which probably belong to the # same line of indices from the origin angle_diff = cand["clock_angle"] - seed["clock_angle"] angle_diff = abs(((angle_diff + math.pi) % TWO_PI) - math.pi) if angle_diff < FIVE_DEG: continue # Calculate the plane normal for the plane containing the seed and stem. # Skip pairs of Miller indices that belong to the same line seed_vec = seed["rlp_datum"] cand_vec = cand["rlp_datum"] try: seed_vec.cross(cand_vec).normalize() except ZeroDivisionError: continue # Compare expected reciprocal space distance with observed distance cand_rlp = matrix.col(self.spots[cand["spot_id"]]["rlp"]) obs_dist = (cand_rlp - seed_rlp).length() exp_dist = (seed_vec - cand_vec).length() r_dist = abs(obs_dist - exp_dist) # If the distance difference is larger than the sum in quadrature of the # tolerated d* bands then reject the candidate sq_band1 = self.spots[seed["spot_id"]]["d_star_band2"] sq_band2 = self.spots[cand["spot_id"]]["d_star_band2"] if r_dist > math.sqrt(sq_band1 + sq_band2): continue # Store the seed-stem match as a 2-node graph g = CompleteGraph( { "spot_id": seed["spot_id"], "miller_index": seed["miller_index"], "rlp_datum": seed["rlp_datum"], } ) g = g.factory_add_vertex( { "spot_id": cand["spot_id"], "miller_index": cand["miller_index"], "rlp_datum": cand["rlp_datum"], }, weights_to_other=[r_dist], ) result.append(g) return result def _extend_by_candidates(self, graph): existing_ids = [e["spot_id"] for e in graph.vertices] obs_relps = [matrix.col(self.spots[e]["rlp"]) for e in existing_ids] exp_relps = [e["rlp_datum"] for e in graph.vertices] result = [] for cand in self.stems: # Don't check spots already matched if cand["spot_id"] in existing_ids: continue # Compare expected reciprocal space distances with observed distances cand_rlp = matrix.col(self.spots[cand["spot_id"]]["rlp"]) cand_vec = cand["rlp_datum"] obs_dists = [(cand_rlp - rlp).length() for rlp in obs_relps] exp_dists = [(vec - cand_vec).length() for vec in exp_relps] residual_dist = [abs(a - b) for (a, b) in zip(obs_dists, exp_dists)] # If any of the distance differences is larger than the sum in quadrature # of the tolerated d* bands then reject the candidate sq_candidate_band = self.spots[cand["spot_id"]]["d_star_band2"] bad_candidate = False for r_dist, spot_id in zip(residual_dist, existing_ids): sq_relp_band = self.spots[spot_id]["d_star_band2"] if r_dist > math.sqrt(sq_relp_band + sq_candidate_band): bad_candidate = True break if bad_candidate: continue # Calculate co-planarity of the relps, including the origin points = flex.vec3_double(exp_relps + [cand_vec, (0.0, 0.0, 0.0)]) plane = least_squares_plane(points) plane_score = flex.sum_sq( points.dot(plane.normal) - plane.distance_to_origin ) # Reject if the group of relps are too far from lying in a single plane. # This cut-off was determined by trial and error using simulated images. if plane_score > 6e-7: continue # Construct a graph including the accepted candidate node g = graph.factory_add_vertex( { "spot_id": cand["spot_id"], "miller_index": cand["miller_index"], "rlp_datum": cand["rlp_datum"], }, weights_to_other=residual_dist, ) result.append(g) return result @staticmethod def _fit_U_from_superposed_points(reference, other): # Add the origin to both sets of points reference.append((0, 0, 0)) other.append((0, 0, 0)) # Find U matrix that takes ideal relps to the reference fit = superpose.least_squares_fit(reference, other) return fit.r def _fit_crystal_model(self, graph): vertices = graph.vertices # Reciprocal lattice points of the observations sel = flex.size_t([e["spot_id"] for e in vertices]) reference = self.spots["rlp"].select(sel) # Ideal relps from the known cell other = flex.vec3_double([e["rlp_datum"] for e in vertices]) U = self._fit_U_from_superposed_points(reference, other) UB = U * self.Bmat if self._params.bootstrap_crystal: # Attempt to index the low resolution spots from dials_algorithms_indexing_ext import AssignIndices phi = self.spots["xyzobs.mm.value"].parts()[2] UB_matrices = flex.mat3_double([UB]) result = AssignIndices(self.spots["rlp"], phi, UB_matrices, tolerance=0.3) hkl = result.miller_indices() sel = hkl != (0, 0, 0) hkl_vec = hkl.as_vec3_double().select(sel) # Use the result to get a new UB matrix reference = self.spots["rlp"].select(sel) other = self.Bmat.elems * hkl_vec U = self._fit_U_from_superposed_points(reference, other) UB = U * self.Bmat # Calculate RMSD of the fit rms = reference.rms_difference(U.elems * other) # Construct a crystal model xl = Crystal(A=UB, space_group_symbol="P1") # Monkey-patch crystal to return rms of the fit (useful?) xl.rms = rms return xl