In [1]:
# manually set mode to 64 bits
from jax.config import config
config.update("jax_enable_x64", True)

In [2]:
# import the usual libraries
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

In [3]:
import cyjax

# random number sequence
rns = cyjax.util.PRNGSequence(42)

# Algebraic varieties
The implementation of varieties delegates most work to the `HomPoly` class which itself is a `Poly`.
These, by default, build on [sympy](https://sympy.org/) and add some functionality we need in this application.
Other internal representations, e.g. not making use of sympy at all, could be implemented by creating a new subclass of `Poly` and respectively `HomPoly`.

Besides `Dwork` and `Fermat` varieties, which are given for convenience, other varieties can be easily specified by giving their defining variety in a sympy-style string.

In [4]:
cyjax.Dwork()

Dwork(dim=3) in CP^4: DworkPoly([z_0, z_1, z_2, z_3, z_4], params=[psi]) = -5*psi*z0*z1*z2*z3*z4 + z0**5 + z1**5 + z2**5 + z3**5 + z4**5

In [5]:
cyjax.Fermat()  # special case where psi is always 0

Fermat(dim=3) in CP^4: FermatPoly([z_0, z_1, z_2, z_3, z_4]) = z0**5 + z1**5 + z2**5 + z3**5 + z4**5

## Polynomials
Internally, the varieties rely a lot on the `Poly` and `HomPoly` classes.
However, it may not be necessary to interact with them directly in some cases as most functionality is wrapped by the variety class and managed automatically.

In [6]:
# Always use *.from_sympy to create new instances in the most convenient way!
poly = cyjax.HomPoly.from_sympy(
    'z_0**5 + z_1**5 + z_2**5 + z_3**5 + z_4**5 - 5 * psi * z_0 * z_1 * z_2 * z_3 * z_4',
    variable_names=['z'],
    # can also be automatically detected in this case
    variable_dim=[5],
    # if omitted, automatically detected and ordered alphabetically
    parameters=['psi'],
)

poly

HomPoly([z_0, z_1, z_2, z_3, z_4], params=[psi]) = -5*psi*z0*z1*z2*z3*z4 + z0**5 + z1**5 + z2**5 + z3**5 + z4**5

In [7]:
# affine polynomials are implicitly defined
# and generated on the fly
poly.affine_poly(patch=1)

Poly([z_0, z_2, z_3, z_4], params=[psi]) = -5*psi*z0*z2*z3*z4 + z0**5 + z2**5 + z3**5 + z4**5 + 1

In [8]:
# can evaluate for single set of values
z = np.array([1, 2, 3, 2, 4])
params = np.array([3])
poly(z, params)

612.0

In [9]:
# evaluate in affine patch
z = np.array([2, 3, 2, 4])
params = np.array([3])
poly(z, params, patch=0)

612.0

When calling a `Poly` object to evaluate it, inputs can be numeric or symbolic.
However, they have to be a single set of values and cannot have a batch dimension.
A JAX-compatible evaluation function which properly handles batch dimension can be generated using the `transform_eval` method.
Varieties automatically do this and expose the fast method as `defining`.
The underlying polynomial of a variety can still be accessed via the `defining_poly` attribute.

## Custom varieties
Although `HomPoly` can handle multiple complex projective spaces, currently only varieties with a single defining equation in a single ambient projective space are supported.

Any defining polynomial can be used, with an arbitrary combination of parameters (corresponding to complex moduli).

In [10]:
# this is a manual version of the Dwork varieties plus an additional parameter
mnfd = cyjax.VarietySingle.from_sympy(
    'z_0**5 + z_1**5 + z_2**5 + z_3**5 + z_4**5 - 5 * psi * z_0 * z_1 * z_2 * z_3 * z_4 + a * z_4**4 * z_0')
mnfd

VarietySingle(dim=3) in CP^4: HomPoly([z_0, z_1, z_2, z_3, z_4], params=[a, psi]) = a*z0*z4**4 - 5*psi*z0*z1*z2*z3*z4 + z0**5 + z1**5 + z2**5 + z3**5 + z4**5

In [11]:
# the most basic thing to do is to evaluate the defining equation
zs, patch = cyjax.random.uniform_projective(next(rns), 200, mnfd.dim_projective)
params = jnp.array([0.1, 1.0 + 0.5j])

In [12]:
mnfd.defining(zs, params, patch).shape

(200,)