step.models.extension

Overview

Classes

BatchAwareScale

BatchAwareScale is a module that performs BatchAwareScale introduced in the paper.

BatchAwareLayerNorm

BatchAwareLayerNorm is a module that performs BatchAwareLayerNorm introduced in the paper.

NrmlsBC

NrmlsBC is an extension of the TranscriptFormer model that supports batch-aware normalization and scaling to eliminate batch effects.

Classes

class BatchAwareScale(input_dim, output_dim, act=None)

Bases: torch.nn.Module

BatchAwareScale is a module that performs BatchAwareScale introduced in the paper.

net

The neural network.

Type:

nn.Sequential

act

The activation function.

Type:

nn.Module

Defines the computation performed at every call.

Parameters:
  • x (Tensor) – The input data.

  • batch (Tensor) – The batch data.

Returns:

The output data after scaling.

Return type:

Tensor

Overview

Methods

forward(x, batch)

-

Members

forward(x, batch)
class BatchAwareLayerNorm(input_dim, output_dim, act='relu')

Bases: torch.nn.Module

BatchAwareLayerNorm is a module that performs BatchAwareLayerNorm introduced in the paper.

mean

The mean layer.

Type:

nn.Linear

scale

The scale layer.

Type:

nn.Linear

layernorm

The layer normalization layer.

Type:

nn.LayerNorm

act

The activation function.

Type:

nn.Module

Initialize the BatchAwareLayerNorm module.

Parameters:
  • input_dim (int) – The input dimension of the module.

  • output_dim (int) – The output dimension of the module.

  • act (str, optional) – The activation function to use. Defaults to ‘relu’.

Overview

Methods

forward(x, batch)

-

Members

forward(x, batch)
class NrmlsBC(num_batches: int, num_classes: int = 1, dispersion: Literal[gene, batch - gene] = 'batch-gene', use_l_scale=False, **kwargs)

Bases: step.models.transcriptformer.TranscriptFormer

NrmlsBC is an extension of the TranscriptFormer model that supports batch-aware normalization and scaling to eliminate batch effects.

num_batches

The number of batches.

Type:

int

batch_emb_dim

The batch embedding dimension.

Type:

int

smoother

The smoother module.

Type:

Optional[nn.Module]

batch_embedding

The batch embedding parameter.

Type:

nn.Parameter

moduler

The moduler module.

Type:

TranscriptFormer

batch_readout

The batch readout module.

Type:

BatchAwareScale

args

The arguments of the model.

Type:

Dict[str, Any]

Initialize the Extension class.

Parameters:
  • num_batches (int) – The number of batches.

  • num_classes (int, optional) – The number of classes. Defaults to 1.

  • **kwargs – Additional keyword arguments.

Overview

Methods

encode_ts(x, batch_rep)

Encodes the output of transformer encoders using the transformer model.

readout_(tsfmr_out)

Apply smoothing to the transformer output if enabled, and then perform readout.

readout_batch(rep_ts, batch_rep)

Readout the representation with batch representation.

encode(x, batch_rep)

Encodes the input data x using the specified batch representation batch_rep.

decode(cls_rep, x_gd, batch_rep, rep_ts)

Decodes the given input representation into output representation.

decode_ts(rep_ts, x_gd, batch_rep)

Decode the given representation tensor rep_ts into a prediction tensor.

decode_skip(cls_rep, rep_ts, x_gd, batch_rep)

Decodes the input data using the skip model.

decode_(cls_rep, x_gd, batch_rep)

-

forward(x, batch_rep, return_exp)

-

init_anchor(num_classes, new_anchors)

Initializes the anchor module.

copy(with_state)

Creates a copy of the current object.

copy_dec()

Creates a copy of the model with the specified parameters.

Members

encode_ts(x, batch_rep)

Encodes the output of transformer encoders using the transformer model.

Parameters:
  • x (Tensor) – The output of transformer encoders.

  • batch_rep (bool) – Whether to return the representation for each time step in the batch.

Returns:

The encoded representation of the output of transformer encoders.

Return type:

Tensor

readout_(tsfmr_out)

Apply smoothing to the transformer output if enabled, and then perform readout.

Parameters:

tsfmr_out – The transformer output.

Returns:

The result of the readout operation.

readout_batch(rep_ts, batch_rep)

Readout the representation with batch representation.

Parameters:
  • rep_ts – The representation tensor.

  • batch_rep – The batch representation.

Returns:

The class representation.

encode(x, batch_rep)

Encodes the input data x using the specified batch representation batch_rep.

Parameters:
  • x – The input data to be encoded.

  • batch_rep – The batch representation to be used for encoding.

Returns:

The encoded representation of the input data.

decode(cls_rep, x_gd, batch_rep, rep_ts=None)

Decodes the given input representation into output representation.

Parameters:
  • cls_rep – The class representation.

  • x_gd – The input representation.

  • batch_rep – The batch representation (optional).

Returns:

The decoded output representation.

decode_ts(rep_ts, x_gd, batch_rep=None)

Decode the given representation tensor rep_ts into a prediction tensor.

Parameters:
  • rep_ts (torch.Tensor) – The representation tensor.

  • x_gd (torch.Tensor) – The input tensor.

  • batch_rep – The batch representation tensor. Defaults to None.

decode_skip(cls_rep, rep_ts, x_gd, batch_rep)

Decodes the input data using the skip model.

Parameters:
  • cls_rep – The class representation.

  • rep_ts – The representation time series.

  • x_gd – The input data.

  • batch_rep – The batch representation.

Returns:

  • px_rate: The rate of the decoded values.

  • px_dropout: The dropout of the decoded values.

  • px_scale: The scale of the decoded values.

  • px_r: The px_r value.

  • decoder_type: The type of decoder.

  • x: The input data.

Return type:

A dictionary containing the decoded values

decode_(cls_rep, x_gd, batch_rep=None)
forward(x, batch_rep, return_exp=True)
init_anchor(num_classes: int | None = None, new_anchors=True)

Initializes the anchor module.

Parameters:
  • num_classes (Optional[int]) – The number of classes. If provided and greater than 0, the class classification head and anchors will be initialized accordingly.

  • new_anchors (bool) – Whether to initialize new anchors.

Returns:

True if the anchor module is successfully initialized, False otherwise.

Return type:

bool

copy(with_state=True)

Creates a copy of the current object.

Returns: A new instance of the NrmlsBC class with the same arguments and state.

copy_dec()

Creates a copy of the model with the specified parameters.

Returns:

A copy of the model with the registered parameters.

Return type:

nn.ModuleDict