# manually set mode to 64 bits
from jax.config import config
config.update("jax_enable_x64", True)
# import the usual libraries
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import cyjax

# random number sequence
rns = cyjax.util.PRNGSequence(42)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Geometric quantities#

Fubini Study metric#

First sample some points in CP^2…

zs, patches = cyjax.random.uniform_projective(next(rns), 6, 2, affine=True)
zs.shape,patches.shape
((6, 2), (6,))

We have a function for the Kähler potential.

cyjax.fs_potential(zs)
Array([0.73866237+0.j, 0.58205621+0.j, 0.49285102+0.j, 0.41209627+0.j,
       0.3525678 +0.j, 0.81256455+0.j], dtype=complex128)

As well as for the FS metric…

gs = cyjax.fs_metric(zs)
print(gs.shape)
gs[0]
(6, 2, 2)
Array([[ 0.36113571-2.65888522e-18j, -0.05611089-1.11124124e-01j],
       [-0.05611089+1.11124124e-01j,  0.34486433-5.41330631e-18j]],      dtype=complex128)

Numerical vs analytic#

Instead of the explicit expression which is implemented in fs_metric, we can also construct the metric as complex hessian of the potential.

numerical_fs_metric = cyjax.complex_hessian(cyjax.fs_potential)

Note that we need to give a single point as input; we don’t want to differentiate all potential values with respect to all points! Note also that for the differentiation to work, we must give conjugate coordinates explicitly.

gs_num = numerical_fs_metric(zs[0], zs[0].conj())
gs_num
Array([[ 0.36113571-3.13239563e-18j, -0.05611089-1.11124124e-01j],
       [-0.05611089+1.11124124e-01j,  0.34486433+5.08178031e-18j]],      dtype=complex128)
# numerical and explicit expressions are very close
print(jnp.allclose(gs[0], gs_num))
jnp.abs(gs[0] - gs_num)
True
Array([[1.11023312e-16, 1.38777878e-17],
       [0.00000000e+00, 5.64945551e-17]], dtype=float64)

In this particular case, both methods are comparable in speed…

# vmap to take take multiple points as a batch
fs_metric_vmap = jax.jit(jax.vmap(numerical_fs_metric))
fs_metric_vmap(zs, zs.conj())  # don't count compilation time

cyjax.fs_metric(zs, zs.conj())
%timeit -n 100 fs_metric_vmap(zs, zs.conj())
%timeit -n 100 cyjax.fs_metric(zs, zs.conj())
9.64 µs ± 311 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
9.82 µs ± 999 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)

Metric on variety#

The variety has one fewer dimension than the ambient projective space. Eventually, we want a metric on the variety. The varieties expose methods for the pullback, including the pullback of the FS metric onto the manifold.

dwork = cyjax.Dwork(3)
psi = jnp.array([0.1])
zs, patch = dwork.sample_intersect(next(rns), psi, 2, True)

fs, dep = dwork.induced_fs(zs, psi, patch)  # automatically determines best dependent variable
fs.shape
(2, 3, 3)

Donaldson’s algebraic metric#

The central approach here is to use the algebraic Ansatz for the Kähler potential, and thus the metric, used in Donaldson’s algorithm. Specifically, given a set of basis section \(s_\alpha(z)\) represented as homogeneous polynomials/monomials, the Kähler potential is

\[ K(z) = \frac{1}{\pi k} \log \left( \sum_{\alpha\bar{\beta}} s_\alpha(z) H^{\alpha \bar{\beta}} s_{\bar{\beta}}(\bar{z}) \right) \]
Unlike in Donaldson’s algorithm, in the machine learning context it is not important that the \(s_\alpha\) are independent and form a basis.

degree = 4  # homogeneous polynomial degree of s_alpha
# object representing s_alpha
mon_basis = cyjax.donaldson.MonomialBasisFull(dwork.dim_projective, degree)
# algebraic metric
metric = cyjax.donaldson.AlgebraicMetric(dwork, mon_basis)

The AlgebraicMetric object defines functions for a number of geometric objects. For faster code, these functions can be jit-compiled.

h = jnp.eye(metric.sections.size)
# returns the local metric, the patch, the dependent coordinate
g, g_patch, g_dep = metric.metric(h, zs[0], psi, patch[0])
print(g.shape, g_patch, g_dep)
(3, 3) 3 2

Note that dependent is the index in the affine coordinate vector and can thus be numerically the same as the patch. In other words, dependent is given as an affine index.

# note that numeric conventions for prefactors might differ
metric.ricci_scalar(h, zs[0], psi, patch[0])
Array(-10.61229456-3.48819611e-15j, dtype=complex128)

Internals#

The ricci tensor and scalar, the eta accuracy as well as the local metric can be computed by methods of the AlgebraicMetric object. These calculations for these quantities share multiple intermediate quantities. To reduce duplication of code and to make each step testable, internally the computation relies on a GeometricObjects class, which represents a lazily-evaluated computational graph. This means functions for (intermediate) quantities are defined which are recursively evaluated based on the requested objects. While there is already little overhead from this, the cost can be fully absorbed by defining and jit-compiling particular functions which internally create the GeometricObjects computational graph. That is what the AlgebraicMetric class does.

# note that the object itself is not automatically vmap-ed so a single point z should be passed
obj = cyjax.donaldson.GeometricObjects(h, zs[0], psi, dwork, metric.sections, patch=patch[0])
# quantities are computed as requested and cached to avoid re-computation
obj.dependent
Array(2, dtype=int64)
obj.psi  # the argument inside the log(...)
Array(9.79441253+4.07741942e-17j, dtype=complex128)
obj.eta
Array(0.01508826-1.26295717e-19j, dtype=complex128)

Comparison with autodiff#

The metric and ricci tensor are both given in terms of (anti-) holomorphic derivatives of the Kähler potential. Because, for holomorphic derivatives to work, all functions have to take explicit holomorphic and anti-holomorphic inputs, autodiff has to compute duplicate intermediate quantities (which we can identify as complex conjugates). The manual implementation in GeometricObjects is thus more efficient. Nonetheless, we can check the implementation using atuomatic differentiation for comparison.

z = zs[0]
patch = patch[0]
z_c = jnp.conj(z)
# define an explicit Kahler potential function

@jax.jit
def kahler(z, z_c, patch, h, psi):
    s = metric.sections(z, patch)
    # know internal coefficients are real otherwise we would
    # also have to conjugate those...
    s_c = metric.sections(z_c, patch)
    psi = jnp.einsum('i,ij,j', s, h, s_c)
    return jnp.log(psi) / (jnp.pi * metric.degree)
kahler(z, z_c, patch, h, psi)
Array(0.18158084+3.31281459e-19j, dtype=complex128)

The metric is now given by the Hessian \(\hat{g} = \partial_i \partial_{\bar{\jmath}} K(z)\). However, if these derivatives are taken with respect to ambient projective coordinates, we still need to compute the pullback via the Jacobian of the embedding of the variety (given by the defining equation).

@jax.jit
def metric_loc(z, z_c, patch, h, psi):
    metric_proj = cyjax.complex_hessian(kahler)(z, z_c, patch, h, psi)
    metric_loc, _ = dwork.induced_metric(metric_proj, z, psi, patch, z_c)
    return metric_loc
g = metric_loc(z, z_c, patch, h, psi)
# compare with other implementation
jnp.allclose(g, obj.g_loc)
Array(True, dtype=bool)

The ricci curvature tensor is now given by \(-\partial_i \partial_{\bar{\jmath}} \log \det g\).

def logdet(z, z_c, patch, h, psi):
    g = metric_loc(z, z_c, patch, h, psi)
    return jnp.log(jnp.linalg.det(g))

@jax.jit
def ricci_tensor(z, z_c, patch):
    ricci_proj = -cyjax.complex_hessian(logdet)(z, z_c, patch, h, psi)
    ricci_loc, _ = dwork.induced_metric(ricci_proj, z, psi, patch, z_c)
    return ricci_loc
ricci = ricci_tensor(z, z_c, patch)
jnp.allclose(ricci, obj.ricci_loc)
Array(True, dtype=bool)

We can now compare the speed of the manual implementation with the autodiff version here.

ricci_fn = jax.jit(metric.ricci)
ricci_fn(h, z, psi, patch)

print('autodiff time:')
%timeit -n 1000 ricci_tensor(z, z_c, patch).block_until_ready()
print('manual implementation:')
%timeit -n 1000 ricci_fn(h, z, psi, patch)[0].block_until_ready()
autodiff time:
551 µs ± 2.53 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
manual implementation:
329 µs ± 2.14 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)