はじめに
DeepGEMMは ディープシーク チームによって開発されたオープンソースのFP8 GEMM (Generalised Matrix Multiplication)ライブラリは、効率的な行列演算サポートを提供することに重点を置いています。NVIDIA HopperアーキテクチャのTensor Core用に特別に設計されており、一般的な行列演算と、混合エキスパートモデル(MoE)用のグループ化されたGEMM演算の両方をサポートしています。CUDAで記述されたこのライブラリは、軽量のJIT(Just-In-Time)コンパイルを使用してランタイム・カーネルコンパイルされているため、インストール時のプリコンパイルが不要になり、導入プロセスが大幅に簡素化されます。DeepGEMMは、クリーンなコードを維持しながら卓越した性能を発揮し、Hopper GPU上で1,350 TFLOPS以上のFP8計算能力を達成します。機械学習モデルのトレーニングや推論アクセラレーションに適しているだけでなく、オープンソースでありアクセスしやすいため、FP8行列最適化を学習するための優れたリソースでもあります。
機能一覧
-FP8行列演算をサポートハイパフォーマンス・コンピューティング・シナリオのための効率的なFP8一般化行列乗算(GEMM)を提供します。
-MoEモデルの最適化ハイブリッドエキスパートモデルのグループ化GEMMをサポート。M軸のみをグループ化し、適応エキスパートが同じシーンの形状を共有する。
-ジャスト・イン・タイム(JIT)コンパイルプリコンパイルなしで異なるハードウェア環境に適応できるように、実行時にカーネルをコンパイルします。
-ハイパフォーマンス・コンピューティング(HPC)NVIDIA Hopper GPUでFP8の計算スループット1350TFLOPS以上を達成。
-シンプルなコード設計約300行のコア・コードで、習得しやすく、二次開発も容易です。
-高い互換性通常のGEMMとマスク付きパケットGEMMの両方をサポートしており、様々な推論シナリオに適している。
-オープンソースで無料MITのプロトコルの下、研究および商用利用を目的に発行。
ヘルプの使用
DeepGEMMは、開発者向けに設計されたオープンソースの行列演算ライブラリで、主にCUDAプログラミングと機械学習の基本的なバックグラウンドを持つユーザーを対象としています。以下は、すぐに使い始め、プロジェクトに統合するのに役立つ詳細なガイドです。
設置プロセス
DeepGEMMは、複雑なプリコンパイルプロセスを必要とせず、わずか数ステップでインストールと実行環境の設定が可能です:
1.環境準備::
- システム要件:NVIDIA Hopperアーキテクチャ(H100など)をサポートするGPU。
- ソフトウェア依存:CUDA Toolkit(推奨バージョン11.8以上)とPython(3.8以上)をインストールしてください。
- ハードウェアサポート:お使いのデバイスに、少なくとも40GBのビデオメモリを搭載したNVIDIA GPUが搭載されていることをご確認ください。
2.クローン倉庫::
ターミナルで以下のコマンドを実行して、DeepGEMM リポジトリをローカルにダウンロードします:
git clone https://github.com/deepseek-ai/DeepGEMM.git**
cd DeepGEMM**
- 依存関係をインストールします:
Pythonのパッケージ管理ツールを使って、必要な依存関係をインストールします:
pip install torch numpy
DeepGEMM自体は、オンザフライのコンパイル技術に依存しており、すべてのカーネルが実行時に自動的に生成されるため、追加のコンパイルは必要ありません。
4.インストールを確認する:
提供されたテストスクリプトを実行し、環境が正しく設定されていることを確認する:
python test/deep_gemm_test.py
出力が正常な行列演算結果を示していれば、インストールは成功である。
主な機能
1.基本的なFP8 GEMM操作を行う
DeepGEMMは、グループ化されていないFP8行列乗算を実行するための使いやすいインターフェイスを提供します:
- 操作手順:
- ライブラリと関数のインポート
インポートトーチ
from deep_gemm import gemm_fp8_fp8_bf16_nt
- 入力データ(行列AとB、FP8形式でなければならない)を準備する:
A = torch.randn(1024, 512, dtype=torch.float8_e4m3fn).cuda()
B = torch.randn(512, 1024, dtype=torch.float8_e4m3fn).cuda()
- 行列の乗算を行う関数を呼び出す:
C = gemm_fp8_fp8_bf16_nt(A, B)
print(C)
- 警告だ:
- 入力行列はGPU上にあり、FP8フォーマット(E4M3またはE5M2)である必要があります。
- 出力はBF16形式で、その後の計算や保存に適している。
2.MoEモデルをサポートするGEMMのグループ化
MoE モデルを扱う必要があるユーザーのために、DeepGEMM はグループ化された GEMM をサポートしています:
- 操作手順:
- GEMM関数をグループ化してインポートする:
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
- 連続レイアウト用の入力データを準備する:
A = torch.randn(4096, 512, dtype=torch.float8_e4m3fn).cuda() # 複数エキスパートの入力スプライシング
B = torch.randn(512, 1024, dtype=torch.float8_e4m3fn).cuda()
group_sizes = [1024, 1024, 1024, 1024] # 各エキスパートの トークン 頻繁に
- グループGEMMを行う:
C = m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(A, B, group_sizes)
print(C)
- 警告だ:
- 入力行列AのM軸はエキスパートによってグループ化される必要があり、各グループのサイズはGEMMのMブロックサイズ(利用可能)に合わせる必要がある。
get_m_alignment_for_contiguous_layout()
(アクセス)。 - BマトリックスのN軸とK軸は固定する必要がある。
- 入力行列AのM軸はエキスパートによってグループ化される必要があり、各グループのサイズはGEMMのMブロックサイズ(利用可能)に合わせる必要がある。
3.推論段階でのマスクグループ化 GEMM
DeepGEMMは推論復号フェーズにおいて、動的トークン割り当てのためのマスクを用いたグループ化GEMMをサポートしている:
- 操作手順:
- マスクのグループ化機能をインポートする:
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
- 入力データとマスクを準備する:
A = torch.randn(4096, 512, dtype=torch.float8_e4m3fn).cuda()
B = torch.randn(512, 1024, dtype=torch.float8_e4m3fn).cuda()
mask = torch.ones(4096, dtype=torch.bool).cuda() #マスクは有効なトークンを示す。
- マスクのグルーピングGEMMを行う:
C = m_grouped_gemm_fp8_fp8_bf16_nt_masked(A, B, mask)
print(C)
- 警告だ:
- マスクは、計算が必要なトークンを指定するために使用され、CUDAグラフが有効な場合に動的推論に適している。
注目機能 操作手順
高性能の最適化とデバッグ
DeepGEMMの核となる強みは、その効率性とシンプルさであり、開発者は以下のステップを踏むことで、さらに最適化し、デバッグすることができる:
- パフォーマンスデータを見る
テストスクリプト実行時にTFLOPSを監視するためのログ出力を追加:インポート・ロギング logging.basicConfig(level=logging.INFO) C = gemm_fp8_fp8_bf16_nt(A, B)
調整パラメーター*:
データの移動と計算のオーバーラップを最適化するために、特定のハードウェアに合わせてブロックサイズ(TMAパラメータ)を調整します。
学習と拡大*:
コアコードはdeep_gemm/gemm_kernel.cuにあり、約300行ある。開発者はこれを直接読んで、カスタム要件に合うように修正することができる。
使用上の推奨事項
ハードウェア要件*: 現在、NVIDIA HopperアーキテクチャGPUのみをサポートしています。
文書参照*: 関数の詳細な説明とサンプルコードは、GitHubリポジトリのREADME.mdとtest/フォルダにあります。
地域支援*問題が発生した場合は、GitHub Issues ページにフィードバックを送信してください。
以上の手順で、DeepGEMMを機械学習プロジェクトに簡単に統合し、その効率的なFP8行列計算能力を享受することができる。