step.models.transcriptformer¶
Overview¶
Linear2D module consists of a linear layer with 3D weight matrix. |
|
Readout module for the TranscriptFormer model. |
|
GeneModuler takes gene expression as input and outputs gene modules. |
|
TranscriptFormer is a gene expression model based on the Transformer architecture. |
Classes¶
- class Linear2D(input_dim, hidden_dim, n_modules, bias=False)¶
Bases:
torch.nn.Module
Linear2D module consists of a linear layer with 3D weight matrix.
- Parameters:
input_dim (int) – The input dimension of the Linear2D module.
hidden_dim (int) – The hidden dimension of the Linear2D module.
n_modules (int) – The number of modules of the Linear2D module.
bias (bool, optional) – Whether to use bias. Defaults to False.
Linear2D module consists of a linear layer with 3D weight matrix.
- Parameters:
input_dim (int) – dimension of input
hidden_dim (int) – dimension of hidden layer
n_modules (int) – number of linear modules
bias (bool, optional) – whether to use bias. Defaults to False.
Overview
Members
- forward(x)¶
- class Readout(input_dim, output_dim, variational=True)¶
Bases:
torch.nn.Module
Readout module for the TranscriptFormer model.
- net¶
The sequential neural network.
- Type:
nn.Sequential
- variational¶
Whether to use variational encoding.
- Type:
bool
- out¶
The sequential neural network for the output.
- Type:
nn.Sequential
- mean¶
The linear layer for the mean.
- Type:
nn.Linear
- logvar¶
The linear layer for the logvar.
- Type:
nn.Linear
Initializes the Readout module.
- Parameters:
input_dim (int) – The input dimension of the Readout module.
output_dim (int) – The output dimension of the Readout module.
variational (bool, optional) – Whether to use variational encoding. Defaults to True.
Overview
¶ forward
(x)Forward pass of the Readout module.
kl_loss
()Computes the KL divergence loss.
clear
()-
Members
- forward(x)¶
Forward pass of the Readout module.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The output tensor.
- Return type:
torch.Tensor
- kl_loss()¶
Computes the KL divergence loss.
- Returns:
The KL divergence loss.
- Return type:
torch.Tensor
- clear()¶
- class GeneModuler(input_dim=2000, hidden_dim=8, n_modules=16)¶
Bases:
torch.nn.Module
GeneModuler takes gene expression as input and outputs gene modules.
- input_dim¶
The input dimension of the GeneModuler model.
- Type:
int
The hidden dimension of the GeneModuler model.
- Type:
int
- n_modules¶
The number of modules of the GeneModuler model.
- Type:
int
- layernorm¶
The layer normalization layer.
- Type:
nn.LayerNorm
GeneModuler takes gene expression as input and outputs gene modules.
- Parameters:
input_dim (int, optional) – dimension of input. Defaults to 2000.
hidden_dim (int, optional) – dimension of hidden layer. Defaults to 8.
n_modules (int, optional) – number of modules. Defaults to 16.
Overview
Members
- forward(x, batch=None)¶
- demodule(x)¶
- random_permute(x)¶
- class TranscriptFormer(decoder_type='zinb', use_pe=True, use_smooth=False, use_skip=False, input_dim=2000, module_dim=30, decoder_input_dim=None, hidden_dim=256, n_modules=16, nhead=8, n_enc_layer=3, dec_norm='batch', variational=True, smoother='GCN', n_glayers=3, dec_hidden_dim=None, n_dec_hid_layers: int = 1, edge_clip=2, use_l_scale: bool = False, num_batches: int = 1, activation: Literal[softplus, softmax] | None = None)¶
Bases:
torch.nn.Module
TranscriptFormer is a gene expression model based on the Transformer architecture.
- input_dim¶
The input dimension of the TranscriptFormer model.
- Type:
int
- module_dim¶
The module dimension of the TranscriptFormer model.
- Type:
int
The hidden dimension of the TranscriptFormer model.
- Type:
int
- n_modules¶
The number of modules of the TranscriptFormer model.
- Type:
int
- moduler¶
The GeneModuler object.
- Type:
- expand¶
The linear layer for expanding the module.
- Type:
nn.Linear
- module¶
The TransformerEncoder object.
- Type:
nn.TransformerEncoder
- cls_token¶
The classification token.
- Type:
nn.Parameter
- px_r¶
The parameter for the zero-inflated negative binomial distribution.
- Type:
torch.nn.Parameter
- decoder¶
The ProbDecoder object.
- Type:
- decoder_type¶
The type of the decoder.
- Type:
str
- _smooth¶
Whether to use smoothing.
- Type:
bool
- smoother_type¶
The type of the smoother.
- Type:
str
- args¶
The arguments for the TranscriptFormer model.
- Type:
dict
- gargs¶
The arguments for the GCN object.
- Type:
dict
Initializes the TranscriptFormer model.
- Parameters:
grids (None, optional) – Grids. Defaults to None.
decoder_type (str, optional) – Decoder type. Defaults to ‘zinb’.
use_pe (bool, optional) – Whether to use positional encoding. Defaults to True.
use_smooth (bool, optional) – Whether to use smoothing. Defaults to False.
use_skip (bool, optional) – Whether to use skip connections. Defaults to False.
input_dim (int, optional) – Input dimension. Defaults to 2000.
module_dim (int, optional) – Module dimension. Defaults to 30.
decoder_input_dim (None, optional) – Decoder input dimension. Defaults to None.
hidden_dim (int, optional) – Hidden dimension. Defaults to 256.
n_modules (int, optional) – Number of modules. Defaults to 16.
nhead (int, optional) – Number of attention heads. Defaults to 8.
n_enc_layer (int, optional) – Number of encoder layers. Defaults to 3.
dec_norm (str, optional) – Decoder normalization. Defaults to ‘batch’.
variational (bool, optional) – Whether to use variational encoding. Defaults to True.
smoother (str, optional) – Smoother type. Defaults to ‘GCN’.
n_glayers (int, optional) – Number of graph layers. Defaults to 3.
dec_hidden_dim (None, optional) – Decoder hidden dimension. Defaults to None.
n_dec_hid_layers (int, optional) – Number of decoder hidden layers. Defaults to 1.
edge_clip (int, optional) – Edge clip value. Defaults to 2.
Overview
¶ get_px_r
(batch_label)-
-
init_smoother
(n_glayers)-
local_smooth
(h, g)Local smoothing function.
encode_ts
(x, batch_rep)Encode the input tensor with only the transformer.
readout_
(cls_rep)Readout function.
encode
(x, bacth_rep)Encode the input tensor with the transformer and the readout function.
decode_ts
(rep_ts, x_gd, batch_rep)Decoding process starting from the non-standardized representation.
decode
(cls_rep, x_gd, batch_rep)-
forward
(x)-
copy
(with_state)-
Members
- get_px_r(batch_label)¶
- init_smoother_with_builtin()¶
- init_smoother(n_glayers=None)¶
- local_smooth(h, g: dgl.DGLGraph | None = None)¶
Local smoothing function.
- Parameters:
h (torch.Tensor) – The input tensor.
g (Optional[dgl.DGLGraph], optional) – The graph. Defaults to None.
- encode_ts(x, batch_rep=None) torch.Tensor ¶
Encode the input tensor with only the transformer.
- Parameters:
x (torch.Tensor) – The input tensor.
batch_rep ([type], optional) – representation tensor of the batch indicator. Defaults to None.
- Returns:
The encoded tensor, denoted as non-standardized representation.
- Return type:
torch.Tensor
- readout_(cls_rep) torch.Tensor ¶
Readout function.
- Parameters:
cls_rep (torch.Tensor) – The input tensor.
- Returns:
The output tensor.
- Return type:
torch.Tensor
- encode(x, bacth_rep=None) torch.Tensor ¶
Encode the input tensor with the transformer and the readout function.
- Parameters:
x (torch.Tensor) – The input tensor.
batch_rep ([type], optional) – representation tensor of the batch indicator. Defaults to None.
- Returns:
The encoded tensor, denoted as standardized representation.
- Return type:
torch.Tensor
- decode_ts(rep_ts, x_gd, batch_rep=None)¶
Decoding process starting from the non-standardized representation.
- Parameters:
rep_ts (torch.Tensor) – The input tensor.
x_gd (torch.Tensor) – The input tensor.
batch_rep ([type], optional) – representation tensor of the batch indicator. Defaults to None.
- decode(cls_rep, x_gd, batch_rep=None)¶
- forward(x)¶
- copy(with_state=True)¶
Attributes¶
- drop_edge¶