General Introduction
FlashMLA is an efficient MLA (Multi-head Latent Attention) decoding kernel developed by DeepSeek AI, optimized for NVIDIA Hopper Architecture GPUs, and designed to improve the performance of variable-length sequence processing. The project has been open sourced on GitHub and is available to developers for free. It supports BF16 precision computation and paged KV caching (with a block size of 64), and performs well on the H800 SXM5, with up to 3000 GB/s of bandwidth in memory-intensive configurations, and up to 580 TFLOPS in compute-intensive configurations. FlashMLA was inspired by FlashAttention 2&3 and the Cutlass project, and is suitable for production environments out of the box. DeepSeek AI has demonstrated its innovative capabilities in AI technology through this open source project, which has attracted a lot of attention.
Function List
- Efficient MLA Decoding: Optimized for Hopper GPUs to significantly speed up processing of variable-length sequences.
- Supports BF16 accuracy: Utilizes half-precision floating-point operations to improve computational efficiency while maintaining precision.
- Paging KV Caching: A paging mechanism with a block size of 64 is used to effectively manage memory and improve inference performance.
- High performance: Provides up to 3000 GB/s of memory bandwidth and 580 TFLOPS of compute power on the H800 GPU.
- open source: Full source code is provided to support developers with custom modifications and integration.
Using Help
Installation process
FlashMLA is an open source project based on GitHub, before using it, you need to ensure that the environment meets the requirements and complete the installation. Here are the detailed steps:
1. Environmental readiness
- operating system: Support for Linux systems (Ubuntu 20.04 or above recommended).
- hardware requirement: Requires an NVIDIA Hopper Architecture GPU (such as the H800 SXM5).
- software dependency::
- CUDA 12.6 or above (see NVIDIA website for installation instructions).
- PyTorch 2.0 or above (recommended via
pip install torch
(Installation). - Python 3.8 or above.
- Inspection tools: Make sure Git is installed for downloading code from GitHub.
2. Downloading the source code
- Open a terminal and enter the following command to clone the FlashMLA repository:
git clone https://github.com/deepseek-ai/FlashMLA.git
- Go to the project catalog:
cd FlashMLA
3. Installation of dependencies
The project depends on PyTorch and CUDA, which can be installed with the following commands:
pip install -r requirements.txt
If not requirements.txt
file, it is straightforward to ensure that PyTorch is installed:
pip install torch torchvision
Verify that CUDA is available:
python -c "import torch; print(torch.cuda.is_available())"
exports True
Indicates successful environment configuration.
4. Compilation and testing
FlashMLA provides pre-compiled CUDA plug-ins, but make sure to match your local CUDA version:
- Go to the source directory and run the compilation script (if any):
python setup.py install
- Test the installation for success and run the sample code:
python example.py
If no errors are reported, the installation is complete.
How to use
The core feature of FlashMLA is to provide efficient MLA decoding support for AI model inference tasks. Here are the steps:
Function 1: Load and Run FlashMLA
- Import Module::
Introduces FlashMLA core functions into Python scripts:from flash_mla import get_mla_metadata, flash_mla_with_kvcache
- Preparing to enter data::
cache_seqlens
: Defines the sequence length of the KV cache.q_i
: Query tensor.kvcache_i
: KV cached data.block_table
: Block table for paging cache.
- Getting Metadata::
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
- running decode::
o_i, lse_i = flash_mla_with_kvcache(q_i, kvcache_i, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, causal=True)
exports
o_i
is the decoding result.lse_i
for the log sum value.
Function 2: Optimize variable-length sequence processing
- take: When dealing with dynamic-length input sequences, FlashMLA reduces the memory footprint by paging the KV cache.
- manipulate::
- Configure paging parameters: the block size is fixed at 64 and can be adjusted by adjusting the
cache_seqlens
Controls the length of the sequence. - Specify at runtime
causal=True
(c) Ensure that the causal attention mechanism is in effect.
- Configure paging parameters: the block size is fixed at 64 and can be adjusted by adjusting the
- effect: 3000 GB/s of memory bandwidth on the H800 for large-scale inference tasks.
Function 3: Performance Testing
- Test Methods::
- Edit the sample script (e.g.
example.py
), increasing the size of the input data. - Use the following code to record performance:
import time start = time.time() # Run the decoding code o_i, lse_i = flash_mla_with_kvcache(...) print(f "Time: {time.time() - start} seconds")
- Edit the sample script (e.g.
- Expected results: Nearly 3000 GB/s for memory-intensive tasks and 580 TFLOPS for compute-intensive tasks.
caveat
- hardware compatibility: Only Hopper GPUs are supported, H800 or equivalent is recommended.
- Debugging Tips: If you encounter CUDA errors, check for version matches or seek community support in GitHub Issues.
- production environment: Integrate directly into existing model inference processes, ensuring that input data formats are consistent with FlashMLA requirements.
With the above steps, users can quickly get started with FlashMLA and enjoy the performance improvement brought by its efficient decoding. The complete code and documentation can be found in the GitHub repository, and it is recommended to adjust the parameters according to the actual project requirements.