Source code for lrnnx.architectures.lru_unet

"""Linear Recurrent Unit (LRU) based U-Net for sequence tasks."""

from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

from lrnnx.models.lti.lru import LRU


[docs] class LayerNormFeature(nn.Module): """ Layer normalization over the feature (channel) dimension. Args: num_features (int): Number of features (channels). """ def __init__(self, num_features: int): super().__init__() self.norm = nn.LayerNorm(num_features)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Applies normalization to input. Args: x (torch.Tensor): Input of shape ``(B, T, C)``. Returns: torch.Tensor: Normalized output. """ return self.norm(x)
[docs] class DownPool1D(nn.Module): """ 1D downsampling: stride-k Conv1d that doubles channels. Args: in_channels (int): Number of input channels. downsample_factor (int, optional): Stride factor. Defaults to 2. """ def __init__(self, in_channels: int, downsample_factor: int = 2): super().__init__() self.factor = downsample_factor self.conv = nn.Conv1d( in_channels=in_channels, out_channels=in_channels * 2, kernel_size=downsample_factor, stride=downsample_factor, padding=0, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Downsample input. Args: x (torch.Tensor): Input of shape ``(B, C, T)``. Returns: torch.Tensor: Downsampled output of shape ``(B, 2C, T/k)``. """ return self.conv(x)
[docs] class UpPool1D(nn.Module): """ 1D upsampling: stride-k ConvTranspose1d that halves channels. Args: in_channels (int): Number of input channels. upsample_factor (int, optional): Upsampling stride. Defaults to 2. """ def __init__(self, in_channels: int, upsample_factor: int = 2): super().__init__() self.factor = upsample_factor self.conv = nn.ConvTranspose1d( in_channels=in_channels, out_channels=in_channels // 2, kernel_size=upsample_factor, stride=upsample_factor, padding=0, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Upsample input. Args: x (torch.Tensor): Input of shape ``(B, C, T)``. Returns: torch.Tensor: Upsampled output of shape ``(B, C/2, T*k)``. """ return self.conv(x)
[docs] class LRU_UNet(nn.Module): """ Linear Recurrent Unit (LRU) based U-Net for sequence tasks. Args: d_model (int): Input feature dimension. d_state (int): Hidden state dimension for the LRU layers. n_layers (int): Number of downsampling/upsampling stages. downsample_factor (int, optional): Factor for each stage. Defaults to 2. """ def __init__( self, d_model: int, d_state: int, n_layers: int, downsample_factor: int = 2, ): super().__init__() self.n_layers = n_layers self.total_downsample = downsample_factor**n_layers self.down_ssms = nn.ModuleList() curr_dim = d_model for _ in range(n_layers): self.down_ssms.append( nn.ModuleList( [ LRU(curr_dim, d_state), DownPool1D(curr_dim, downsample_factor), ] ) ) curr_dim *= 2 self.hid_ssms = LRU(curr_dim, d_state) self.up_ssms = nn.ModuleList() for _ in range(n_layers): self.up_ssms.append( nn.ModuleList( [ UpPool1D(curr_dim, downsample_factor), LRU(curr_dim // 2, d_state), ] ) ) curr_dim //= 2
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the U-Net. Args: x (torch.Tensor): Input sequence of shape ``(B, C_in, T)``. Returns: torch.Tensor: Processed sequence of shape ``(B, C_in, T)``. """ # Handle padding for stride alignment T = x.shape[2] original_length = T pad_amount = ( self.total_downsample - (T % self.total_downsample) ) % self.total_downsample if pad_amount > 0: x = F.pad(x, (0, pad_amount)) x = x.permute(0, 2, 1) # (B, T, C_in) skips = [] # Encoder for lru_layer, pool_layer in self.down_ssms: skips.append(x) x = lru_layer(x) x_conv = x.permute(0, 2, 1) x_conv = pool_layer(x_conv) x = x_conv.permute(0, 2, 1) x = self.hid_ssms(x) # Decoder for pool_layer, lru_layer in self.up_ssms: x_conv = x.permute(0, 2, 1) x_conv = pool_layer(x_conv) x = x_conv.permute(0, 2, 1) skip = skips.pop() x = lru_layer(x + skip) x = x.permute(0, 2, 1) if pad_amount > 0: x = x[..., :original_length] return x