FlashMLA: Hopper GPU를 위한 MLA 디코딩 커널 최적화(딥시크 오픈 소스 주간 1일차)

최신 AI 리소스게시됨 6 개월 전 AI 공유 서클
1.8K 00

일반 소개

FlashMLA는 DeepSeek AI는 가변 길이 시퀀스 처리의 성능을 개선하기 위해 설계된 NVIDIA Hopper 아키텍처 GPU에 최적화된 효율적인 MLA(Multi-head Latent Attention) 디코딩 커널을 개발했습니다. 이 프로젝트는 GitHub에서 오픈 소스로 제공되며 개발자는 무료로 사용할 수 있습니다. BF16 정밀 연산과 페이징된 KV 캐싱(블록 크기 64)을 지원하며, 메모리 집약적 구성에서 최대 3000GB/s 대역폭과 컴퓨팅 집약적 구성에서 최대 580 TFLOPS로 H800 SXM5에서 우수한 성능을 발휘합니다. FlashMLA는 FlashAttention 2&3과 Cutlass 프로젝트에서 영감을 받았습니다. 딥시크 AI는 이 오픈 소스 프로젝트를 통해 AI 기술의 혁신을 입증하며 많은 주목을 받고 있습니다.

FlashMLA:优化Hopper GPU的MLA解码内核(DeepSeek 开源周第一天)

 

기능 목록

  • 효율적인 MLA 디코딩가변 길이 시퀀스의 처리 속도를 크게 향상시키기 위해 Hopper GPU에 최적화되었습니다.
  • BF16 정확도 지원반정밀도 부동 소수점 연산을 활용하여 정밀도를 유지하면서 계산 효율성을 개선합니다.
  • 페이징 KV 캐싱블록 크기가 64인 페이징 메커니즘을 사용하여 메모리를 효과적으로 관리하고 추론 성능을 개선합니다.
  • 고성능H800 GPU에서 최대 3000GB/s의 메모리 대역폭과 580 TFLOPS의 컴퓨팅 성능을 제공합니다.
  • 오픈 소스개발자의 사용자 지정 수정 및 통합을 지원하기 위해 전체 소스 코드가 제공됩니다.

 

도움말 사용

설치 프로세스

FlashMLA는 GitHub를 기반으로 하는 오픈 소스 프로젝트이므로 사용 전에 환경이 요구 사항을 충족하는지 확인하고 설치를 완료해야 합니다. 자세한 단계는 다음과 같습니다:

1. 환경 대비

  • 운영 체제Linux 시스템 지원(우분투 20.04 이상 권장).
  • 하드웨어 요구 사항NVIDIA 호퍼 아키텍처 GPU(예: H800 SXM5)가 필요합니다.
  • 소프트웨어 종속성::
    • CUDA 12.6 이상(설치 지침은 NVIDIA 웹사이트 참조).
    • PyTorch 2.0 이상(다음을 통해 권장) pip install torch (설치).
    • Python 3.8 이상.
  • 검사 도구GitHub에서 코드를 다운로드하기 위해 Git이 설치되어 있는지 확인합니다.

2. 소스 코드 다운로드

  1. 터미널을 열고 다음 명령을 입력하여 FlashMLA 리포지토리를 복제합니다:
    git clone https://github.com/deepseek-ai/FlashMLA.git
  1. 프로젝트 카탈로그로 이동합니다:
    cd FlashMLA
    

3. 종속성 설치

이 프로젝트는 다음 명령어로 설치할 수 있는 PyTorch 및 CUDA에 의존합니다:

pip install -r requirements.txt

그렇지 않은 경우 requirements.txt 파일에 파이토치가 설치되어 있는지 확인하는 것은 간단합니다:

pip install torch torchvision

CUDA를 사용할 수 있는지 확인합니다:

python -c "import torch; print(torch.cuda.is_available())"

수출 True 환경 구성이 성공했음을 나타냅니다.

4. 컴파일 및 테스트

FlashMLA는 사전 컴파일된 CUDA 플러그인을 제공하지만, 사용 중인 CUDA 버전과 일치해야 합니다:

  1. 소스 디렉토리로 이동하여 컴파일 스크립트(있는 경우)를 실행합니다:
    python setup.py install
    
  2. 설치가 성공적으로 완료되었는지 테스트하고 샘플 코드를 실행합니다:
    python example.py
    

오류가 보고되지 않으면 설치가 완료된 것입니다.

사용 방법

FlashMLA의 핵심 기능은 AI 모델 추론 작업을 위한 효율적인 MLA 디코딩 지원을 제공하는 것입니다. 단계는 다음과 같습니다:

기능 1: 플래시MLA 로드 및 실행

  1. 모듈 가져오기::
    Python 스크립트에서 FlashMLA 핵심 기능을 소개합니다:

    from flash_mla import get_mla_metadata, flash_mla_with_kvcache
    
  2. 데이터 입력 준비::
    • cache_seqlens: KV 캐시의 시퀀스 길이를 정의합니다.
    • q_i쿼리 텐서.
    • kvcache_iKV 캐시된 데이터.
    • block_table: 페이징 캐시용 블록 테이블.
  3. 메타데이터 가져오기::
    tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
    
  4. 디코딩 실행 중::
    o_i, lse_i = flash_mla_with_kvcache(q_i, kvcache_i, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, causal=True)
    

    수출 o_i 는 디코딩 결과입니다.lse_i 는 로그 합계 값입니다.

기능 2: 가변 길이 시퀀스 처리 최적화

  • take동적 길이 입력 시퀀스를 처리할 때 플래시MLA는 KV 캐시를 페이징하여 메모리 풋프린트를 줄입니다.
  • rig::
    1. 페이징 매개변수 구성: 블록 크기는 64로 고정되어 있으며 다음과 같이 조정하여 조정할 수 있습니다. cache_seqlens 시퀀스의 길이를 제어합니다.
    2. 런타임에 지정 causal=True(c) 인과관계 주의 메커니즘이 작동하는지 확인합니다.
  • 효과대규모 추론 작업을 위한 H800의 3000GB/s 메모리 대역폭.

기능 3: 성능 테스트

  • 테스트 방법::
    1. 샘플 스크립트 편집(예 example.py)를 사용하여 입력 데이터의 크기를 늘립니다.
    2. 다음 코드를 사용하여 실적을 기록하세요:
      import time
      start = time.time()
      # 运行解码代码
      o_i, lse_i = flash_mla_with_kvcache(...)
      print(f"耗时: {time.time() - start} 秒")
      
  • 예상 결과메모리 집약적 작업의 경우 약 3000GB/s, 컴퓨팅 집약적 작업의 경우 580TFLOPS.

주의

  • 하드웨어 호환성호퍼 GPU만 지원되며, H800 또는 이와 동등한 제품을 권장합니다.
  • 디버깅 팁CUDA 오류가 발생하면 버전이 일치하는지 확인하거나 GitHub 이슈에서 커뮤니티 지원을 요청하세요.
  • 프로덕션 환경기존 모델 추론 프로세스에 직접 통합하여 입력 데이터 형식이 FlashMLA 요구 사항과 일치하도록 보장합니다.

위의 단계를 통해 사용자는 플래시MLA를 빠르게 시작하고 효율적인 디코딩으로 인한 성능 향상을 누릴 수 있습니다. 전체 코드와 문서는 GitHub 리포지토리에서 확인할 수 있으며, 실제 프로젝트 요구 사항에 따라 파라미터를 조정하는 것이 좋습니다.

© 저작권 정책
AiPPT

관련 문서

댓글 없음

댓글에 참여하려면 로그인해야 합니다!
지금 로그인
없음
댓글 없음...