An Alternative to Transformer for Language Modeling
The Transformer architecture is a key component of the success of Large Language Models (LLMs). Almost all large language models in use today employ the architecture, from open source models such as Mistral to closed-source models such as ChatGPT.
In order to further improve the large language model, new architectures have been developed that may even go beyond the Transformer architecture. One such approach is the Mambaone type of state-space modelThe
Mamba in the paper Mamba: Linear-Time Sequence Modeling with Selective State Spaces1 in which it was presented. You can find out more about this in its repository Find its official implementation and model checkpoints in.
In this post, I will introduce the field of state-space modeling in the context of language modeling and step-by-step explore various concepts to help understand the field. We will then discuss how Mamba might challenge the Transformer architecture.
As a visual guide, this article will go through many visualizations to help understand Mamba and state space models!
Part 1:The issue of Transformers
To illustrate what an interesting architecture Mamba is, let's first briefly review Transformers and explore one of its shortcomings.
The Transformer treats any text input as if it were being created by the tokens constituent sequenceThe
One of the main advantages of Transformers is that no matter what input it receives, it can go back to any earlier tokens in the sequence to derive its representation.
Core Components of Transformers
Remember that the Transformer consists of two structures, a set of encoder blocks for representing text and a set of decoder blocks for generating text. Combined, these structures can be used for multiple tasks, including translation.
We can adopt this structure to create generative models by using only decoders. This Transformer-based model thatGenerative Pre-training Transformer(GPT) that uses decoder blocks to complete some of the input text.
Let's see how this works!
Blessings in Training...
A single decoder block consists of two main components, the masked self-attention mechanism and the feedforward neural network.
The self-attention mechanism is an important reason why these models work so well. It allows for an uncompressed view of the entire sequence and fast training.
So how does it work?
It creates a matrix that separates each Token Compare this with each previous Token. The weights in the matrix depend on the relevance of the Token pairs to each other.
This matrix is created all at once during training. "My"and"name"Attention between does not need to be calculated before it can be calculated."name"and"is"Attention between.
It realizesparallelization, which greatly speeds up training!
Problems in reasoning!
However there is a drawback. When generating the next token, we need to recalculate theentire sequenceattention, even if we have already generated some Token.
Generate the length of theLThe sequence requires aboutL²times, which can be computationally expensive if the length of the sequence increases.
This need to recompute the entire sequence is one of the main bottlenecks of the Transformer architecture.
Let's see how a "classic" technique, Recurrent Neural Networks (RNN), solves this slow inference problem.
Is RNN the solution?
Recurrent Neural Network (RNN) is a sequence-based network. It receives two inputs at each time step, i.e., time steptinput and the previous time stept-1of the hidden state for generating the next hidden state and predicting the output.
The RNN has a looping mechanism that allows it to pass information from one step to the next. We can "expand" this visualization to make it more explicit.
When generating the output, the RNN only needs to consider the previous hidden state and the current input. It avoids the problem of recomputing all the previous hidden states required by the Transformer.
In other words, an RNN is capable of fast inference because it scales linearly with sequence length! In theory, it could even haveUnlimited context lengthThe
To illustrate this, let's apply the RNN to the input text we used earlier.
Each hidden state is an aggregation of all previous hidden states, usually a compressed view.
However, there is a problem here ......
Note that when generating the name "Maarten" when the last hidden state no longer contains information about the word "Hello."RNNs tend to forget information over time, because they only consider the last state.
While RNNs are faster in terms of training and inference, they lack the accuracy that Transformer models can provide.
Therefore, we study State Space Models to efficiently use RNNs (and sometimes convolution).
Part 2:State Space Model (SSM)
State Space Models (SSMs) like Transformers and RNNs deal with sequences of information, such as text and signals. In this section, we introduce the basic concepts of SSMs and how they relate to textual data.
What is state space?
A state space contains the minimum number of variables needed to completely describe a system. It is a way to mathematically represent a problem by defining the possible states of a system.
Let's simplify. Imagine we are traveling through a maze. "state space"It's a map of all possible locations (states). Each point represents a unique location in the maze with specific details, such as how far you are from the exit.
"state-space representation"is a simplified description of this map. It shows your current location (current state), the next location you can travel to (possible future state), and how you can get to the next location (moving right or left).
While the state space model uses equations and matrices to track this behavior, it's really just a way of tracking where you are, where you can go, and how to get there.
The variables describe a state, in our case the X and Y coordinates, and the distance to the exit, which can be expressed as "state vector".
Sound familiar? That's because embeddings or vectors are also often used in language models to describe the "state" of an input sequence. For example, a vector describing your current position (a state vector) might look something like this:
In neural networks, "state" usually refers to the hidden state of the system, which is one of the most important aspects of generating new tokens in large language models.
What is a state space model?
SSMs (State Space Models) are a class of models used to describe these state representations and predict their next state, with predictions based on certain inputs.
Traditionally, in time t, SSM:
- The input sequence will be x(t)(e.g., moving left and down in a maze) maps to the potential state representation h(t)(e.g., distance to exit and x/y coordinates)
- and derive the predicted output sequence y(t)(e.g., move left again to get to the exit faster)
However, instead of using discrete sequences (e.g., shifting left once), SSM accepts continuous sequences as input and predicts the output sequence.
SSM assumes that dynamic systems (e.g., objects moving in 3D space) can be temporally t of the state and two equations to predict it.
By solving these equations, we hypothesize that we can reveal statistical principles for predicting the state of the system based on observed data (input sequences and prior states).
The goal is to find this state representation h(t)that allows us to go from input to output sequences.
These two equations are at the heart of the state space model.
These two equations will be referenced throughout the guide. In order to visualize them, theWe use color coding, so that you can quote them quickly.
equation of state describes how the state will be determined based on the inputs (via the Matrix B) influence state (through Matrix A) and change.
As mentioned earlier.h(t) refers to any given time t The potential state representation of thex(t) refers to a certain input.
output equation describes how the state is passed through the Matrix C converted to output, and how inputs are converted to outputs by the Matrix D Affects output.
take note of: Matrix A,B,C cap (a poem) D Also often referred to as parameters, because they are learnable.
Visualizing these two equations, we get the following architecture:
Let's take a step-by-step look at how these matrices affect the learning process.
Suppose we have some input signals x(t)This signal is first connected to the Matrix B Multiply.Matrix B Describes how inputs affect the system.
The updated state (similar to the hidden state of a neural network) is a potential space that contains the core "knowledge" of the environment. We relate the state to the Matrix A Multiplied together, this matrix describes the associations between all internal states and represents the underlying dynamics of the system.
As you may have noticed.Matrix A Apply before the state representation is created and again after the state representation is updated.
Then, we use the Matrix C to describe how state is converted to output.
Finally, we can use the Matrix D Provides a direct signal from input to output. This is also commonly referred to as jump connectionThe
due to Matrix D Similar to jump connections, SSM is often considered to be the following form that does not contain jump connections.
Returning to our simplified perspective, we can now focus on the matrix A,B cap (a poem) C up as the core of SSM.
We can update the original equations (and add some nice colors) to mark the purpose of each matrix, as we did before.
Together, these two sets of equations aim to predict the state of the system from the observed data. Since the inputs are continuous, the main representation of the SSM is continuous time scaleThe
From continuous to discrete signals
If you have a continuous signal, find the state representation h(t) challenging to parse. In addition, since we typically have discrete inputs (e.g., text sequences), we would like to discretize the model.
To do this, we use the zero-order holding technique. It works as follows: first, whenever we receive a discrete signal, we hold its value until we receive a new discrete signal. This process generates a continuous signal for SSM:
We keep the value of the time by a new parameter that can be learned pacemaker ∆ Indicates. It indicates the resolution of the input.
Now that we have generated a continuous signal for the input, we can next generate a continuous output and sample these values based on the time step of the input.
These sampled values are the output of our discretization!
Mathematically, we can apply the zero-order hold as follows:
Together they allow us to transition from a continuous SSM to a discrete SSM. at this point the model is no longer function-to-function x(t) → y(t)Instead sequence-to-sequence xₖ → y_ₖ::
Here, the matrix A cap (a poem) B Now denote the discretization parameters of the model.
We use k instead of tto distinguish when we are talking about continuous SSM versus discrete SSM.
Attention: During training, we still retain Matrix A in its continuous form, rather than its discretized version. During training, the continuous representation is discretized.
Now that we have a formula for the discretized representation, let's explore how to actually count The model.
recursive representation
Our discretized SSM allows us to construct the problem in a specific time step, rather than in a continuous signal. As we have seen before in RNNs, recursive methods are very useful here.
If we consider discrete time steps instead of continuous signals, we can reformulate the problem using time steps:
At each time step, we compute the current input (Bxₖ) how it affects the previous state (Ahₖ₋₁) and then calculate the predicted output (Chₖ).
This representation may already be familiar to you! We can analyze it as we did with the previous treatment of RNNs.
We can expand this (or expand it into a series of time steps) as follows:
Note that we can use a discretized version of this using the basic methods of RNNs.
convolution representation (math.)
We can use convolution to represent SSM. remember that in classical image recognition tasks we apply filters (kernels) to extract aggregated features:
Since we are dealing with text and not images, we need to use a one-dimensional view:
The kernel we use to represent this "filter" is derived from the SSM formula:
Let's see what this kernel does in practice. Like convolution, we can use the SSM kernel to traverse each set of tokens and compute the output:
This also demonstrates the effect that padding may have on the output. I changed the order of the padding to improve the visualization, but we usually apply padding at the end of the sentence.
In the next step, the kernel moves once to perform the next step of the calculation:
In the last step, we can see the full effect of the kernel:
One of the main benefits of representing SSMs as convolutions is that they can be trained in parallel like convolutional neural networks (CNNs). However, due to the fixed kernel size, their inference is not as fast and unconstrained as RNNs.
Three representations
These three representations -progression,recursive (calculation) cap (a poem) convolution Each has different advantages and disadvantages:
Interestingly, we can now utilize recursive SSMs for efficient inference while utilizing convolutional SSMs for parallel training.
Using these representations, we can use a clever trick of choosing a representation based on the task. During training, we use a convolutional representation that can be parallelized, while during inference we use an efficient recursive representation:
This model is called Linear State Space Layer (LSSL)The 2
These representations share an important property, namelylinear time invariant (math.)(LTI indicates the parameters of SSM. A,B cap (a poem) C is fixed at all time steps. This means that the matrix A,B cap (a poem) C is the same for every Token generated.
In other words, whatever sequence you give SSM.A,B cap (a poem) C The values of all of them remain the same. We have a static representation that doesn't care about content.
Before exploring how Mamba solves this problem, let's explore the last piece of the puzzle - theMatrix AThe
matrices A importance of
Arguably, one of the most important aspects of the SSM formula is the Matrix A. As we saw before in the recursive representation, it captures information about the previous one information about the state to construct the renewed Status.
Essentially.Matrix A Generate a hidden state:
Therefore, it is important to create Matrix A may determine whether we can remember just a few previous tokens or capture every token we've seen so far. especially in the context of recursive representations, since it only recalls that Previous stateThe
How to create in a way that retains a lot of memory (context size) matrix A?
We use Hungry Hungry Hippo! or HiPPO3 to realizeyour (honorific)stairs how (what extent)term (math.) place oneself into the hands ofmovie fateoperator.HiPPO attempts to compress all the input signals it has seen so far into a vector of coefficients.
It uses matrix A to construct a state representation that better captures recent tokens and decays older tokens.The formula can be expressed as follows:
Suppose we have a square matrix A, which gives us:
Building with HiPPO matrix A was shown to be much better than initializing it to a random matrix. Therefore, it is better to reconstruct the update The signal (nearest token) aspect is more important than older signal (initial token) is more accurate.
The core idea of the HiPPO matrix is to generate a hidden state that memorizes its history.
Mathematically, it does this by tracking Legendre polynomial coefficients to accomplish this, which allows it to approximate all historical records.4
HiPPO is then applied to the recursive and convolutional representations we have seen before to handle long-range dependencies. The result. Structured State Space for Sequences (S4), an SSM class that can efficiently handle long sequences.5
It consists of three parts:
- state-space model
- HiPPO for processing long-distance dependence
- Discretization is used to create recursive (calculation) cap (a poem) convolution indicate
This type of SSM has several advantages, depending on the representation you choose (recursive vs. convolutional). It can also handle long text sequences and store memories efficiently by being based on HiPPO matrices.
take note of: If you want to get a deeper understanding of how to compute the HiPPO matrix and build your own S4 model, I highly recommend reading Annotated S4The
Part 3:Mamba - a selective state space model
We have finally covered all the basics needed to understand what makes Mamba special. State space models can be used to model text sequences, but there are still a number of drawbacks that we would like to avoid.
In this section, we will discuss the two main contributions of Mamba:
- one kind of Selective Scanning Algorithm, allowing the model to screen for (un)relevant information
- one kind of Hardware-aware algorithmsBy parallel scan,kernel fusion cap (a poem) recalculate to efficiently store (intermediate) results.
Together, these two create the selective state-space model maybe S6 model that can be used like self-attention to create Mamba blockThe
Before exploring these two main contributions, let us first explore why they are necessary.
What is the problem being attempted to be solved?
State-space models, even S4 (structured state-space models), perform poorly on certain key tasks in language modeling and generation, namely Ability to focus on or ignore specific inputsThe
We can illustrate this with two synthetic tasks, namely selective copying cap (a poem) induction headThe
exist selective copying task, the goal of SSM is to copy the input parts and output them in order:
However, since SSM is Linear time invariant(used form a nominal expression), it performs poorly on this task. As we have seen before, the matrix A,B cap (a poem) C Each token generated is the same for SSM.
As a result, SSM is unable to perform Content-aware reasoningThis is a problem because we want SSM to reason about the inputs (hints). because it treats each token equally due to a fixed matrix of A, B, and C. This is a problem because we want SSM to reason about the inputs (hints).
SSM performs poorly on another task, that of induction head, whose goal is to reproduce patterns found in the input:
In the above example, we are essentially performing a one-time cue in which we try to "teach" the model to "teach" in each "Q." after providing a "A." response. However, since SSM is time-invariant, it cannot choose which previous tokens to recall from the history.
Let's do it by focusing on Matrix B to illustrate this point. Regardless of the input x What it is.Matrix B is always the same, and therefore the same as x Irrelevant:
Likewise.A cap (a poem) C also remains fixed at all times, independent of the input. This suggests that what we have seen so far of SSM's static (as in electrostatic force) Characteristics.
In contrast, these tasks are relatively simple for the Transformer, since they are based on the input sequence dynamic (science) Changing attention. They can selectively "watch" or "focus" on different parts of the sequence.
The poor performance of SSM on these tasks illustrates the potential problems of time-invariant SSM, matrix A,B cap (a poem) C The static properties of the Content Sense of the problem.
Selective retention of information
The recursive representation of SSM creates a smaller state that is very efficient because it compresses the entire history. However, compared to the Transformer model, the Transformer model does not compress the history (via the Attention Matrix), so it is more capable.
Mamba aims to have the best of both worlds. A small state that is as powerful as a Transformer state:
As mentioned above, it does this by selectively compressing data into states. When you have an input sentence, there is usually some information, such as a stop word, that doesn't make much sense.
In order to selectively compress information, we need the parameters to be dependent on the input. To this end, let us first explore the dimensions of the inputs and outputs of SSM during training:
In the structured state space model (S4), the matrix A,B cap (a poem) C are independent of the inputs because their dimensions N cap (a poem) D is static and does not change.
Instead, Mamba, by combining the input sequence length and batch size, makes the matrix B cap (a poem) Cso much so that pacemaker ∆_, dependent on input:
This means that for each input token, we now have a different B cap (a poem) C matrix, which solves the problem of content perception!
take note of: Matrix A remains unchanged because we want the state itself to remain static, but the way it is affected (via the B cap (a poem) C) is dynamic.
together selectively Select what to keep in the hidden state and what to ignore as they are now dependent on the input.
smaller pacemaker ∆ leads to ignoring specific words in favor of using the previous context more often, and larger pacemaker ∆ Instead, the focus is more on the input vocabulary than on the context:
scanning operation
Since these matrices are now dynamic (science) s, they cannot be computed using the convolutional representation, because the convolutional representation assumes a set rigidly in place of the convolution kernel. We can only use recursive representations, which loses the parallelization advantage provided by convolution.
To achieve parallelization, let's explore how to compute the output using recursion:
Each state is the previous state (multiplied by the A) with the current input (multiplied by B) of and. This is known as the scanning operationThis can be easily calculated with a for loop.
In contrast, parallelization seems impossible because each state can only be computed after a previous state is available. Mamba, however, has made it possible to parallelize the state of each state by parallel scan Algorithms make this possible.
It assumes that the order in which we perform the operations does not matter, utilizing a property of the law of union. Thus, we can compute the sequence in parts and then iteratively combine them:
dynamic matrix B cap (a poem) C and parallel scanning algorithms come together to create Selective Scanning Algorithmto represent the dynamic and fast properties of using recursive representations.
Hardware-aware algorithms
One of the drawbacks of recent GPUs is the limited transfer (IO) speed between their small but efficient SRAM and their large but slightly less efficient DRAM. Frequent copying of information between SRAM and DRAM can become a bottleneck.
Mamba is similar to Flash Attention in that it attempts to limit the number of trips from DRAM to SRAM and back. It does this by kernel fusion Implementing this allows the model to prevent writing intermediate results and to keep executing the computation until it is complete.
We can see specific examples of DRAM and SRAM allocation by visualizing the basic architecture of Mamba:
Here, the following are fused into one kernel:
- The discretization step is similar to the step size ∆
- Selective Scanning Algorithm
- together with C multiplication
The last part of the hardware-aware algorithm is recomputationThe
The intermediate states are not saved, but are necessary to compute the gradient in the reverse pass. Instead, the authors recalculate these intermediate states during the reverse pass.
While this may seem inefficient, it is much cheaper than reading all this intermediate state from relatively slow DRAM.
We have now covered all the components of its architecture, which is pictured below in its article:
Selective SSM. Retrieved from: Gu, Albert and Tri Dao. "Mamba: linear time series modeling with selective state spaces." arXiv preprint arXiv:2312.00752 (2023).
This architecture is often referred to as Selective SSM maybe S6 model, as it is essentially an S4 model computed using the selective scanning algorithm.
Mamba Module
What we have explored so far Selective SSM can be implemented as a module, just as we can represent self-attention in a decoder module.
Like a decoder, we can stack multiple Mamba modules and use their output as input for the next Mamba module:
It starts with a linear projection to extend the input embedding. Then, the Selective SSM Convolution is applied before to prevent independent token computation.
Selective SSM has the following characteristics:
- pass (a bill or inspection etc) discretization created Recursive SSM
- in matrix (math.) A go ahead HiPPO Initialize to capture long range dependence
- Selective Scanning Algorithm to selectively compress information
- Hardware-aware algorithms to speed up calculations
When we look at the code implementation, we can extend this architecture even more and explore what an end-to-end example would look like:
Note some changes, such as the addition of a normalization layer and a softmax for selecting the output token.
When we put it all together, we get fast inference and training, even with infinite context. Using this architecture, the authors found that its performance matched and sometimes exceeded that of a Transformer model of the same size!
reach a verdict
This concludes our exploration of state space models and the incredible Mamba architecture using selective state space models. Hopefully, this post has given you a better understanding of state space models and Mamba in particular. who knows if this will replace Transformers, but for now, it's amazing to see such a different architecture get the attention it deserves!
To see more visualizations related to large language models and to support this newsletter, check out the book I co-authored with Jay Alammar.