DeepSeek发布DeepGEMM|完全 Just-In-Time | 300 行代码胜过专家调优的内核

DeepGEMM 是一个 CUDA 库,专为 DeepSeek-V3 中使用的 FP8 精度 GEMM 运算而设计,特点是细粒度缩放和高性能。 它支持普通 GEMM 和 MoE 模型的 Grouped GEMM,并在运行时使用 JIT 编译内核,无需预先编译,代码简洁易懂,核心内核仅约 300 行代码。DeepGEMM 专为 NVIDIA Hopper 架构的张量核心优化,通过 CUDA 核心两级累加解决 FP8 的精度问题。虽然借鉴了 CUTLASS 和 CuTe 的概念,但 DeepGEMM 专注于简化设计,在多种矩阵形状下性能可媲美或超过专业库,并且采用多种优化技术,例如 TMA 加速、JIT 编译以及 FFMA SASS 交错等。

阅读时长: 6 分钟
共 2992字
作者: eimoon.com

deepseek 第三天发布的技术DeepGEMM。以下是翻译内容:

DeepGEMM

DeepGEMM 是一个专为清晰高效的 FP8 通用矩阵乘法 (GEMM) 设计的库,它采用了 DeepSeek-V3 中提出的细粒度缩放技术。它同时支持普通 GEMM混合专家 (MoE) 分组 GEMM。该库使用 CUDA 编写,无需在安装时进行编译,因为它通过一个轻量级的即时 (JIT) 模块在运行时编译所有内核。

目前,DeepGEMM 仅支持 NVIDIA Hopper 张量核心。为了解决 FP8 张量核心累加的不精确性,它采用了 CUDA 核心的两级累加(提升)。虽然它借鉴了 CUTLASS 和 CuTe 中的一些概念,但避免了对它们的模板或代数的过度依赖。相反,该库的设计注重简洁,只有一个核心内核函数,大约包含 ~300 行代码。这使得它成为学习 Hopper FP8 矩阵乘法和优化技术的清晰易懂的资源。

尽管设计轻量,DeepGEMM 的性能在各种矩阵形状上都能与专家调优的库相媲美或超过。

性能

我们在 H800 SXM5 上使用 NVCC 12.8 测试了 DeepSeek-V3/R1 推理中可能使用的所有形状(包括预填充和解码,但不包括张量并行)。所有加速指标都是与我们内部精心优化的基于 CUTLASS 3.6 的实现进行比较计算得出的。

DeepGEMM 在某些形状上的表现不佳,如果您有兴趣,欢迎提交优化 PR。

普通 GEMM,适用于稠密模型

M N K 计算量 内存带宽 加速
64 2112 7168 206 TFLOPS 1688 GB/s 2.7x
64 24576 1536 289 TFLOPS 2455 GB/s 1.7x
64 32768 512 219 TFLOPS 2143 GB/s 1.8x
64 7168 16384 336 TFLOPS 2668 GB/s 1.4x
64 4096 7168 287 TFLOPS 2320 GB/s 1.4x
64 7168 2048 295 TFLOPS 2470 GB/s 1.7x
128 2112 7168 352 TFLOPS 1509 GB/s 2.4x
128 24576 1536 535 TFLOPS 2448 GB/s 1.6x
128 32768 512 358 TFLOPS 2103 GB/s 1.5x
128 7168 16384 645 TFLOPS 2604 GB/s 1.4x
128 4096 7168 533 TFLOPS 2221 GB/s 2.0x
128 7168 2048 510 TFLOPS 2277 GB/s 1.7x
4096 2112 7168 1058 TFLOPS 527 GB/s 1.1x
4096 24576 1536 990 TFLOPS 786 GB/s 1.0x
4096 32768 512 590 TFLOPS 1232 GB/s 1.0x
4096 7168 16384 1358 TFLOPS 343 GB/s 1.2x
4096 4096 7168 1304 TFLOPS 500 GB/s 1.1x
4096 7168 2048 1025 TFLOPS 697 GB/s 1.1x

MoE 模型的分组 GEMM(连续布局)

#Groups 每组 M 值 N K 计算量 内存带宽 加速
4 8192 4096 7168 1297 TFLOPS 418 GB/s 1.2x
4 8192 7168 2048 1099 TFLOPS 681 GB/s 1.2x
8 4096 4096 7168 1288 TFLOPS 494 GB/s 1.2x
8 4096 7168 2048 1093 TFLOPS 743 GB/s 1.1x

MoE 模型的分组 GEMM(掩码布局)

#Groups 每组 M 值 N K 计算量 内存带宽 加速
1 1024 4096 7168 1233 TFLOPS 924 GB/s 1.2x
1 1024 7168 2048 925 TFLOPS 968 GB/s 1.2x
2 512 4096 7168 1040 TFLOPS 1288 GB/s 1.2x
2 512 7168 2048 916 TFLOPS 1405 GB/s 1.2x
4 256 4096 7168 932 TFLOPS 2064 GB/s 1.1x
4 256 7168 2048 815 TFLOPS 2047 GB/s 1.2x

快速开始

要求

  • Hopper 架构 GPU,必须支持 sm_90a

  • Python 3.8 或以上

  • CUDA 12.3 或以上

    • 但我们强烈建议使用 12.8 或以上版本以获得最佳性能
  • PyTorch 2.1 或以上

  • CUTLASS 3.6 或以上(可以通过 Git 子模块克隆)

开发

# Submodule must be cloned
git clone --recursive [email protected]:deepseek-ai/DeepGEMM.git

# Make symbolic links for third-party (CUTLASS and CuTe) include directories
python setup.py develop

# Test JIT compilation
python tests/test_jit.py

# Test all GEMM implements (normal, contiguous-grouped and masked-grouped)
python tests/test_core.py

安装

python setup.py install

然后,在您的 Python 项目中导入 deep_gemm,尽情享用吧!

接口

注意事项

该库仅包含 GEMM 内核。它要求 LHS 缩放因子与 TMA 对齐并进行转置,并且仅支持 NT 格式(非转置 LHS 和转置 RHS)。对于转置或其他 FP8 转换操作,请独立实现或将它们融合到之前的内核中。虽然该库提供了一些简单的 PyTorch 实用函数,但这些函数可能会导致性能下降,但我们的主要重点是优化 GEMM 内核本身。

普通稠密 GEMM(非分组)

要执行基本的非分组 FP8 GEMM,请调用 deep_gemm.gemm_fp8_fp8_bf16_nt 函数。有关更多详细信息,请参阅函数文档。

分组 GEMM(连续布局)

与 CUTLASS 中的传统分组 GEMM 不同,DeepGEMM 仅对 M 轴进行分组,而 N 轴和 K 轴必须保持固定。此设计专为 MoE 模型中的专家共享相同形状的场景而定制。

对于训练前向传递或推理预填充,其中每个专家可以处理不同数量的令牌,我们将这些令牌连接成一个张量,称为“连续”布局。请注意,每个专家段必须与 GEMM M 块大小对齐 (get_m_alignment_for_contiguous_layout())。

有关更多信息,请参阅 m_grouped_gemm_fp8_fp8_bf16_nt_contiguous 函数文档。

分组 GEMM(掩码布局)

在推理解码阶段,当 CUDA 图已启用且 CPU 不知道每个专家接收到的令牌数量时,我们支持掩码分组 GEMM。通过提供掩码张量,内核仅计算有效部分。

使用 m_grouped_gemm_fp8_fp8_bf16_nt_masked 来实现此目的,并查阅相关文档。一个示例用法是将 DeepEP 中低延迟内核的输出用作输入。

实用工具

除了上述内核之外,该库还提供了一些实用函数:

  • deep_gemm.set_num_sms:设置要使用的最大 SM 计数
  • deep_gemm.get_num_sms:获取当前的 SM 最大计数
  • deep_gemm.get_m_alignment_for_contiguous_layout:获取分组连续布局的组级别对齐要求
  • deep_gemm.get_tma_aligned_size:获取所需的 TMA 对齐大小
  • deep_gemm.get_col_major_tma_aligned_tensor:获取列优先 TMA 对齐张量

该库还提供了一些环境变量,这些变量可能很有用:

  • DG_CACHE_DIR:字符串,用于存储已编译内核的缓存目录,默认为 $HOME/.deep_gemm
  • DG_NVCC_COMPILER:字符串,指定的 NVCC 编译器路径;默认情况下将从 torch.utils.cpp_extension.CUDA_HOME 中查找
  • DG_DISABLE_FFMA_INTERLEAVE:0 或 1,禁用 FFMA 交错优化
  • DG_PTXAS_VERBOSE:0 或 1,显示详细的 PTXAS 编译器输出
  • DG_PRINT_REG_REUSE:0 或 1,打印 FFMA 交错详细信息
  • DG_JIT_PRINT_NVCC_COMMAND:0 或 1,打印 NVCC 编译命令
  • DG_JIT_DEBUG:0 或 1,打印更多调试信息

有关其他示例和详细信息,请参阅测试代码或查看相应的 Python 文档。

优化

我们用 🐳 符号表示 CUTLASS 中排除的技术。

持久 Warp 特化

按照 CUTLASS 的设计,DeepGEMM 中的内核是 Warp 特化的,从而可以重叠数据移动、张量核心 MMA 指令和 CUDA 核心提升。下图简化说明了此过程: alt text

Hopper TMA 特性

张量内存加速器 (TMA) 是 Hopper 架构引入的一项新的硬件特性,专为更快速和异步的数据移动而设计。具体来说,我们利用 TMA 来实现:

  • 用于 LHS、LHS 缩放因子和 RHS 矩阵的 TMA 加载
  • 用于输出矩阵的 TMA 存储
  • TMA 组播(LHS 矩阵独有)
  • TMA 描述符预取

通用细节优化

  • PTX 指令的利用
  • 针对不同 Warp 组量身定制的寄存器计数控制
  • 尽可能多地重叠,例如重叠 TMA 存储和非 TMA RHS 缩放因子加载 🐳

统一且优化的块调度器

  • 一个调度器,适用于所有非分组和分组内核
  • 光栅化以增强 L2 缓存重用

完全 JIT 设计 🐳

DeepGEMM 采用完全即时 (JIT) 设计,无需在安装时进行编译。所有内核都在运行时使用轻量级 JIT 实现进行编译。这种方法具有以下几个优点:

  • GEMM 形状、块大小和流水线阶段数被视为编译时常量

    • 节省寄存器
    • 编译器可以进行更多优化
  • 自动选择块大小、Warp 组数、最佳流水线阶段和 TMA 集群大小

    • 但没有自动调优,最佳值是确定性选择的
  • 完全展开 MMA 流水线,为编译器提供更多优化机会

    • 对于小形状非常重要
    • 有关详细信息,请参阅内核文件中的 launch_k_iterations

总的来说,JIT 显着提高了小形状的性能,类似于 Triton 编译器的方法。

非对齐块大小 🐳

对于某些形状,与 2 的幂对齐的块大小可能导致 SM 利用率不足。例如,对于 M=256、N=7168,典型的块大小分配 BLOCK_M=128、BLOCK_N=128 仅导致 (256 / 128) * (7168 / 128) = 112 个 SM 被利用,总共有 132 个 SM。为了解决这个问题,我们支持非对齐的块大小,如 112,从而使 (256 / 128) * (7168 / 112) = 128 个 SM 可以在这种情况下工作。实施此技术以及细粒度缩放需要仔细优化,但最终会带来性能提升。

FFMA SASS 交错 🐳

我们观察到 NVCC 12.2 和 12.3 之间 CUTLASS FP8 内核的性能有所提高。通过比较已编译的 SASS,我们发现在交错模式中,一系列指令中的一位被翻转。在参考了一些开源 CUDA 汇编器实现后,我们确定该位控制 yield,这可能会增强 Warp 级并行性(只是猜测,产生当前 Warp 并让其他 Warp 工作)。

为了利用这一点,我们开发了一个类似的脚本来修改已编译二进制文件中的 FFMA 指令。除了简单地修改 yield 位之外,我们还翻转了重用位(如果 Warp 已 yield,则无法重用寄存器)。通过创建更多将 MMA 指令与提升 FFMA 指令重叠的机会,此调整提高了细粒度缩放 FP8 GEMM 的性能(在某些情况下提高了 10% 以上)。

致谢

DeepGEMM 的灵感来自 CUTLASS 项目。感谢并致敬开发者们!

许可证

此代码库是在 MIT 许可证下发布的。

关注我获取更多资讯

公众号
📢 公众号
个人号
💬 个人号
使用 Hugo 构建
主题 StackJimmy 设计