step.models.transcriptformer

Overview

Classes

Linear2D

Linear2D module consists of a linear layer with 3D weight matrix.

Readout

Readout module for the TranscriptFormer model.

GeneModuler

GeneModuler takes gene expression as input and outputs gene modules.

TranscriptFormer

TranscriptFormer is a gene expression model based on the Transformer architecture.

Attributes

drop_edge

-

Classes

class Linear2D(input_dim, hidden_dim, n_modules, bias=False)

Bases: torch.nn.Module

Linear2D module consists of a linear layer with 3D weight matrix.

Parameters:
  • input_dim (int) – The input dimension of the Linear2D module.

  • hidden_dim (int) – The hidden dimension of the Linear2D module.

  • n_modules (int) – The number of modules of the Linear2D module.

  • bias (bool, optional) – Whether to use bias. Defaults to False.

Linear2D module consists of a linear layer with 3D weight matrix.

Parameters:
  • input_dim (int) – dimension of input

  • hidden_dim (int) – dimension of hidden layer

  • n_modules (int) – number of linear modules

  • bias (bool, optional) – whether to use bias. Defaults to False.

Overview

Methods

forward(x)

-

Members

forward(x)
class Readout(input_dim, output_dim, variational=True)

Bases: torch.nn.Module

Readout module for the TranscriptFormer model.

net

The sequential neural network.

Type:

nn.Sequential

variational

Whether to use variational encoding.

Type:

bool

out

The sequential neural network for the output.

Type:

nn.Sequential

mean

The linear layer for the mean.

Type:

nn.Linear

logvar

The linear layer for the logvar.

Type:

nn.Linear

Initializes the Readout module.

Parameters:
  • input_dim (int) – The input dimension of the Readout module.

  • output_dim (int) – The output dimension of the Readout module.

  • variational (bool, optional) – Whether to use variational encoding. Defaults to True.

Overview

Methods

forward(x)

Forward pass of the Readout module.

kl_loss()

Computes the KL divergence loss.

clear()

-

Members

forward(x)

Forward pass of the Readout module.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The output tensor.

Return type:

torch.Tensor

kl_loss()

Computes the KL divergence loss.

Returns:

The KL divergence loss.

Return type:

torch.Tensor

clear()
class GeneModuler(input_dim=2000, hidden_dim=8, n_modules=16)

Bases: torch.nn.Module

GeneModuler takes gene expression as input and outputs gene modules.

input_dim

The input dimension of the GeneModuler model.

Type:

int

hidden_dim

The hidden dimension of the GeneModuler model.

Type:

int

n_modules

The number of modules of the GeneModuler model.

Type:

int

layernorm

The layer normalization layer.

Type:

nn.LayerNorm

extractor

The Linear2D object.

Type:

Linear2D

GeneModuler takes gene expression as input and outputs gene modules.

Parameters:
  • input_dim (int, optional) – dimension of input. Defaults to 2000.

  • hidden_dim (int, optional) – dimension of hidden layer. Defaults to 8.

  • n_modules (int, optional) – number of modules. Defaults to 16.

Overview

Methods

forward(x, batch)

-

demodule(x)

-

random_permute(x)

-

Members

forward(x, batch=None)
demodule(x)
random_permute(x)
class TranscriptFormer(decoder_type='zinb', use_pe=True, use_smooth=False, use_skip=False, input_dim=2000, module_dim=30, decoder_input_dim=None, hidden_dim=256, n_modules=16, nhead=8, n_enc_layer=3, dec_norm='batch', variational=True, smoother='GCN', n_glayers=3, dec_hidden_dim=None, n_dec_hid_layers: int = 1, edge_clip=2, use_l_scale: bool = False, num_batches: int = 1, activation: Literal[softplus, softmax] | None = None)

Bases: torch.nn.Module

TranscriptFormer is a gene expression model based on the Transformer architecture.

input_dim

The input dimension of the TranscriptFormer model.

Type:

int

module_dim

The module dimension of the TranscriptFormer model.

Type:

int

hidden_dim

The hidden dimension of the TranscriptFormer model.

Type:

int

n_modules

The number of modules of the TranscriptFormer model.

Type:

int

moduler

The GeneModuler object.

Type:

GeneModuler

expand

The linear layer for expanding the module.

Type:

nn.Linear

readout

The Readout object.

Type:

Readout

module

The TransformerEncoder object.

Type:

nn.TransformerEncoder

cls_token

The classification token.

Type:

nn.Parameter

px_r

The parameter for the zero-inflated negative binomial distribution.

Type:

torch.nn.Parameter

decoder

The ProbDecoder object.

Type:

ProbDecoder

decoder_type

The type of the decoder.

Type:

str

_smooth

Whether to use smoothing.

Type:

bool

smoother

The GCN object for smoothing.

Type:

GCN

smoother_type

The type of the smoother.

Type:

str

args

The arguments for the TranscriptFormer model.

Type:

dict

gargs

The arguments for the GCN object.

Type:

dict

Initializes the TranscriptFormer model.

Parameters:
  • grids (None, optional) – Grids. Defaults to None.

  • decoder_type (str, optional) – Decoder type. Defaults to ‘zinb’.

  • use_pe (bool, optional) – Whether to use positional encoding. Defaults to True.

  • use_smooth (bool, optional) – Whether to use smoothing. Defaults to False.

  • use_skip (bool, optional) – Whether to use skip connections. Defaults to False.

  • input_dim (int, optional) – Input dimension. Defaults to 2000.

  • module_dim (int, optional) – Module dimension. Defaults to 30.

  • decoder_input_dim (None, optional) – Decoder input dimension. Defaults to None.

  • hidden_dim (int, optional) – Hidden dimension. Defaults to 256.

  • n_modules (int, optional) – Number of modules. Defaults to 16.

  • nhead (int, optional) – Number of attention heads. Defaults to 8.

  • n_enc_layer (int, optional) – Number of encoder layers. Defaults to 3.

  • dec_norm (str, optional) – Decoder normalization. Defaults to ‘batch’.

  • variational (bool, optional) – Whether to use variational encoding. Defaults to True.

  • smoother (str, optional) – Smoother type. Defaults to ‘GCN’.

  • n_glayers (int, optional) – Number of graph layers. Defaults to 3.

  • dec_hidden_dim (None, optional) – Decoder hidden dimension. Defaults to None.

  • n_dec_hid_layers (int, optional) – Number of decoder hidden layers. Defaults to 1.

  • edge_clip (int, optional) – Edge clip value. Defaults to 2.

Overview

Methods

get_px_r(batch_label)

-

init_smoother_with_builtin()

-

init_smoother(n_glayers)

-

local_smooth(h, g)

Local smoothing function.

encode_ts(x, batch_rep)

Encode the input tensor with only the transformer.

readout_(cls_rep)

Readout function.

encode(x, bacth_rep)

Encode the input tensor with the transformer and the readout function.

decode_ts(rep_ts, x_gd, batch_rep)

Decoding process starting from the non-standardized representation.

decode(cls_rep, x_gd, batch_rep)

-

forward(x)

-

copy(with_state)

-

Members

get_px_r(batch_label)
init_smoother_with_builtin()
init_smoother(n_glayers=None)
local_smooth(h, g: dgl.DGLGraph | None = None)

Local smoothing function.

Parameters:
  • h (torch.Tensor) – The input tensor.

  • g (Optional[dgl.DGLGraph], optional) – The graph. Defaults to None.

encode_ts(x, batch_rep=None) torch.Tensor

Encode the input tensor with only the transformer.

Parameters:
  • x (torch.Tensor) – The input tensor.

  • batch_rep ([type], optional) – representation tensor of the batch indicator. Defaults to None.

Returns:

The encoded tensor, denoted as non-standardized representation.

Return type:

torch.Tensor

readout_(cls_rep) torch.Tensor

Readout function.

Parameters:

cls_rep (torch.Tensor) – The input tensor.

Returns:

The output tensor.

Return type:

torch.Tensor

encode(x, bacth_rep=None) torch.Tensor

Encode the input tensor with the transformer and the readout function.

Parameters:
  • x (torch.Tensor) – The input tensor.

  • batch_rep ([type], optional) – representation tensor of the batch indicator. Defaults to None.

Returns:

The encoded tensor, denoted as standardized representation.

Return type:

torch.Tensor

decode_ts(rep_ts, x_gd, batch_rep=None)

Decoding process starting from the non-standardized representation.

Parameters:
  • rep_ts (torch.Tensor) – The input tensor.

  • x_gd (torch.Tensor) – The input tensor.

  • batch_rep ([type], optional) – representation tensor of the batch indicator. Defaults to None.

decode(cls_rep, x_gd, batch_rep=None)
forward(x)
copy(with_state=True)

Attributes

drop_edge