# manually set mode to 64 bits
from jax.config import config
config.update("jax_enable_x64", True)
# import the usual libraries
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import cyjax
# random number sequence
rns = cyjax.util.PRNGSequence(42)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Algebraic varieties#
The implementation of varieties delegates most work to the HomPoly
class which itself is a Poly
.
These, by default, build on sympy and add some functionality we need in this application.
Other internal representations, e.g. not making use of sympy at all, could be implemented by creating a new subclass of Poly
and respectively HomPoly
.
Besides Dwork
and Fermat
varieties, which are given for convenience, other varieties can be easily specified by giving their defining variety in a sympy-style string.
cyjax.Dwork()
cyjax.Fermat() # special case where psi is always 0
Polynomials#
Internally, the varieties rely a lot on the Poly
and HomPoly
classes.
However, it may not be necessary to interact with them directly in some cases as most functionality is wrapped by the variety class and managed automatically.
# Always use *.from_sympy to create new instances in the most convenient way!
poly = cyjax.HomPoly.from_sympy(
'z_0**5 + z_1**5 + z_2**5 + z_3**5 + z_4**5 - 5 * psi * z_0 * z_1 * z_2 * z_3 * z_4',
variable_names=['z'],
# can also be automatically detected in this case
variable_dim=[5],
# if omitted, automatically detected and ordered alphabetically
parameters=['psi'],
)
poly
# affine polynomials are implicitly defined
# and generated on the fly
poly.affine_poly(patch=1)
# can evaluate for single set of values
z = np.array([1, 2, 3, 2, 4])
params = np.array([3])
poly(z, params)
612.0
# evaluate in affine patch
z = np.array([2, 3, 2, 4])
params = np.array([3])
poly(z, params, patch=0)
612.0
When calling a Poly
object to evaluate it, inputs can be numeric or symbolic.
However, they have to be a single set of values and cannot have a batch dimension.
A JAX-compatible evaluation function which properly handles batch dimension can be generated using the transform_eval
method.
Varieties automatically do this and expose the fast method as defining
.
The underlying polynomial of a variety can still be accessed via the defining_poly
attribute.
Custom varieties#
Although HomPoly
can handle multiple complex projective spaces, currently only varieties with a single defining equation in a single ambient projective space are supported.
Any defining polynomial can be used, with an arbitrary combination of parameters (corresponding to complex moduli).
# this is a manual version of the Dwork varieties plus an additional parameter
mnfd = cyjax.VarietySingle.from_sympy(
'z_0**5 + z_1**5 + z_2**5 + z_3**5 + z_4**5 - 5 * psi * z_0 * z_1 * z_2 * z_3 * z_4 + a * z_4**4 * z_0')
mnfd
# the most basic thing to do is to evaluate the defining equation
zs, patch = cyjax.random.uniform_projective(next(rns), 200, mnfd.dim_projective)
params = jnp.array([0.1, 1.0 + 0.5j])
mnfd.defining(zs, params, patch).shape
(200,)