# manually set mode to 64 bits
from jax.config import config
config.update("jax_enable_x64", True)
# import the usual libraries
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import cyjax

# random number sequence
rns = cyjax.util.PRNGSequence(42)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Algebraic varieties#

The implementation of varieties delegates most work to the HomPoly class which itself is a Poly. These, by default, build on sympy and add some functionality we need in this application. Other internal representations, e.g. not making use of sympy at all, could be implemented by creating a new subclass of Poly and respectively HomPoly.

Besides Dwork and Fermat varieties, which are given for convenience, other varieties can be easily specified by giving their defining variety in a sympy-style string.

cyjax.Dwork()
\[\operatorname{Dwork}(dim=3) \subset \mathbb{CP}^4: \displaystyle \operatorname{DworkPoly}([z_{0},z_{1},z_{2},z_{3},z_{4}], \text{params}=[\psi]) = - 5 \psi z_{0} z_{1} z_{2} z_{3} z_{4} + z_{0}^{5} + z_{1}^{5} + z_{2}^{5} + z_{3}^{5} + z_{4}^{5}\]
cyjax.Fermat()  # special case where psi is always 0
\[\operatorname{Fermat}(dim=3) \subset \mathbb{CP}^4: \displaystyle \operatorname{FermatPoly}([z_{0},z_{1},z_{2},z_{3},z_{4}]) = z_{0}^{5} + z_{1}^{5} + z_{2}^{5} + z_{3}^{5} + z_{4}^{5}\]

Polynomials#

Internally, the varieties rely a lot on the Poly and HomPoly classes. However, it may not be necessary to interact with them directly in some cases as most functionality is wrapped by the variety class and managed automatically.

# Always use *.from_sympy to create new instances in the most convenient way!
poly = cyjax.HomPoly.from_sympy(
    'z_0**5 + z_1**5 + z_2**5 + z_3**5 + z_4**5 - 5 * psi * z_0 * z_1 * z_2 * z_3 * z_4',
    variable_names=['z'],
    # can also be automatically detected in this case
    variable_dim=[5],
    # if omitted, automatically detected and ordered alphabetically
    parameters=['psi'],
)

poly
\[\displaystyle \operatorname{HomPoly}([z_{0},z_{1},z_{2},z_{3},z_{4}], \text{params}=[\psi]) = - 5 \psi z_{0} z_{1} z_{2} z_{3} z_{4} + z_{0}^{5} + z_{1}^{5} + z_{2}^{5} + z_{3}^{5} + z_{4}^{5}\]
# affine polynomials are implicitly defined
# and generated on the fly
poly.affine_poly(patch=1)
\[\displaystyle \operatorname{Poly}([z_{0},z_{2},z_{3},z_{4}], \text{params}=[\psi]) = - 5 \psi z_{0} z_{2} z_{3} z_{4} + z_{0}^{5} + z_{2}^{5} + z_{3}^{5} + z_{4}^{5} + 1\]
# can evaluate for single set of values
z = np.array([1, 2, 3, 2, 4])
params = np.array([3])
poly(z, params)
612.0
# evaluate in affine patch
z = np.array([2, 3, 2, 4])
params = np.array([3])
poly(z, params, patch=0)
612.0

When calling a Poly object to evaluate it, inputs can be numeric or symbolic. However, they have to be a single set of values and cannot have a batch dimension. A JAX-compatible evaluation function which properly handles batch dimension can be generated using the transform_eval method. Varieties automatically do this and expose the fast method as defining. The underlying polynomial of a variety can still be accessed via the defining_poly attribute.

Custom varieties#

Although HomPoly can handle multiple complex projective spaces, currently only varieties with a single defining equation in a single ambient projective space are supported.

Any defining polynomial can be used, with an arbitrary combination of parameters (corresponding to complex moduli).

# this is a manual version of the Dwork varieties plus an additional parameter
mnfd = cyjax.VarietySingle.from_sympy(
    'z_0**5 + z_1**5 + z_2**5 + z_3**5 + z_4**5 - 5 * psi * z_0 * z_1 * z_2 * z_3 * z_4 + a * z_4**4 * z_0')
mnfd
\[\operatorname{VarietySingle}(dim=3) \subset \mathbb{CP}^4: \displaystyle \operatorname{HomPoly}([z_{0},z_{1},z_{2},z_{3},z_{4}], \text{params}=[a,\psi]) = a z_{0} z_{4}^{4} - 5 \psi z_{0} z_{1} z_{2} z_{3} z_{4} + z_{0}^{5} + z_{1}^{5} + z_{2}^{5} + z_{3}^{5} + z_{4}^{5}\]
# the most basic thing to do is to evaluate the defining equation
zs, patch = cyjax.random.uniform_projective(next(rns), 200, mnfd.dim_projective)
params = jnp.array([0.1, 1.0 + 0.5j])
mnfd.defining(zs, params, patch).shape
(200,)