Source code for epsearch._diff_zeros

import warnings
from collections.abc import Callable, Sequence
from typing import Any, Literal

import attrs
import numpy as np
import scipy.optimize
from numpy.typing import NDArray
from ss_hankel import SSHKwargs, ss_h_circle

from ._branching import (
    BoundaryGenerator,
    Circle,
    FindExceptionalPointsRecursivelyResult,
    find_branching_points_recursively,
    is_analytic,
)
from ._cycle import Cycles, TNumber, get_cycles


def find_branching_points_using_zeros_ssh(
    f_eigvals: Callable[[NDArray[TNumber]], NDArray[TNumber]],
    /,
    **kwargs: SSHKwargs,
) -> Sequence[TNumber]:
    """
    Search for branching points using zero-finding methods.

    Parameters
    ----------
    f_eigvals : Callable[[NDArray[Any]], NDArray[Any]]
        A function that takes a batch of parameters and returns the eigenvalues.
    **kwargs : SSHKwargs
        Keyword arguments for the Sakurai-Sugiura method.

    Returns
    -------
    Sequence[TNumber]
        The branching points.

    """

    def f(p: NDArray[TNumber], /) -> NDArray[TNumber]:
        cycles: Cycles[TNumber] = get_cycles(f_eigvals(p))
        if cycles.max_cycle_length == 1:
            raise ValueError("No branching points found.")
        cycle_diffs = np.stack(
            [
                np.sum(cycle, axis=0) ** len(cycle) - np.prod(cycle * len(cycle), axis=0)
                for cycle in cycles.cycles
                if cycle.shape[0] > 1
            ],
            axis=-1,
        )
        return cycle_diffs[..., np.diag(np.arange(cycle_diffs.shape[-1]))]

    return ss_h_circle(f, **kwargs).eigval


def find_branching_points_using_zeros_scipy(
    f_eigvals: Callable[[NDArray[TNumber]], NDArray[TNumber]],
    x0: TNumber,
    eigval: TNumber,
    multiplicity: int,
    /,
    **kwargs: Any,
) -> TNumber:
    """
    Search for branching points using scipy's root-finding methods.

    Parameters
    ----------
    f_eigvals : Callable[[NDArray[Any]], NDArray[Any]]
        A function that takes a batch of parameters and returns the eigenvalues.
    x0 : TNumber
        The initial guess for the root-finding method.
    eigval : TNumber
        The expected eigenvalue at the branching point.
    multiplicity : int
        The multiplicity of the eigenvalue at the branching point.
    **kwargs : Any
        Keyword arguments for ``scipy.optimize.root``.

    Returns
    -------
    Sequence[TNumber]
        The branching points.

    """

    def f(p: NDArray[TNumber], /) -> NDArray[TNumber]:
        p = p[0] + 1j * p[1]
        eigvals = f_eigvals(p[None])[0]
        argsort = np.argsort(np.abs(eigvals - eigval))
        eigvals_near = eigvals[argsort[:multiplicity]]
        objective = np.abs(
            np.sum(eigvals_near, axis=0) ** multiplicity
            - np.prod(eigvals_near * multiplicity, axis=0)
        )
        return np.stack((objective.real, objective.imag), axis=0)

    res = scipy.optimize.root(f, np.stack((x0.real, x0.imag)), **kwargs).x
    return res[0] + 1j * res[1]


[docs] def find_branching_points_recursively_hybrid( f_eigvals: Callable[[Sequence[TNumber]], Sequence[Sequence[TNumber]]], f_boundary: BoundaryGenerator[Circle, TNumber], /, f_plot: Callable[[int | None, int | None], None] | None = None, rtol_analytic: float | None = None, ssh_kwargs: SSHKwargs | None = None, scipy_kwargs: Any | None = None, method: Literal["ssh", "scipy"] = "scipy", depth_first: bool = False, depth_first_and_break: bool = False, ) -> FindExceptionalPointsRecursivelyResult[Circle, TNumber]: """ Search for branching points recursively. Parameters ---------- f_eigvals : Callable[[Sequence[TNumber]], Sequence[Sequence[TNumber]]] A function that takes a batch of parameters and returns the eigenvalues. f_boundary : BoundaryGenerator[Circle] A function that takes the mapping key and whether a branching point is found inside the boundary, and returns the new boundaries f_plot : Callable[[int | None, int | None], None], optional A function that takes the iteration number and boundary key and plots the boundaries, by default None rtol_analytic : float, optional The relative tolerance for the analytic check, by default None. ssh_kwargs : SSHKwargs, optional Keyword arguments for the Sakurai-Sugiura method. scipy_kwargs : Any, optional Keyword arguments for the scipy root-finding method. method : Literal["ssh", "scipy"], optional The method to use for finding branching points, by default "scipy". If "ssh", the Sakurai-Sugiura method is used. If "scipy", scipy's root-finding method is used. depth_first : bool, optional Whether to use depth-first search for finding branching points, by default False. If True, the search will be depth-first, otherwise it will be breadth-first. depth_first_and_break : bool, optional Whether to use depth-first search and break when the first branching point is found, by default False. If True, the search will be depth-first and will break when the first branching point is found. Returns ------- FindExceptionalPointsRecursivelyResult[Circle] The branching points. """ def f_final(cycles: Cycles[TNumber]) -> bool: # check all cycle diffs ** cycle_length # are analytic for cycle in cycles.cycles: if cycle.shape[0] < 2: continue if cycle.shape[0] > 2: return False # TODO: remove if not is_analytic( np.sum(cycle, axis=0) ** len(cycle) - np.prod(cycle * len(cycle), axis=0), rtol=rtol_analytic, ): return False return True result = find_branching_points_recursively( f_eigvals, f_boundary, f_final=f_final, f_plot=f_plot, eigvals_analytic=True, depth_first=depth_first, depth_first_and_break=depth_first_and_break, ) branching_points: list[TNumber] = [] for circle in result.keys: if method == "ssh": try: branching_points.extend( find_branching_points_using_zeros_ssh( f_eigvals, circle_center=circle.center, circle_radius=circle.radius, **(ssh_kwargs or {}), ) ) except ValueError as e: warnings.warn(f"No branching points found for {circle}.", source=e, stacklevel=2) elif method == "scipy": try: for connected_cycle in result.cycles[circle].connected_cycles: if connected_cycle.shape[0] == 1: continue branching_points.append( find_branching_points_using_zeros_scipy( f_eigvals, np.mean(result.boundaries[circle]), np.mean(connected_cycle[0]), connected_cycle.shape[0], **(scipy_kwargs or {}), ) ) except ValueError as e: warnings.warn(f"No branching points found for {circle}.", source=e, stacklevel=2) return attrs.evolve(result, branching_points=np.sort_complex(branching_points))