Source code for epsearch._branching

from collections.abc import Callable, Mapping, Sequence
from typing import Any, Generic, Literal, Protocol, TypeVar

import attrs
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from colormap_complex import colormap
from matplotlib_multicolored_line import colored_line
from numpy.typing import NDArray

from epsearch._cycle import Cycles, TNumber, get_cycles

TKey = TypeVar("TKey")
TNumber_co_ = TypeVar("TNumber_co_", bound=np.number, covariant=True)


def count_duplicate(a: NDArray[Any], /, *, eps: float = 1e-3) -> NDArray[Any]:
    """Count the number of duplicates."""
    a = np.abs(a[..., :, None] - a[..., None, :]) + np.eye(a.shape[-1]) * 1e10
    return (a.min(axis=-1) < eps).sum(axis=-1)


def contour_integral(
    values: NDArray[TNumber], /, type: Literal["circle", "square"] = "circle"
) -> NDArray[TNumber]:
    """
    Calculate contour integral.

    Parameters
    ----------
    values: NDArray[TNumber]
        Values of shape (..., num_points).
    type : Literal["circle", "square"], optional
        The type of contour, by default "circle".
        If "square", assume that first quarter of points are
        one side of the square, second quarter are the side which
        is connected to the previous side, and so on.

    Returns
    -------
    NDArray[TNumber]
        The contour integral of shape (...).

    """
    if type == "square":
        l_side = values.shape[-1] // 4
        return (
            np.mean(values[..., :l_side], axis=-1)
            + 1j * np.mean(values[..., l_side : 2 * l_side], axis=-1)
            - np.mean(values[..., 2 * l_side : 3 * l_side], axis=-1)
            - 1j * np.mean(values[..., 3 * l_side :], axis=-1)
        )
    j = np.arange(values.shape[-1]) / values.shape[-1]
    weights = np.exp(2j * np.pi * j)
    cauchy = np.mean(values * weights, axis=-1)
    return cauchy


def is_analytic(
    values: NDArray[TNumber],
    *,
    rtol: float | None = None,
    integral: NDArray[TNumber] | None = None,
    type: Literal["circle", "square"] = "circle",
) -> bool:
    """
    Check if Cauchy's integral formula is fulfilled.

    Parameters
    ----------
    values : NDArray[TNumber]
        Values of shape (..., num_points).
    rtol : float, optional
        The relative tolerance, by default 1e-3.
    integral : NDArray[TNumber] | None, optional
        The contour integral of shape (...).
    type : Literal["circle", "square"], optional
        The type of contour, by default "circle".

    Returns
    -------
    bool
        Whether Cauchy's integral formula is fulfilled
        for all sequences.

    """
    rtol = 1e-3 if rtol is None else rtol
    integral = contour_integral(values, type=type) if integral is None else integral
    return bool(np.all(np.abs(integral) < rtol * np.max(np.abs(values), axis=-1)))


[docs] class BoundaryGenerator(Protocol[TKey, TNumber_co_]): """A protocol for boundary generators.""" def __call__( self, go_further: Mapping[TKey, bool], / ) -> tuple[Mapping[TKey, Sequence[TNumber_co_]], Sequence[TKey]]: """ Divide-and-conquer search. Parameters ---------- go_further : Mapping[TKey, bool] Whether to go further for each boundary. Returns ------- tuple[Mapping[TKey, Sequence[TNumber_co_]], Sequence[TKey]] The new boundaries and the final keys. """ ...
[docs] def keys_to_values(self, keys: Sequence[TKey], /) -> Sequence[TNumber_co_]: """ Get the final candidates. Parameters ---------- keys : Sequence[TKey] The keys of the final candidates. Returns ------- Sequence[TNumber_co_] The final candidates. """ ...
[docs] @attrs.frozen(kw_only=True) class FindExceptionalPointsRecursivelyResult(Generic[TKey, TNumber]): """The result of the recursive search.""" boundaries: Mapping[TKey, Sequence[TNumber]] """The boundaries returned by the boundary generator.""" eigvals: Mapping[TKey, Sequence[Sequence[TNumber]]] """The eigenvalues for each boundary.""" cycles: Mapping[TKey, Cycles[TNumber]] """The cycles for each boundary.""" generations: Mapping[TKey, int] """The generation of each boundary.""" keys: Sequence[TKey] """The keys of the final candidates.""" branching_points: Sequence[TNumber] """The branching points found.""" f_boundary: BoundaryGenerator[TKey, TNumber]
[docs] def plot( self, text_contour_integral: bool = True, text_additional: Callable[[Cycles[TNumber]], str] | None = None, set_limits: bool = True, ) -> None: """Plot the boundaries and the eigenvalues.""" plot( boundaries=self.boundaries, cycles=self.cycles, generations=self.generations, branching_points=self.branching_points, text_contour_integral=text_contour_integral, text_additional=text_additional, set_limits=set_limits, type="circle" if isinstance(self.f_boundary, CirclesBoundary) else "square", )
def plot( *, boundaries: Mapping[TKey, Sequence[TNumber]] | Sequence[TNumber], cycles: Mapping[TKey, Cycles[TNumber]] | Cycles[TNumber], generations: Mapping[TKey, int] | None = None, branching_points: Sequence[TNumber] | None = None, text_contour_integral: bool = True, text_additional: Callable[[Cycles[TNumber]], str] | None = None, set_limits: bool = False, type: Literal["circle", "square"] = "circle", ) -> None: """Plot the boundaries and the eigenvalues.""" sns.set_theme() _, ax = plt.subplots(2, 2, figsize=(20, 20), layout="constrained") ax: Sequence[plt.Axes] = ax.reshape(-1) # type: ignore if isinstance(boundaries, Mapping): boundaries_: Mapping[Any, Sequence[TNumber]] = boundaries else: boundaries_ = {None: boundaries} if isinstance(cycles, Cycles): cycles_: Mapping[Any, Cycles[TNumber]] = {None: cycles} else: cycles_ = cycles del boundaries, cycles cmap = colormap(type="oklch") has_multiple_generations = ( generations is not None and np.unique(list(generations.values())).size > 1 ) has_multiple = len(boundaries_) > 1 for ik, k in enumerate(boundaries_): i = generations.get(k, 0) if generations is not None else 0 boundary_ = np.asarray(boundaries_[k]) cycle_: Cycles[TNumber] = cycles_[k] if not has_multiple_generations: color = plt.get_cmap("twilight")(np.linspace(0, 1, len(boundary_)))[:, :3] else: color = cmap( np.linspace(0, 1, len(boundary_)), 1 - i / (len(boundaries_) - 1), ) ax[0].scatter( boundary_.real, boundary_.imag, c=color, ) prefix = "" if not has_multiple: prefix = f"B{ik}-{prefix}" if not has_multiple_generations: prefix = f"G{i}-{prefix}" if cycle_.max_cycle_length > 1: ax[0].text( boundary_[0].real, boundary_[0].imag, f"{prefix}{cycle_.max_cycle_length}", fontsize=8, ) ax[1].scatter( cycle_.incomplete_eigvals.real.flatten(), cycle_.incomplete_eigvals.imag.flatten(), marker="o", c=np.broadcast_to(color[None, :, :], (*cycle_.incomplete_eigvals.shape[:2], 3)).reshape( -1, 3 ), ) for cycle in cycle_.cycles: colored_line( cycle.real.T, cycle.imag.T, c=color[:, None, :], ax=ax[3 if cycle.shape[0] == 1 else 2], ) for j in range(cycle.shape[0]): text = f"{prefix}C{cycle.shape[0]}-{j}" if text_contour_integral and j == 0: contour_integral_abs = np.abs( contour_integral( np.sum(cycle, axis=0) ** len(cycle) - np.prod(cycle * len(cycle), axis=0), type=type, ) ) text += f"\n∫: {contour_integral_abs:.3g}" ax[3 if cycle.shape[0] == 1 else 2].text( cycle[j, 0].real, cycle[j, 0].imag, text, fontsize=8, ) if branching_points is not None: ax[0].plot(np.real(branching_points), np.imag(branching_points), "x") ax[0].set_title("Trace of the parameter (p)") ax[0].set_xlabel("Re p") ax[0].set_ylabel("Im p") ax[1].set_title("Scatter plot of the eigenvalues (λ)") ax[2].set_title("Trace of the eigenvalues \nwhich period > 1 (λ)") ax[3].set_title("Trace of the eigenvalues \nwhich period = 1 (λ)") for ax_ in ax[1:]: ax_.set_xlabel("Re λ") ax_.set_ylabel("Im λ") if set_limits: xlim = ax[1].get_xlim() ylim = ax[1].get_ylim() for ax_ in ax[2:]: ax_.set_xlim(xlim) ax_.set_ylim(ylim)
[docs] def find_branching_points_recursively( f_eigvals: Callable[[Sequence[TNumber]], Sequence[Sequence[TNumber]]], f_boundary: BoundaryGenerator[TKey, TNumber], /, *, f_go_further: Callable[[Cycles[TNumber]], bool] | None = None, f_final: Callable[[Cycles[TNumber]], bool] | None = None, f_plot: Callable[[int | None, int | None], None] | None = None, eigvals_analytic: bool = True, rtol_analytic: float | None = None, depth_first: bool = False, depth_first_and_break: bool = False, ) -> FindExceptionalPointsRecursivelyResult[TKey, TNumber]: """ Search for branching points recursively. Parameters ---------- f_eigvals : Callable[[Sequence[TNumber]], Sequence[Sequence[TNumber]]] A function that takes a batch of parameters and returns the eigenvalues. f_boundary : BoundaryGenerator[TKey] A function that takes the mapping key and whether a branching point is found inside the boundary, and returns the new boundaries f_go_further : Callable[[Cycles[TNumber]], bool], optional A function that takes the cycles and returns whether to go further, by default None f_final : Callable[[Cycles[TNumber]], bool], optional A function that takes the cycles and returns whether the boundary is final, by default None f_plot : Callable[[int | None, int | None], None], optional A function that takes the iteration number and boundary key and plots the boundaries, by default None eigvals_analytic : bool, optional Whether the eigenvalues are supposed to be analytic on the region except for the branching points, by default True. If True, the function will check if the eigenvalues follow Cauchy's integral formula as well. rtol_analytic : float, optional The relative tolerance for the analytic check, by default None. depth_first : bool, optional Whether to use depth-first search for finding branching points, by default False. If True, the search will be depth-first, otherwise it will be breadth-first. depth_first_and_break : bool, optional Whether to use depth-first search and break when the first branching point is found, by default False. If True, the search will be depth-first and will break when the first branching point is found. Returns ------- FindExceptionalPointsRecursivelyResult[TKey] The branching points. """ if f_go_further is None: def f_go_further(cycles: Cycles[TNumber]) -> bool: return cycles.max_cycle_length > 1 boundaries = dict(f_boundary({})[0]) boundaries_stack: list[tuple[TKey, Sequence[TNumber]]] = list(boundaries.items()) eigvals = {} cycles = {} generations = dict.fromkeys(boundaries.keys(), 0) final_keys: list[TKey] = [] while boundaries_stack: # check has_branching for new boundaries k, boundary = boundaries_stack.pop() eigval = f_eigvals(boundary) if len(eigval) != len(boundary): raise ValueError( "f_eigvals must return the same number of eigenvalues " "as the number of points, but " f"len(f(points))={len(eigval)} != len(points)={len(boundary)}" ) cycle: Cycles[TNumber] = get_cycles(eigval) cycles[k] = cycle eigvals[k] = eigval generation = generations[k] has_inside = f_go_further(cycle) final = has_inside and f_final(cycle) if f_final is not None else False print(f"Generation {generation}, key {k}: {has_inside}") # analytic check if eigvals_analytic is True and Circle boundary if isinstance(f_boundary, (CirclesBoundary, RectsBoundary)) and eigvals_analytic: has_inside = has_inside or not is_analytic( cycle.eigvals, rtol=rtol_analytic, type="circle" if isinstance(f_boundary, CirclesBoundary) else "square", ) print(f"Generation {generation}, key {k}: {has_inside}") # plot if f_plot is not None: plot( boundaries=boundary, cycles=cycle, generations=None, type="circle" if isinstance(f_boundary, CirclesBoundary) else "square", ) f_plot(generation, len(cycles)) if final: final_keys.append(k) boundaries_new: Mapping[TKey, Sequence[TNumber]] = {} else: # get new boundaries based on the previous result boundaries_new, final_keys_based_on_f_boundary = f_boundary({k: has_inside}) final_keys.extend(final_keys_based_on_f_boundary) # break if depth_first and depth_first_and_break if depth_first_and_break and final_keys: break # append new boundaries and eigenvalues boundaries.update(boundaries_new) if not depth_first: boundaries_stack = list(boundaries_new.items()) + boundaries_stack else: boundaries_stack.extend(boundaries_new.items()) generations.update(dict.fromkeys(boundaries_new.keys(), generation + 1)) result = FindExceptionalPointsRecursivelyResult( boundaries={k: boundary for k, boundary in boundaries.items() if k in eigvals}, eigvals=eigvals, cycles=cycles, generations=generations, keys=list(dict.fromkeys(final_keys)), # remove duplicates branching_points=list(f_boundary.keys_to_values(final_keys)), f_boundary=f_boundary, ) if f_plot is not None: result.plot() f_plot(None, None) return result
[docs] @attrs.frozen(kw_only=True) class Circle: """A circle.""" radius: float """The radius of the circle.""" center: complex """The center of the circle."""
[docs] @attrs.frozen(kw_only=True) class CirclesBoundary(BoundaryGenerator[Circle, complex]): """ Divide-and-conquer search using circles. The circles circumscribe the rectangular region. The search region is a square [Re center - radius/sqrt(2), Re center + radius/sqrt(2)] x [Im center - radius/sqrt(2), Im center + radius/sqrt(2)]. Parameters ---------- center : complex The center of the circle. radius : float The radius of the circle. radius_min : float The radius threshold to stop the recursion. n_points : int The number of points on the circle. extra_ratio : float, optional The extra ratio to enlarge the circle to avoid the corners of the square to be missed, by default 0.1. Must be positive or zero. """ center: complex radius: float radius_min: float n_points: int extra_ratio: float = 0.1 def _circle(self, *, center: complex, radius: float) -> tuple[Circle, Sequence[complex]]: points = center + radius * (1 + self.extra_ratio) * np.exp( 2j * np.pi * np.arange(self.n_points) / self.n_points ) return Circle(center=center, radius=radius), points def __call__( self, go_further: Mapping[Circle, bool], / ) -> tuple[Mapping[Circle, Sequence[complex]], Sequence[Circle]]: """ Divide-and-conquer search using circles. The circles circumscribe the rectangular region. """ final_keys = [] if not go_further: return dict([self._circle(center=self.center, radius=self.radius)]), [] else: result: dict[Circle, Sequence[complex]] = {} for circle, branching in go_further.items(): if not branching: continue if circle.radius < self.radius_min: final_keys.append(circle) continue result.update( dict( [ self._circle( center=circle.center + circle.radius / 2 / np.sqrt(2) * (i + j * 1j), radius=circle.radius / 2, ) for i in [-1, 1] for j in [-1, 1] ] ) ) return result, final_keys
[docs] def keys_to_values(self, keys: Sequence[Circle], /) -> Sequence[complex]: """ Get the final candidates. Parameters ---------- keys : Sequence[Circle] The keys of the final candidates. Returns ------- Sequence[complex] The final candidates. """ centers = np.asarray([circle.center for circle in keys]) radii = np.asarray([circle.radius for circle in keys]) result = [] for i in range(len(centers)): if (np.abs(centers[i] - centers[:i]) > radii[i] + radii[:i]).all(): result.append(centers[i]) return result
[docs] @attrs.frozen(kw_only=True) class Rect: half_size: complex center: complex @property def radius(self) -> float: return abs(self.half_size)
[docs] @attrs.frozen(kw_only=True) class RectsBoundary(BoundaryGenerator[Rect, complex]): """ Divide-and-conquer search using rectangles. Parameters ---------- center : complex The center of the rectangle. half_size : complex The half size of the rectangle (width/2 + 1j * height/2). half_size_min : complex The half size threshold to stop the recursion. n_points_per_side : int The number of points per side on the rectangle. extra_ratio : float, optional The extra ratio to enlarge the rectangle to avoid the corners to be missed, by default 0.1. Must be positive or zero. """ center: complex half_size: complex half_size_min: complex n_points_per_side: int extra_ratio: float = 0.1 def _rect(self, *, center: complex, half_size: complex) -> tuple[Rect, Sequence[complex]]: half_size = half_size * (1 + self.extra_ratio) arranged = 2 * np.arange(self.n_points_per_side) / self.n_points_per_side - 1 points = center + np.concat( [ # right top half_size.real + 1j * half_size.imag * arranged, # right bottom 1j * half_size.imag - half_size.real * arranged, # left bottom -half_size.real - 1j * half_size.imag * arranged, # left top -1j * half_size.imag + half_size.real * arranged, ], axis=0, ) return Rect(center=center, half_size=half_size), points def __call__( self, go_further: Mapping[Rect, bool], / ) -> tuple[Mapping[Rect, Sequence[complex]], Sequence[Rect]]: final_keys = [] if not go_further: return dict([self._rect(center=self.center, half_size=self.half_size)]), [] else: result: dict[Rect, Sequence[complex]] = {} for rect, branching in go_further.items(): if not branching: continue if ( rect.half_size.real < self.half_size_min.real and rect.half_size.imag < self.half_size_min.imag ): final_keys.append(rect) continue result.update( dict( [ self._rect( center=rect.center + rect.half_size.real / 2 * i + 1j * rect.half_size.imag / 2 * j, half_size=rect.half_size / 2, ) for i in [-1, 1] for j in [-1, 1] ] ) ) return result, final_keys
[docs] def keys_to_values(self, keys: Sequence[Rect], /) -> Sequence[complex]: return [rect.center for rect in keys]