lrnnx.architectures.lru_unet module¶
Linear Recurrent Unit (LRU) based U-Net for sequence tasks.
- class LayerNormFeature[source]¶
Bases:
ModuleLayer normalization over the feature (channel) dimension.
- Parameters:
num_features (int) – Number of features (channels).
- forward(x: torch.Tensor) torch.Tensor[source]¶
Applies normalization to input.
- Parameters:
x (torch.Tensor) – Input of shape
(B, T, C).- Returns:
Normalized output.
- Return type:
- class DownPool1D[source]¶
Bases:
Module1D downsampling: stride-k Conv1d that doubles channels.
- Parameters:
- forward(x: torch.Tensor) torch.Tensor[source]¶
Downsample input.
- Parameters:
x (torch.Tensor) – Input of shape
(B, C, T).- Returns:
Downsampled output of shape
(B, 2C, T/k).- Return type:
- class UpPool1D[source]¶
Bases:
Module1D upsampling: stride-k ConvTranspose1d that halves channels.
- Parameters:
- forward(x: torch.Tensor) torch.Tensor[source]¶
Upsample input.
- Parameters:
x (torch.Tensor) – Input of shape
(B, C, T).- Returns:
Upsampled output of shape
(B, C/2, T*k).- Return type:
- class LRU_UNet[source]¶
Bases:
ModuleLinear Recurrent Unit (LRU) based U-Net for sequence tasks.
- Parameters:
- forward(x: torch.Tensor) torch.Tensor[source]¶
Forward pass through the U-Net.
- Parameters:
x (torch.Tensor) – Input sequence of shape
(B, C_in, T).- Returns:
Processed sequence of shape
(B, C_in, T).- Return type: