cyjax.ml.HNetMLP#
- class cyjax.ml.HNetMLP(basis_size, layer_sizes, dropout_rates=None, powers=(1, 2, 3), sig_suppress=True, init_fluctuation=0.001, activation=<PjitFunction of <function sigmoid>>, feature_angle=True, feature_abs=True, feature_parts=True, parent=<flax.linen.module._Sentinel object>, name=None)#
Dense network for learning moduli dependence of the H matrix.
Given moduli \(\psi_i\) as inputs, first a number of features are built form this which are then fed into a chosen number of dense linear layers. Finally, a linear layer without activation function is used to construct the parameters of the H matrix. If
sig_suppress
is true, two copies are output in this way and the final outputs areout[0] * sigmoid(out[1])
.The input features are constructed as follows. First, raise each moduli parameter to all chosen powers \(p_j\), \(\psi_i^{p_j}\). Then, take all possible products of these (in the case of multiple moduli). Calling the resulting products \(f_n\), the input features are a concatenation of
\(|f_n|\) if
feature_abs
is true.\(arg(f_n)\) if
feature_angle
is true.\(Re[f_n]\) and \(Im[f_n]\) if
feature_parts
is true.
Note that 0 is always included as a power if there are multiple moduli, even if not explicitly chosen.
- Parameters:
basis_size (
int
) – Number of sections used.layer_sizes (
Sequence
[int
]) – Number of hidden dense layers.dropout_rates (
Optional
[Sequence
[Optional
[float
]]]) – Dropout rates per layer. Either None, or a list of floats/None where the latter indicates no dropout.powers (
Sequence
[Union
[int
,float
]]) – Powers to use to extract features from moduli parameter inputs.sig_suppress (
bool
) – Whether to learn an output which is transformed by sigmoid and multiplies the entries of the H-matrix.init_fluctuation (
float
) – Initial fluctuation of the diagonal around 1.activation (
Callable
[[Array
],Array
]) – Activation function for hidden layers.feature_angle (
bool
) – Whether to include the complex angle of the moduli powers as feature.feature_abs (
bool
) – Whether to include the absolute value of the moduli powers as features.feature_parts (
bool
) – Whether to use real and imaginary parts of moduli powers as features.
- __init__(basis_size, layer_sizes, dropout_rates=None, powers=(1, 2, 3), sig_suppress=True, init_fluctuation=0.001, activation=<PjitFunction of <function sigmoid>>, feature_angle=True, feature_abs=True, feature_parts=True, parent=<flax.linen.module._Sentinel object>, name=None)#
Methods
__init__
(basis_size, layer_sizes[, ...])activation
()Sigmoid activation function.
apply
(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind
(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone
(*[, parent, _deep_clone])Creates a clone of this Module, with optionally updated arguments.
get_variable
(col, name[, default])Retrieves the value of a Variable.
has_rng
(name)Returns true if a PRNGSequence with name name exists.
has_variable
(col, name)Checks if a variable of given collection and name exists in this Module.
init
(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output
(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
is_initializing
()Returns True if running under self.init(...) or nn.init(...)().
is_mutable_collection
(col)Returns true if the collection col is mutable.
lazy_init
(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
make_rng
(name)Returns a new RNG key from a given RNG sequence for this Module.
param
(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb
(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable
(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup
()Initializes a Module lazily (similar to a lazy
__init__
).sow
(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate
(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind
()Returns an unbound copy of a Module and its variables.
variable
(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
Attributes
dropout_rates
Dropout rates per layer.
feature_abs
Whether to use absolute values as input feature.
feature_angle
Whether to use complex angle as input feature.
feature_parts
Whether to use real & imaginary part as input feature.
init_fluctuation
Fluctuation of initialization around the identity (of diagonal entries).
name
parent
powers
Powers to use to extract features from moduli parameter inputs.
scope
sig_suppress
Whether to use learnable sigmoid-suppression of matrix elements.
variables
Returns the variables in this module.
basis_size
Number of sections used.
layer_sizes
Number of hidden dense linear layers.