lrnnx.architectures.lru_unet module

Linear Recurrent Unit (LRU) based U-Net for sequence tasks.

class LayerNormFeature[source]

Bases: Module

Layer normalization over the feature (channel) dimension.

Parameters:

num_features (int) – Number of features (channels).

forward(x: torch.Tensor) torch.Tensor[source]

Applies normalization to input.

Parameters:

x (torch.Tensor) – Input of shape (B, T, C).

Returns:

Normalized output.

Return type:

torch.Tensor

class DownPool1D[source]

Bases: Module

1D downsampling: stride-k Conv1d that doubles channels.

Parameters:
  • in_channels (int) – Number of input channels.

  • downsample_factor (int, optional) – Stride factor. Defaults to 2.

forward(x: torch.Tensor) torch.Tensor[source]

Downsample input.

Parameters:

x (torch.Tensor) – Input of shape (B, C, T).

Returns:

Downsampled output of shape (B, 2C, T/k).

Return type:

torch.Tensor

class UpPool1D[source]

Bases: Module

1D upsampling: stride-k ConvTranspose1d that halves channels.

Parameters:
  • in_channels (int) – Number of input channels.

  • upsample_factor (int, optional) – Upsampling stride. Defaults to 2.

forward(x: torch.Tensor) torch.Tensor[source]

Upsample input.

Parameters:

x (torch.Tensor) – Input of shape (B, C, T).

Returns:

Upsampled output of shape (B, C/2, T*k).

Return type:

torch.Tensor

class LRU_UNet[source]

Bases: Module

Linear Recurrent Unit (LRU) based U-Net for sequence tasks.

Parameters:
  • d_model (int) – Input feature dimension.

  • d_state (int) – Hidden state dimension for the LRU layers.

  • n_layers (int) – Number of downsampling/upsampling stages.

  • downsample_factor (int, optional) – Factor for each stage. Defaults to 2.

forward(x: torch.Tensor) torch.Tensor[source]

Forward pass through the U-Net.

Parameters:

x (torch.Tensor) – Input sequence of shape (B, C_in, T).

Returns:

Processed sequence of shape (B, C_in, T).

Return type:

torch.Tensor