cyjax.ml.HNetMLP#
- class cyjax.ml.HNetMLP(basis_size, layer_sizes, dropout_rates=None, powers=(1, 2, 3), sig_suppress=True, init_fluctuation=0.001, activation=<PjitFunction of <function sigmoid>>, feature_angle=True, feature_abs=True, feature_parts=True, parent=<flax.linen.module._Sentinel object>, name=None)#
Dense network for learning moduli dependence of the H matrix.
Given moduli \(\psi_i\) as inputs, first a number of features are built form this which are then fed into a chosen number of dense linear layers. Finally, a linear layer without activation function is used to construct the parameters of the H matrix. If
sig_suppressis true, two copies are output in this way and the final outputs areout[0] * sigmoid(out[1]).The input features are constructed as follows. First, raise each moduli parameter to all chosen powers \(p_j\), \(\psi_i^{p_j}\). Then, take all possible products of these (in the case of multiple moduli). Calling the resulting products \(f_n\), the input features are a concatenation of
\(|f_n|\) if
feature_absis true.\(arg(f_n)\) if
feature_angleis true.\(Re[f_n]\) and \(Im[f_n]\) if
feature_partsis true.
Note that 0 is always included as a power if there are multiple moduli, even if not explicitly chosen.
- Parameters:
basis_size (
int) – Number of sections used.layer_sizes (
Sequence[int]) – Number of hidden dense layers.dropout_rates (
Optional[Sequence[Optional[float]]]) – Dropout rates per layer. Either None, or a list of floats/None where the latter indicates no dropout.powers (
Sequence[Union[int,float]]) – Powers to use to extract features from moduli parameter inputs.sig_suppress (
bool) – Whether to learn an output which is transformed by sigmoid and multiplies the entries of the H-matrix.init_fluctuation (
float) – Initial fluctuation of the diagonal around 1.activation (
Callable[[Array],Array]) – Activation function for hidden layers.feature_angle (
bool) – Whether to include the complex angle of the moduli powers as feature.feature_abs (
bool) – Whether to include the absolute value of the moduli powers as features.feature_parts (
bool) – Whether to use real and imaginary parts of moduli powers as features.
- __init__(basis_size, layer_sizes, dropout_rates=None, powers=(1, 2, 3), sig_suppress=True, init_fluctuation=0.001, activation=<PjitFunction of <function sigmoid>>, feature_angle=True, feature_abs=True, feature_parts=True, parent=<flax.linen.module._Sentinel object>, name=None)#
Methods
__init__(basis_size, layer_sizes[, ...])activation()Sigmoid activation function.
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone(*[, parent, _deep_clone])Creates a clone of this Module, with optionally updated arguments.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name name exists.
has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection(col)Returns true if the collection col is mutable.
lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
make_rng(name)Returns a new RNG key from a given RNG sequence for this Module.
param(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup()Initializes a Module lazily (similar to a lazy
__init__).sow(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
Attributes
dropout_ratesDropout rates per layer.
feature_absWhether to use absolute values as input feature.
feature_angleWhether to use complex angle as input feature.
feature_partsWhether to use real & imaginary part as input feature.
init_fluctuationFluctuation of initialization around the identity (of diagonal entries).
nameparentpowersPowers to use to extract features from moduli parameter inputs.
scopesig_suppressWhether to use learnable sigmoid-suppression of matrix elements.
variablesReturns the variables in this module.
basis_sizeNumber of sections used.
layer_sizesNumber of hidden dense linear layers.