Nathan's Notes


Mamba: Selective State Space Modeling

An introduction to Mamba models: faster and better* than transformers

Jul 22, 2024

This is a blog post regarding a talk I gave at K-Scale Labs to help everyone understand Mamba. The slides are here. State Space Models (SSMs) are one of the few machine learning archetypes that is competitive with Mamba in both inference speed and effectiveness — not to mention, they have some pretty cool intuition!

Papers + Resources

Some sources to look for more insight + math:

Introduction

What’s the problem with transformers?

There are a few:

If we can address these problems with a different architecture, we can hopefully be better at long-range tasks, synthesizing speech and video (which are packed with context), training models quicker, and inference on smaller devices.

What are these “State Space Models (SSMs)”?

SSMs are characterized by two equations x(t)=Ax(t)+Bu(t)y(t)=Cx(t)+Du(t) \begin{aligned} x'(t) &= Ax(t) + Bu(t) \\ y(t) &= Cx(t) + Du(t) \end{aligned}

However, we can simply interpret this as a function-to-function mapping u(t)y(t)u(t) \mapsto y(t) parameterized by AA, BB, CC, and DD (fixed “latent” parameters). Then, x(t)x(t) is a latent representation satisfying the ODE. The important aspect of SSMs is that they have three views:

ssm_representations

  1. Continuous: This is simply our original equation — a function-to-function mapping of u(t)y(t)u(t) \mapsto y(t) following the equation denoted above.
  2. Recurrent: We can map our linear ODE to discrete steps in an RNN, with similar intuition for how we approach Flow Matching (see slide 4). We essentially turn our function-to-function mapping into a sequence-to-sequence mapping through interpolation

xk=Axk1+Bukyk=Cxk+Duk \begin{aligned} x_k &= Ax_{k-1} + Bu_k \\ y_k &= Cx_k + Du_k \end{aligned}

  1. Convolutional: We unroll our recurrent view, solving NN steps ahead with the recurrence to get a kernel of size NN, replugging-in values such that

x0=Bu0x1=ABu0+Bu1x2=A2Bu0+ABu1+Bu2 x_0 = \overline{B}u_0 \quad x_1 = \overline{A}B u_0 + \overline{B}u_1 \quad x_2 = \overline{A}^2 B u_0 + \overline{A}B u_1 + \overline{B}u_2 \quad \cdots

Each of these views have tradeoffs, but the key version we want to focus on is the recurrent view. It has effectively no context window (unlike the convolutional view), which addresses the issue we have with transformers. On the other hand, it has a efficient constant time inference (unlike the continuous view), which once again addresses another problem we had with transformers.

However, recurrent networks are not easily parallelizable in training, as each inference is dependent on the previous time-step. Additionally, we’re at risk of vanishing/exploding gradients. How do we fix these problems?

Let’s drop the state space model idea — at this point, SSMs are an RNN that has its roots in continuous space and the possibility to become a convolution. Contrarily, an RNN models information statefully by being able to merge all of its previously seen context into one state.

Linear RNNs

Parallelizing

Having an activation function makes our calculations way too complicated and difficult to parallelize. What we really want is a way to calculate the accumulation of all an RNN’s layer multiplications quickly and easily. For each iteration with

ht=fW(ht1,xt)h_t = f_W(h_{t-1}, x_t)

we can’t possibly expect condense these calculations quickly with a nonlinear activation function like tanh\tanh Instead, let’s remove the activation function so we’re left with

fW(h,x)=Whh+Wxxf_W(h, x) = W_h h + W_x x

How do we parallelize this? If we think about the Blelloch algorithm for parallel prefix sum, we can find some inspiration. With Blelloch, we are able to take the prefix sum of an array of length nn in O(log(N))O(\log(N)) sequential steps through parallelization simply because the addition property is associative. We do this by accumulating the sum over different step sizes.

blelloch

Luckily there’s an associative function we can utilize for our new activiationless RNN layers.

f((W1,x1),(W2,x2))=(W1W2,W1x1+x2)f((W_1, x_1), (W_2, x_2)) = (W_1 W_2, W_1 x_1 + x_2)

Now, we are able to apply this function do our recurrence in O(Nlog(N))O(N\log(N)) (we are folding in WhW_h and WxW_x together). But, there is an additional problem where we have to store WiWjWk W_i W_j \dots W_k during the scanning which results in a very large cubic cache as we have to store a I ⁣Rd×d {\rm I\!R}^{d \times d} matrix for each input during the Blelloch scan. Luckily, because WiI ⁣Rd×d W_i \in {\rm I\!R}^{d \times d} , we can easily diagonalize these intermediate products and keep our cache quadratic.1

Boom! Now we have a quick and parallelizable RNN. To add nonlinearity (which we still want), we can simply add an additional layer to our RNN which is applied to the recurrent part’s outputs (just like the feed forward layer after self-attention layers).

Avoiding Exploding/Vanishing Gradients

RNNs are incredibly sensitive and any small errors could result in exploding/vanishing gradients such that the our model never converges. Further, we can’t make use of gradient truncation as it would superficially make our model short range.

The solution is how we initialize our weights. If we make them very small, our gradients will be small as well. The creators of Linear RNN initialize each parameter ww such that

w=eeaeib,eeaUniform([0.999,1.0]),bUniform([0,π10])w = e^{-e^a}e^{ib}, e^{-e^a} \sim \text{Uniform}([0.999, 1.0]), b \sim \text{Uniform}([0, \frac{\pi}{10}])

On the other hand, our parameters are also very sensitive to inputs. We want to reduce their magnitude as much as possible. So, we multiply each input by Δ=1eea\Delta = \sqrt{1 - e^{-e^a}}. From there, we’re done! 2

Mamba

Mamba improves upon using one key idea: selection. With just a recurrent network, after many iterations, it is very easy to have to hold too much information in hth_t. The RNN should have some idea of attention — what should be retained in hth_t and what should be removed?

The eloquent solution, as expected, is gates. We want WhW_h and WxW_x (preserving notation from traditional RNNs) to be dependent on xix_i itself. Below shows the exact psuedocode for an SSM algorithm.3

selection_algorithm

The most important thing about this algorithm is that the “selecting” functions and latent parameters4 (specifically AA and BB) can be chosen such that A=1A=-1, B=1B=1, sΔ=Linear(xt)s_\Delta = \text{Linear}(x_t), and τΔ=softplus\tau_\Delta = \text{softplus} where the result is that each channel of our RNN having a gate is characterizable by

gt=σ(Linear(xt))ht=(1gt)ht1+gtxt \begin{aligned} g_t &= \sigma(\text{Linear}(x_t)) \\ h_t &= (1 - g_t)h_{t-1} + g_tx_t \end{aligned}

This provides us an intuitive gate and shows us that we are really just considering how much we want to “remember” the current input at each timestep. Now, we can selectively recall information in very long context windows. Not to mention, our recurrence is only O(NlogN)O(N \log N) with the parallelization step.5 Now, we have Mamba!

  1. Something pretty cool — PP and P1P^{-1} in the diagonalization PDP1PDP^{-1} are actually learned matrices to add more expressivity, while not having to worry about any matrix-inverting problems. 

  2. SSMs, in aims to convert the continuous-time view into a recurrent view, are remotely the same but with a different initialization scheme. In such, weights are initialized where w=eΔ(a+bi) w = e^{\Delta(a+bi)} and inputs are multiplied by (Δ(a+bi))1(w1)Δ (\Delta(a+bi))^{-1}(w-1) \circ \Delta where Δ[0.1,0.0001] \Delta \in [0.1, 0.0001] . Otherwise, our intuition that the recurrent SSM view is the same as Linear RNN holds. 

  3. Notice that the parameter DD is ignored because it is easily computable as a skip connection. 

  4. The Mamba paper offers interesting intuition regarding each of the parameters in the 3.5.2 Interpretation of Selection Mechanisms section. 

  5. It’s pretty interesting to explore handware-aware understanding of runtime. There are cool optimizations such as recomputing intermediate steps during the backward pass for backpropagation in order to reduce cache size during training. Additionally, similar to FastAttention, we want to do all our heavy linear algebra (recurrence) on SRAM, where this data transfer ends up being the largest bottleneck to training speed more than the recurrence itself.