Source code for lrnnx.ops.triton.selective_state_update

# Copyright (c) 2024, Tri Dao, Albert Gu.
# Modified by SAiDL.

"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this"""

import math

import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange, repeat

from lrnnx.ops.triton.softplus import softplus


@triton.heuristics(
    {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}
)
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics({"HAS_DELTAA": lambda args: args["deltaA_ptr"] is not None})
@triton.heuristics(
    {
        "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
        is not None
    }
)
@triton.heuristics(
    {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
)
@triton.jit
def _selective_scan_update_kernel(
    # Pointers to matrices
    state_ptr,
    x_ptr,
    dt_ptr,
    dt_bias_ptr,
    A_ptr,
    B_ptr,
    C_ptr,
    D_ptr,
    z_ptr,
    deltaA_ptr,
    out_ptr,
    state_batch_indices_ptr,
    # Matrix dimensions
    batch,
    nheads,
    dim,
    dstate,
    nheads_ngroups_ratio,
    # Strides
    stride_state_batch,
    stride_state_head,
    stride_state_dim,
    stride_state_dstate,
    stride_x_batch,
    stride_x_head,
    stride_x_dim,
    stride_dt_batch,
    stride_dt_head,
    stride_dt_dim,
    stride_dt_bias_head,
    stride_dt_bias_dim,
    stride_A_head,
    stride_A_dim,
    stride_A_dstate,
    stride_B_batch,
    stride_B_group,
    stride_B_dstate,
    stride_C_batch,
    stride_C_group,
    stride_C_dstate,
    stride_D_head,
    stride_D_dim,
    stride_z_batch,
    stride_z_head,
    stride_z_dim,
    stride_deltaA_batch,
    stride_deltaA_head,
    stride_deltaA_dim,
    stride_out_batch,
    stride_out_head,
    stride_out_dim,
    # Meta-parameters
    DT_SOFTPLUS: tl.constexpr,
    TIE_HDIM: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
    HAS_DT_BIAS: tl.constexpr,
    HAS_D: tl.constexpr,
    HAS_Z: tl.constexpr,
    HAS_DELTAA: tl.constexpr,
    HAS_STATE_BATCH_INDICES: tl.constexpr,
    BLOCK_SIZE_DSTATE: tl.constexpr,
    DISCRETIZATION: tl.constexpr,
):
    """
    Triton JIT kernel for the selective state update.
    """
    pid_m = tl.program_id(axis=0)
    pid_b = tl.program_id(axis=1)
    pid_h = tl.program_id(axis=2)

    if HAS_STATE_BATCH_INDICES:
        state_batch_indices_ptr += pid_b
        state_batch_idx = tl.load(state_batch_indices_ptr)
        state_ptr += (
            state_batch_idx * stride_state_batch + pid_h * stride_state_head
        )
    else:
        state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head

    x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
    dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
    if HAS_DT_BIAS:
        dt_bias_ptr += pid_h * stride_dt_bias_head
    A_ptr += pid_h * stride_A_head
    B_ptr += (
        pid_b * stride_B_batch
        + (pid_h // nheads_ngroups_ratio) * stride_B_group
    )
    C_ptr += (
        pid_b * stride_C_batch
        + (pid_h // nheads_ngroups_ratio) * stride_C_group
    )
    if HAS_Z:
        z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
    if HAS_DELTAA:
        deltaA_ptr += pid_b * stride_deltaA_batch + pid_h * stride_deltaA_head
    out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head

    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
    state_ptrs = state_ptr + (
        offs_m[:, None] * stride_state_dim
        + offs_n[None, :] * stride_state_dstate
    )
    x_ptrs = x_ptr + offs_m * stride_x_dim
    dt_ptrs = dt_ptr + offs_m * stride_dt_dim
    if HAS_DT_BIAS:
        dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
    if HAS_D:
        D_ptr += pid_h * stride_D_head
    A_ptrs = A_ptr + (
        offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
    )
    B_ptrs = B_ptr + offs_n * stride_B_dstate
    C_ptrs = C_ptr + offs_n * stride_C_dstate
    if HAS_D:
        D_ptrs = D_ptr + offs_m * stride_D_dim
    if HAS_Z:
        z_ptrs = z_ptr + offs_m * stride_z_dim
    if HAS_DELTAA:
        deltaA_ptrs = deltaA_ptr + offs_m * stride_deltaA_dim
    out_ptrs = out_ptr + offs_m * stride_out_dim

    state = tl.load(
        state_ptrs,
        mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
        other=0.0,
    )
    x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)

    if not TIE_HDIM:
        A = tl.load(
            A_ptrs,
            mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
            other=-1.0,
        ).to(tl.float32)
    else:
        A = tl.load(A_ptr).to(tl.float32)

    if HAS_DELTAA:
        if not TIE_HDIM:
            deltaA = tl.load(deltaA_ptrs, mask=offs_m < dim, other=0.0).to(
                tl.float32
            )
            if DISCRETIZATION == 4:
                dA = tl.exp(deltaA[:, None] * tl.log(A))
            else:
                dA = tl.exp(A * deltaA[:, None])
        else:
            deltaA = tl.load(deltaA_ptr).to(tl.float32)
            if DISCRETIZATION == 4:
                dA = tl.exp(deltaA * tl.log(A))
            else:
                dA = tl.exp(A * deltaA)
    else:
        if not TIE_HDIM:
            dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
            if HAS_DT_BIAS:
                dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(
                    tl.float32
                )
            if DT_SOFTPLUS:
                dt = tl.where(dt <= 20.0, softplus(dt), dt)
            if DISCRETIZATION == 4:
                dA = tl.exp(dt[:, None] * tl.log(A))
            elif DISCRETIZATION == 5:  # S7: A_bar = 1 - 1/(A_raw² + 0.5)
                dt_sq_half = dt[:, None] * dt[:, None] + 0.5
                dA = 1.0 - 1.0 / dt_sq_half
            else:
                dA = tl.exp(A * dt[:, None])
        else:
            dt = tl.load(dt_ptr).to(tl.float32)
            if HAS_DT_BIAS:
                dt += tl.load(dt_bias_ptr).to(tl.float32)
            if DT_SOFTPLUS:
                dt = tl.where(dt <= 20.0, softplus(dt), dt)
            if DISCRETIZATION == 4:
                dA = tl.exp(dt * tl.log(A))
            elif DISCRETIZATION == 5:  # S7
                dt_sq_half = dt * dt + 0.5
                dA = 1.0 - 1.0 / dt_sq_half
            else:
                dA = tl.exp(A * dt)

    if not TIE_HDIM:
        if not HAS_DELTAA:
            pass
        else:
            dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
            if HAS_DT_BIAS:
                dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(
                    tl.float32
                )
            if DT_SOFTPLUS:
                dt = tl.where(dt <= 20.0, softplus(dt), dt)
    else:
        if not HAS_DELTAA:
            pass
        else:
            dt = tl.load(dt_ptr).to(tl.float32)
            if HAS_DT_BIAS:
                dt += tl.load(dt_bias_ptr).to(tl.float32)
            if DT_SOFTPLUS:
                dt = tl.where(dt <= 20.0, softplus(dt), dt)

    B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
    C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
    if HAS_D:
        D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
    if HAS_Z:
        z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)

    if DISCRETIZATION == 0:
        if not TIE_HDIM:
            A_dt = A * dt[:, None]
            expm1_A_dt = tl.exp(A_dt) - 1.0
            B_tilde = expm1_A_dt / A
            dB = B[None, :] * B_tilde
        else:
            A_dt = A * dt
            expm1_A_dt = tl.exp(A_dt) - 1.0
            B_tilde = expm1_A_dt / A
            dB = B * B_tilde
    elif DISCRETIZATION == 1:
        if not TIE_HDIM:
            v = 0.5 * dt[:, None] * A
            den_inv = 1.0 / (1.0 - v)
            dB = B[None, :] * den_inv * dt[:, None]
        else:
            v = 0.5 * dt * A
            den_inv = 1.0 / (1.0 - v)
            dB = B * den_inv * dt
    elif DISCRETIZATION == 2:
        if not TIE_HDIM:
            dB = B[None, :]  # (1, dstate), will broadcast with x[:, None]
        else:
            dB = B
    elif DISCRETIZATION == 4:
        sqrt_term = tl.sqrt(1.0 - dA * dA)
        if not TIE_HDIM:
            dB = B[None, :] * sqrt_term
        else:
            dB = B * sqrt_term
    elif DISCRETIZATION == 5:  # S7: identity (Bu pre-computed)
        if not TIE_HDIM:
            dB = B[None, :]
        else:
            dB = B
    else:
        if not TIE_HDIM:
            dB = B[None, :] * dt[:, None]
        else:
            dB = B * dt

    state = state * dA + dB * x[:, None]
    tl.store(
        state_ptrs,
        state,
        mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
    )
    out = tl.sum(state * C[None, :], axis=1)
    if HAS_D:
        out += x * D
    if HAS_Z:
        out *= z * tl.sigmoid(z)
    tl.store(out_ptrs, out, mask=offs_m < dim)


[docs] def selective_state_update( state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, deltaA=None, state_batch_indices=None, discretization="mamba", ): """ Triton-accelerated single-step state update for selective state space models. Args: state (torch.Tensor): Hidden state of shape ``(batch, dim, dstate)`` or ``(batch, nheads, dim, dstate)``. x (torch.Tensor): Input tensor of shape ``(batch, dim)`` or ``(batch, nheads, dim)``. dt (torch.Tensor): Timestep tensor of shape ``(batch, dim)`` or ``(batch, nheads, dim)``. A (torch.Tensor): State transition matrix of shape ``(dim, dstate)`` or ``(nheads, dim, dstate)``. B (torch.Tensor): Input projection matrix of shape ``(batch, dstate)`` or ``(batch, ngroups, dstate)``. C (torch.Tensor): Output projection matrix of shape ``(batch, dstate)`` or ``(batch, ngroups, dstate)``. D (torch.Tensor, optional): Skip connection vector of shape ``(dim,)`` or ``(nheads, dim)``. Defaults to None. z (torch.Tensor, optional): Gating tensor of shape ``(batch, dim)`` or ``(batch, nheads, dim)``. Defaults to None. dt_bias (torch.Tensor, optional): Bias for dt of shape ``(dim,)`` or ``(nheads, dim)``. Defaults to None. dt_softplus (bool, optional): Whether to apply softplus to dt. Defaults to False. deltaA (torch.Tensor, optional): Timestep for A discretization (dtA) in asymmetric mode, shape ``(batch, dim)`` or ``(batch, nheads, dim)``. Defaults to None. state_batch_indices (torch.Tensor, optional): Indices to select states for the batch, shape ``(batch,)``. Defaults to None. discretization (str, optional): Discretization method ('zoh', 'bilinear', 'dirac', 'mamba', 'rglru', 's7'). Defaults to "mamba". Returns: torch.Tensor: The output tensor of shape ``(batch, dim)`` or ``(batch, nheads, dim)``. """ has_heads = state.dim() > 3 if state.dim() == 3: state = state.unsqueeze(1) if x.dim() == 2: x = x.unsqueeze(1) if dt.dim() == 2: dt = dt.unsqueeze(1) if A.dim() == 2: A = A.unsqueeze(0) if B.dim() == 2: B = B.unsqueeze(1) if C.dim() == 2: C = C.unsqueeze(1) if D is not None and D.dim() == 1: D = D.unsqueeze(0) if z is not None and z.dim() == 2: z = z.unsqueeze(1) if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) if deltaA is not None and deltaA.dim() == 2: deltaA = deltaA.unsqueeze(1) _, nheads, dim, dstate = state.shape batch = x.shape[0] if x.shape != (batch, nheads, dim): print(f"{state.shape} {x.shape} {batch} {nheads} {dim}") assert x.shape == (batch, nheads, dim) assert dt.shape == x.shape assert A.shape == (nheads, dim, dstate) ngroups = B.shape[1] assert nheads % ngroups == 0, "nheads must be divisible by ngroups" assert B.shape == (batch, ngroups, dstate) assert C.shape == B.shape if D is not None: assert D.shape == (nheads, dim) if z is not None: assert z.shape == x.shape if dt_bias is not None: assert dt_bias.shape == (nheads, dim) if deltaA is not None: assert deltaA.shape == (batch, nheads, dim) if state_batch_indices is not None: assert state_batch_indices.shape == (batch,) # map discretization string to integer disc_map = { "zoh": 0, "bilinear": 1, "dirac": 2, "mamba": 3, "rglru": 4, "s7": 5, } disc_int = disc_map.get(discretization.lower(), 3) out = torch.empty_like(x) grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) z_strides = ( (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) ) deltaA_strides = ( (deltaA.stride(0), deltaA.stride(1), deltaA.stride(2)) if deltaA is not None else (0, 0, 0) ) # We don't want autotune since it will overwrite the state # We instead tune by hand. BLOCK_SIZE_M, num_warps = ( (32, 4) if dstate <= 16 else ( (16, 4) if dstate <= 32 else ( (8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))) ) ) ) tie_hdim = ( A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and (dt_bias is not None and dt_bias.stride(-1) == 0) ) with torch.cuda.device(x.device.index): _selective_scan_update_kernel[grid]( state, x, dt, dt_bias, A, B, C, D, z, deltaA, out, state_batch_indices, batch, nheads, dim, dstate, nheads // ngroups, state.stride(0), state.stride(1), state.stride(2), state.stride(3), x.stride(0), x.stride(1), x.stride(2), dt.stride(0), dt.stride(1), dt.stride(2), *( (dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else (0, 0) ), A.stride(0), A.stride(1), A.stride(2), B.stride(0), B.stride(1), B.stride(2), C.stride(0), C.stride(1), C.stride(2), *((D.stride(0), D.stride(1)) if D is not None else (0, 0)), z_strides[0], z_strides[1], z_strides[2], deltaA_strides[0], deltaA_strides[1], deltaA_strides[2], out.stride(0), out.stride(1), out.stride(2), dt_softplus, tie_hdim, BLOCK_SIZE_M, DISCRETIZATION=disc_int, num_warps=num_warps, ) if not has_heads: out = out.squeeze(1) return out
[docs] def selective_state_update_ref( state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, deltaA=None, discretization="mamba", ): """ Reference (pure PyTorch) implementation of the single-step selective state update. Args: state (torch.Tensor): Hidden state of shape ``(batch, dim, dstate)`` or ``(batch, nheads, dim, dstate)``. x (torch.Tensor): Input tensor of shape ``(batch, dim)`` or ``(batch, nheads, dim)``. dt (torch.Tensor): Timestep tensor of shape ``(batch, dim)`` or ``(batch, nheads, dim)``. A (torch.Tensor): State transition matrix of shape ``(dim, dstate)`` or ``(nheads, dim, dstate)``. B (torch.Tensor): Input projection matrix of shape ``(batch, dstate)`` or ``(batch, ngroups, dstate)``. C (torch.Tensor): Output projection matrix of shape ``(batch, dstate)`` or ``(batch, ngroups, dstate)``. D (torch.Tensor, optional): Skip connection vector of shape ``(dim,)`` or ``(nheads, dim)``. Defaults to None. z (torch.Tensor, optional): Gating tensor of shape ``(batch, dim)`` or ``(batch, nheads, dim)``. Defaults to None. dt_bias (torch.Tensor, optional): Bias for dt of shape ``(dim,)`` or ``(nheads, dim)``. Defaults to None. dt_softplus (bool, optional): Whether to apply softplus to dt. Defaults to False. deltaA (torch.Tensor, optional): Timestep for A discretization (dtA) in asymmetric mode, shape ``(batch, dim)`` or ``(batch, nheads, dim)``. Defaults to None. discretization (str, optional): Discretization method ('zoh', 'bilinear', 'dirac', 'mamba', 'rglru', 's7'). Defaults to "mamba". Returns: torch.Tensor: The output tensor of shape ``(batch, dim)`` or ``(batch, nheads, dim)``. """ has_heads = state.dim() > 3 if state.dim() == 3: state = state.unsqueeze(1) if x.dim() == 2: x = x.unsqueeze(1) if dt.dim() == 2: dt = dt.unsqueeze(1) if A.dim() == 2: A = A.unsqueeze(0) if B.dim() == 2: B = B.unsqueeze(1) if C.dim() == 2: C = C.unsqueeze(1) if D is not None and D.dim() == 1: D = D.unsqueeze(0) if z is not None and z.dim() == 2: z = z.unsqueeze(1) if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) if deltaA is not None and deltaA.dim() == 2: deltaA = deltaA.unsqueeze(1) batch, nheads, dim, dstate = state.shape assert x.shape == (batch, nheads, dim) assert dt.shape == x.shape assert A.shape == (nheads, dim, dstate) ngroups = B.shape[1] assert nheads % ngroups == 0, "nheads must be divisible by ngroups" assert B.shape == (batch, ngroups, dstate) assert C.shape == B.shape if D is not None: assert D.shape == (nheads, dim) if z is not None: assert z.shape == x.shape if dt_bias is not None: assert dt_bias.shape == (nheads, dim) if deltaA is not None: assert deltaA.shape == (batch, nheads, dim) dt = dt + dt_bias if dt_bias is not None else dt dt = F.softplus(dt) if dt_softplus else dt if discretization.lower() == "rglru": gate = rearrange( deltaA if deltaA is not None else dt, "b h d -> b h d 1" ) dA = A**gate # (batch, nheads, dim, dstate) elif discretization.lower() == "s7": dt_exp = rearrange(dt, "b h d -> b h d 1") dA = 1.0 - 1.0 / (dt_exp * dt_exp + 0.5) elif deltaA is not None: dA = torch.exp( rearrange(deltaA, "b h d -> b h d 1") * A ) # (batch, nheads, dim, dstate) else: dA = torch.exp( rearrange(dt, "b h d -> b h d 1") * A ) # (batch, nheads, dim, dstate) B = repeat( B, "b g n -> b (g h) n", h=nheads // ngroups ) # (batch, nheads, dstate) C = repeat( C, "b g n -> b (g h) n", h=nheads // ngroups ) # (batch, nheads, dstate) dt_expanded = rearrange(dt, "b h d -> b h d 1") B_expanded = rearrange(B, "b h n -> b h 1 n") if discretization.lower() == "zoh": A_dt = dt_expanded * A # (batch, nheads, dim, dstate) expm1_A_dt = torch.exp(A_dt) - 1.0 B_tilde = expm1_A_dt / A # (batch, nheads, dim, dstate) dB = B_expanded * B_tilde elif discretization.lower() == "bilinear": v = 0.5 * dt_expanded * A den_inv = 1.0 / (1.0 - v) dB = B_expanded * den_inv * dt_expanded elif discretization.lower() in ("dirac", "s7"): dB = B_expanded.expand_as(dA) elif discretization.lower() == "rglru": sqrt_term = torch.sqrt(1.0 - dA * dA) dB = B_expanded * sqrt_term else: dB = dt_expanded * B_expanded state.copy_( state * dA + dB * rearrange(x, "b h d -> b h d 1") ) # (batch, dim, dstate) out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) if D is not None: out += (x * D).to(out.dtype) out = (out if z is None else out * F.silu(z)).to(x.dtype) if not has_heads: out = out.squeeze(1) return out