Introdução geral
O FlashMLA é um kernel de decodificação MLA (Multi-head Latent Attention) eficiente desenvolvido pela DeepSeek AI, otimizado para GPUs NVIDIA Hopper Architecture e projetado para melhorar o desempenho do processamento de sequências de comprimento variável. O projeto tem código aberto no GitHub e está disponível para desenvolvedores gratuitamente. Ele oferece suporte à computação de precisão BF16 e ao cache KV paginado (tamanho de bloco 64) e apresenta bom desempenho no H800 SXM5, com largura de banda de até 3000 GB/s em configurações com uso intensivo de memória e até 580 TFLOPS em configurações com uso intensivo de computação. O FlashMLA foi inspirado no FlashAttention 2&3 e no projeto Cutlass. A DeepSeek AI demonstrou sua inovação em tecnologia de IA por meio desse projeto de código aberto, que atraiu muita atenção.
Lista de funções
- Decodificação eficiente de MLAOtimizado para GPUs Hopper para acelerar significativamente o processamento de sequências de comprimento variável.
- Suporta a precisão do BF16Utilize operações de ponto flutuante de meia precisão para melhorar a eficiência computacional e, ao mesmo tempo, manter a precisão.
- Paging Cache KVMecanismo de paginação: Um mecanismo de paginação com um tamanho de bloco de 64 é usado para gerenciar a memória de forma eficaz e melhorar o desempenho da inferência.
- Alto desempenhoCapacidade de computação: fornece até 3000 GB/s de largura de banda de memória e 580 TFLOPS de potência de computação na GPU H800.
- código abertoCódigo-fonte: O código-fonte completo é fornecido para dar suporte à modificação e integração personalizadas pelos desenvolvedores.
Usando a Ajuda
Processo de instalação
O FlashMLA é um projeto de código aberto baseado no GitHub. Você precisa garantir que o ambiente atenda aos requisitos e concluir a instalação antes de usar. Aqui estão as etapas detalhadas:
1. preparação ambiental
- sistema operacionalSuporte para sistemas Linux (recomenda-se o Ubuntu 20.04 ou superior).
- Requisitos de hardwareRequer uma GPU de arquitetura NVIDIA Hopper (como a H800 SXM5).
- dependência de software::
- CUDA 12.6 ou superior (consulte o site da NVIDIA para obter instruções de instalação).
- PyTorch 2.0 ou superior (recomendado via
pip install torch
(Instalação). - Python 3.8 ou superior.
- Ferramentas de inspeçãoVerifique se o Git está instalado para fazer download do código do GitHub.
2. download do código-fonte
- Abra um terminal e digite o seguinte comando para clonar o repositório FlashMLA:
git clone https://github.com/deepseek-ai/FlashMLA.git
- Vá para o catálogo de projetos:
cd FlashMLA
3. instalação de dependências
O projeto depende do PyTorch e do CUDA, que podem ser instalados com os seguintes comandos:
pip install -r requirements.txt
Caso contrário requisitos.txt
é simples garantir que o PyTorch esteja instalado:
pip install torch torchvision
Verifique se o CUDA está disponível:
python -c "import torch; print(torch.cuda.is_available())"
exportações Verdadeiro
Indica que a configuração do ambiente foi bem-sucedida.
4. compilação e teste
O FlashMLA fornece plug-ins CUDA pré-compilados, mas certifique-se de que correspondem à versão local do CUDA:
- Vá para o diretório de origem e execute o script de compilação (se houver):
python setup.py install
- Teste se a instalação foi bem-sucedida e execute o código de amostra:
python example.py
Se nenhum erro for relatado, a instalação estará concluída.
Como usar
O principal recurso do FlashMLA é oferecer suporte eficiente à decodificação de MLA para tarefas de inferência de modelos de IA. Aqui estão as etapas:
Função 1: Carregar e executar o FlashMLA
- módulo de importação::
Apresentando as funções principais do FlashMLA em scripts Python:from flash_mla import get_mla_metadata, flash_mla_with_kvcache
- Preparação para inserir dados::
cache_seqlens
Comprimento da sequência: define o comprimento da sequência do cache KV.q_i
: Tensor de consulta.kvcache_i
Dados em cache da KV.block_table
Tabela de blocos para cache de paginação: Tabela de blocos para cache de paginação.
- Obtenção de metadados::
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
- decodificação em execução::
o_i, lse_i = flash_mla_with_kvcache(q_i, kvcache_i, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, causal=True)
exportações
o_i
é o resultado da decodificação.lse_i
é o valor da soma de registros.
Função 2: Otimizar o processamento de sequências de comprimento variável
- tomarSequências de entrada de comprimento dinâmico: Ao lidar com sequências de entrada de comprimento dinâmico, o FlashMLA reduz o espaço ocupado pela memória paginando o cache KV.
- equipamento::
- Configurar os parâmetros de paginação: o tamanho do bloco é fixo em 64 e pode ser ajustado com o ajuste do
cache_seqlens
Controla o comprimento da sequência. - Especificar em tempo de execução
causal=True
(c) Assegurar que o mecanismo de atenção causal esteja em vigor.
- Configurar os parâmetros de paginação: o tamanho do bloco é fixo em 64 e pode ser ajustado com o ajuste do
- efeitoO H800 tem uma largura de banda de memória de 3000 GB/s para tarefas de inferência em grande escala.
Função 3: Teste de desempenho
- Métodos de teste::
- Edite o script de amostra (por exemplo
example.py
), aumentando o tamanho dos dados de entrada. - Use o código a seguir para registrar o desempenho:
importar time start = time.time() # Executar o código de decodificação o_i, lse_i = flash_mla_with_kvcache(...) print(f "Time: {time.time() - start} seconds")
- Edite o script de amostra (por exemplo
- Resultados esperadosQuase 3000 GB/s para tarefas com uso intensivo de memória e 580 TFLOPS para tarefas com uso intensivo de computação.
advertência
- compatibilidade de hardwareGPUs Hopper: Somente as GPUs Hopper são compatíveis, recomendando-se a H800 ou equivalente.
- Dicas de depuraçãoSe você encontrar erros de CUDA, verifique se há correspondências de versão ou procure o suporte da comunidade em Problemas do GitHub.
- ambiente de produçãoIntegração direta com os processos de inferência de modelos existentes, garantindo que os formatos de dados de entrada sejam consistentes com os requisitos do FlashMLA.
Com as etapas acima, os usuários podem começar a usar o FlashMLA rapidamente e aproveitar a melhoria de desempenho proporcionada por sua decodificação eficiente. O código completo e a documentação podem ser encontrados no repositório do GitHub, e é recomendável ajustar os parâmetros de acordo com os requisitos reais do projeto.