step.models.extension¶
Overview¶
BatchAwareScale is a module that performs BatchAwareScale introduced in the paper. |
|
BatchAwareLayerNorm is a module that performs BatchAwareLayerNorm introduced in the paper. |
|
NrmlsBC is an extension of the TranscriptFormer model that supports batch-aware normalization and scaling to eliminate batch effects. |
Classes¶
- class BatchAwareScale(input_dim, output_dim, act=None)¶
Bases:
torch.nn.Module
BatchAwareScale is a module that performs BatchAwareScale introduced in the paper.
- net¶
The neural network.
- Type:
nn.Sequential
- act¶
The activation function.
- Type:
nn.Module
Defines the computation performed at every call.
- Parameters:
x (Tensor) – The input data.
batch (Tensor) – The batch data.
- Returns:
The output data after scaling.
- Return type:
Tensor
Overview
Members
- forward(x, batch)¶
- class BatchAwareLayerNorm(input_dim, output_dim, act='relu')¶
Bases:
torch.nn.Module
BatchAwareLayerNorm is a module that performs BatchAwareLayerNorm introduced in the paper.
- mean¶
The mean layer.
- Type:
nn.Linear
- scale¶
The scale layer.
- Type:
nn.Linear
- layernorm¶
The layer normalization layer.
- Type:
nn.LayerNorm
- act¶
The activation function.
- Type:
nn.Module
Initialize the BatchAwareLayerNorm module.
- Parameters:
input_dim (int) – The input dimension of the module.
output_dim (int) – The output dimension of the module.
act (str, optional) – The activation function to use. Defaults to ‘relu’.
Overview
Members
- forward(x, batch)¶
- class NrmlsBC(num_batches: int, num_classes: int = 1, dispersion: Literal[gene, batch - gene] = 'batch-gene', use_l_scale=False, **kwargs)¶
Bases:
step.models.transcriptformer.TranscriptFormer
NrmlsBC is an extension of the TranscriptFormer model that supports batch-aware normalization and scaling to eliminate batch effects.
- num_batches¶
The number of batches.
- Type:
int
- batch_emb_dim¶
The batch embedding dimension.
- Type:
int
- smoother¶
The smoother module.
- Type:
Optional[nn.Module]
- batch_embedding¶
The batch embedding parameter.
- Type:
nn.Parameter
- moduler¶
The moduler module.
- Type:
- batch_readout¶
The batch readout module.
- Type:
- args¶
The arguments of the model.
- Type:
Dict[str, Any]
Initialize the Extension class.
- Parameters:
num_batches (int) – The number of batches.
num_classes (int, optional) – The number of classes. Defaults to 1.
**kwargs – Additional keyword arguments.
Overview
¶ encode_ts
(x, batch_rep)Encodes the output of transformer encoders using the transformer model.
readout_
(tsfmr_out)Apply smoothing to the transformer output if enabled, and then perform readout.
readout_batch
(rep_ts, batch_rep)Readout the representation with batch representation.
encode
(x, batch_rep)Encodes the input data x using the specified batch representation batch_rep.
decode
(cls_rep, x_gd, batch_rep, rep_ts)Decodes the given input representation into output representation.
decode_ts
(rep_ts, x_gd, batch_rep)Decode the given representation tensor rep_ts into a prediction tensor.
decode_skip
(cls_rep, rep_ts, x_gd, batch_rep)Decodes the input data using the skip model.
decode_
(cls_rep, x_gd, batch_rep)-
forward
(x, batch_rep, return_exp)-
init_anchor
(num_classes, new_anchors)Initializes the anchor module.
copy
(with_state)Creates a copy of the current object.
copy_dec
()Creates a copy of the model with the specified parameters.
Members
- encode_ts(x, batch_rep)¶
Encodes the output of transformer encoders using the transformer model.
- Parameters:
x (Tensor) – The output of transformer encoders.
batch_rep (bool) – Whether to return the representation for each time step in the batch.
- Returns:
The encoded representation of the output of transformer encoders.
- Return type:
Tensor
- readout_(tsfmr_out)¶
Apply smoothing to the transformer output if enabled, and then perform readout.
- Parameters:
tsfmr_out – The transformer output.
- Returns:
The result of the readout operation.
- readout_batch(rep_ts, batch_rep)¶
Readout the representation with batch representation.
- Parameters:
rep_ts – The representation tensor.
batch_rep – The batch representation.
- Returns:
The class representation.
- encode(x, batch_rep)¶
Encodes the input data x using the specified batch representation batch_rep.
- Parameters:
x – The input data to be encoded.
batch_rep – The batch representation to be used for encoding.
- Returns:
The encoded representation of the input data.
- decode(cls_rep, x_gd, batch_rep, rep_ts=None)¶
Decodes the given input representation into output representation.
- Parameters:
cls_rep – The class representation.
x_gd – The input representation.
batch_rep – The batch representation (optional).
- Returns:
The decoded output representation.
- decode_ts(rep_ts, x_gd, batch_rep=None)¶
Decode the given representation tensor rep_ts into a prediction tensor.
- Parameters:
rep_ts (torch.Tensor) – The representation tensor.
x_gd (torch.Tensor) – The input tensor.
batch_rep – The batch representation tensor. Defaults to None.
- decode_skip(cls_rep, rep_ts, x_gd, batch_rep)¶
Decodes the input data using the skip model.
- Parameters:
cls_rep – The class representation.
rep_ts – The representation time series.
x_gd – The input data.
batch_rep – The batch representation.
- Returns:
px_rate: The rate of the decoded values.
px_dropout: The dropout of the decoded values.
px_scale: The scale of the decoded values.
px_r: The px_r value.
decoder_type: The type of decoder.
x: The input data.
- Return type:
A dictionary containing the decoded values
- decode_(cls_rep, x_gd, batch_rep=None)¶
- forward(x, batch_rep, return_exp=True)¶
- init_anchor(num_classes: int | None = None, new_anchors=True)¶
Initializes the anchor module.
- Parameters:
num_classes (Optional[int]) – The number of classes. If provided and greater than 0, the class classification head and anchors will be initialized accordingly.
new_anchors (bool) – Whether to initialize new anchors.
- Returns:
True if the anchor module is successfully initialized, False otherwise.
- Return type:
bool
- copy(with_state=True)¶
Creates a copy of the current object.
Returns: A new instance of the NrmlsBC class with the same arguments and state.
- copy_dec()¶
Creates a copy of the model with the specified parameters.
- Returns:
A copy of the model with the registered parameters.
- Return type:
nn.ModuleDict