cyjax.ml.BatchSampler#

class cyjax.ml.BatchSampler(seed, variety, params_sampler, batch_size_params=5, batch_size=100, buffer_size=20, backend='cpu', device=None)#

Bases: object

__init__(seed, variety, params_sampler, batch_size_params=5, batch_size=100, buffer_size=20, backend='cpu', device=None)#

Iterable buffered sample generator.

Once initialized, new samples can be obtained as >>> params, zs, patch, weights = next(batch_sampler)

Note that creating the sampler can take several seconds because the sampling function is jit-compiled and the first batch is sampled.

Parameters:
  • seed – Random key or integer.

  • variety – Variety to sample for.

  • params_sampler (Callable[[Union[PRNGKeyArray, Array], int], Union[Array, ndarray, bool_, number]]) – Function which samples new complex moduli values given a random key and an integer batch size.

  • batch_size_params (int) – Number of complex moduli values to sample for each batch.

  • batch_size (int) – Number of points on the variety per moduli value.

  • buffer_size (int) – Number of samples to keep in buffer.

  • backend – Backend used for sampling.

  • device – Device used for training. Samples are transferred to this device after they are generated.

Methods

__init__(seed, variety, params_sampler[, ...])

Iterable buffered sample generator.

new_buffer()

Generate and buffer new samples.

new_buffer()#

Generate and buffer new samples.