cyjax.ml.BatchSampler#
- class cyjax.ml.BatchSampler(seed, variety, params_sampler, batch_size_params=5, batch_size=100, buffer_size=20, backend='cpu', device=None)#
Bases:
object
- __init__(seed, variety, params_sampler, batch_size_params=5, batch_size=100, buffer_size=20, backend='cpu', device=None)#
Iterable buffered sample generator.
Once initialized, new samples can be obtained as >>> params, zs, patch, weights = next(batch_sampler)
Note that creating the sampler can take several seconds because the sampling function is jit-compiled and the first batch is sampled.
- Parameters:
seed – Random key or integer.
variety – Variety to sample for.
params_sampler (
Callable
[[Union
[PRNGKeyArray
,Array
],int
],Union
[Array
,ndarray
,bool_
,number
]]) – Function which samples new complex moduli values given a random key and an integer batch size.batch_size_params (
int
) – Number of complex moduli values to sample for each batch.batch_size (
int
) – Number of points on the variety per moduli value.buffer_size (
int
) – Number of samples to keep in buffer.backend – Backend used for sampling.
device – Device used for training. Samples are transferred to this device after they are generated.
Methods
__init__
(seed, variety, params_sampler[, ...])Iterable buffered sample generator.
Generate and buffer new samples.
- new_buffer()#
Generate and buffer new samples.