cyjax.ml.HNetMLP#

class cyjax.ml.HNetMLP(basis_size, layer_sizes, dropout_rates=None, powers=(1, 2, 3), sig_suppress=True, init_fluctuation=0.001, activation=<PjitFunction of <function sigmoid>>, feature_angle=True, feature_abs=True, feature_parts=True, parent=<flax.linen.module._Sentinel object>, name=None)#

Dense network for learning moduli dependence of the H matrix.

Given moduli \(\psi_i\) as inputs, first a number of features are built form this which are then fed into a chosen number of dense linear layers. Finally, a linear layer without activation function is used to construct the parameters of the H matrix. If sig_suppress is true, two copies are output in this way and the final outputs are out[0] * sigmoid(out[1]).

The input features are constructed as follows. First, raise each moduli parameter to all chosen powers \(p_j\), \(\psi_i^{p_j}\). Then, take all possible products of these (in the case of multiple moduli). Calling the resulting products \(f_n\), the input features are a concatenation of

  • \(|f_n|\) if feature_abs is true.

  • \(arg(f_n)\) if feature_angle is true.

  • \(Re[f_n]\) and \(Im[f_n]\) if feature_parts is true.

Note that 0 is always included as a power if there are multiple moduli, even if not explicitly chosen.

Parameters:
  • basis_size (int) – Number of sections used.

  • layer_sizes (Sequence[int]) – Number of hidden dense layers.

  • dropout_rates (Optional[Sequence[Optional[float]]]) – Dropout rates per layer. Either None, or a list of floats/None where the latter indicates no dropout.

  • powers (Sequence[Union[int, float]]) – Powers to use to extract features from moduli parameter inputs.

  • sig_suppress (bool) – Whether to learn an output which is transformed by sigmoid and multiplies the entries of the H-matrix.

  • init_fluctuation (float) – Initial fluctuation of the diagonal around 1.

  • activation (Callable[[Array], Array]) – Activation function for hidden layers.

  • feature_angle (bool) – Whether to include the complex angle of the moduli powers as feature.

  • feature_abs (bool) – Whether to include the absolute value of the moduli powers as features.

  • feature_parts (bool) – Whether to use real and imaginary parts of moduli powers as features.

__init__(basis_size, layer_sizes, dropout_rates=None, powers=(1, 2, 3), sig_suppress=True, init_fluctuation=0.001, activation=<PjitFunction of <function sigmoid>>, feature_angle=True, feature_abs=True, feature_parts=True, parent=<flax.linen.module._Sentinel object>, name=None)#

Methods

__init__(basis_size, layer_sizes[, ...])

activation()

Sigmoid activation function.

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent, _deep_clone])

Creates a clone of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

make_rng(name)

Returns a new RNG key from a given RNG sequence for this Module.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

Attributes

dropout_rates

Dropout rates per layer.

feature_abs

Whether to use absolute values as input feature.

feature_angle

Whether to use complex angle as input feature.

feature_parts

Whether to use real & imaginary part as input feature.

init_fluctuation

Fluctuation of initialization around the identity (of diagonal entries).

name

parent

powers

Powers to use to extract features from moduli parameter inputs.

scope

sig_suppress

Whether to use learnable sigmoid-suppression of matrix elements.

variables

Returns the variables in this module.

basis_size

Number of sections used.

layer_sizes

Number of hidden dense linear layers.