深入解析整数量化技术
速览
整数量化是一种将模型参数从高精度浮点数转换为低精度整数的技术。该技术能显著降低模型存储需求和计算开销,提升推理速度。在边缘设备和移动端部署大模型时,整数量化是实现高效推理的关键手段。
AI 深度解读
Integer Quantization: Deep Dive 深度解读
背景
过去几年,Transformer 模型的量化技术取得了显著进展。从最初难以在不严重损害精度的情况下将 7B 参数模型量化为 INT8,到如今能够 routinely(常规地)在单张 GPU 上将 70B 参数模型以 4-bit 精度运行。然而,目前关于该主题的指导资料往往碎片化:要么专注于特定技术,要么仅介绍如何使用某个库。
作者长期致力于定点硬件上的整数量化工作,本系列文章旨在填补这一空白:通过仔细构建核心概念,并追溯该领域的演变历程,展示每种技术是如何由前一个技术所面临的问题所驱动的。本文作为系列的第一篇,涵盖基础内容:什么是量化、为什么量化很难,以及其背后的数学原理。
核心内容
什么是量化以及为什么值得关注?
量化是使用更少的位来表示高精度值的过程。在实践中,这意味着以较低的精度存储权重和(可选的)激活值(例如,使用 int8 而不是 fp16),从而引入微小的近似误差。
量化最直接且易于实现的收益是内存减少。根据经验法则,拥有 N 亿参数的模型在 16-bit 精度下存储时,大约需要 2 × N GB 的内存。量化到 8-bit 或 4-bit 分别可以将内存占用减少 2 倍和 4 倍。
此外,还有硬件优势。2014年,斯坦福大学的 Mark Horowitz 发表了一篇名为《Computing’s Energy Problem》的论文,研究了浮点运算与整数运算的区别:
注:原文引用了一张关于 45nm CMOS 节点上各种操作能耗的图表。
整数算术消耗的能量更少,具体而言,int8 加法比 fp32 加法消耗少 30 倍的能量,int8 乘法比 fp32 乘法消耗少 18 倍的能量。较低精度的硬件也比浮点硬件更快,且占用的硅片面积更小。
这些优势如何转化为实际收益?这取决于系统的瓶颈所在:
- 计算密集型工作负载(例如 CNN、LLM 预填充阶段):量化可以通过更快的低精度算术运算和更低的能耗来提高吞吐量。
- 内存带宽密集型工作负载(例如 LLM 解码阶段):量化减少了移动的数据量,通过降低内存带宽压力来提高性能。
至此,量化的动机已十分明确:量化减少内存占用、降低能耗,并可能提升性能。接下来,我们将看看执行定点算术的硬件单元。
乘加单元 (Multiply Accumulate Unit, MAC)
神经网络中的主导操作是矩阵乘法。现代硬件加速器使用称为乘加单元(MAC Units)的专用单元来优化这一操作。
该图表示神经网络加速器中典型的矩阵-向量乘法单元。这是矩阵乘法和卷积的基础构建块。两个基本组件是处理元素 $C_{n,m}$ 和累加器 $A_n$。
计算过程如下:
- 首先用偏置值 $b_n$ 初始化累加器。
- 在下一个周期,加载权重 $W_{n,m}$ 和输入值 $x_m$。
- 在每个处理元素上计算它们的乘积: $$C_{n,m} = W_{n,m} \cdot x_m$$
- 然后对结果进行累加: $$A_n = b_n + \sum_{m} C_{n,m}$$
如何进行量化?
从实值向量 $x$ 开始,我们将其映射到整数网格 ${x_{\text{int}}^{\min}, \ldots, x_{\text{int}}^{\max}}$:
$$x_{\text{int}} = \text{clamp}\left(\left\lfloor \frac{x}{s} + z \right\rceil, q_{\min}, q_{\max}\right)$$
其中:
- $s$ 是缩放因子 (scale)
- $z$ 是零点 (zero-point,即偏移量)
- $\lfloor \cdot \rceil$ 表示四舍五入到最近的整数
Clamp 操作确保结果位于有效的整数范围内:
$$q_{\min} \le x_{\text{int}} \le q_{\max}$$
因此,核心思想是缩放和平移浮点值,然后将其夹紧以适合整数网格。
量化模拟 (Fake Quantization)
我们通常不在目标硬件上直接运行量化模型,而是使用 PyTorch 等高级框架在通用硬件上模拟量化。这通常被称为假量化 (fake quantization)。
关键思想很简单:我们模仿量化的效果,同时仍以浮点数执行操作。这允许我们在不需要专用硬件的情况下研究精度并进行如量化感知训练 (QAT) 等实验。
为此,我们:
- 将输入量化到整数网格
- 将其反量化回浮点数
- 在标准硬件(如 GPU)上以浮点数执行所有计算
反量化步骤将整数映射回实值:
$$x_{\text{dequant}} = s \cdot (x_{\text{int}} - z)$$
结合量化和反量化,我们得到:
$$x_{\text{dequant}} = s \cdot (\text{clamp}(\lfloor \frac{x}{s} + z \rceil, q_{\min}, q_{\max}) - z)$$
在实践中,框架会在模型图的运算周围插入这些量化-反量化 (Q/DQ) 对。虽然计算仍以浮点数进行,但值被限制在离散集合中,从而有效地模拟推理期间的量化效果。
$q_{\min}$ 和 $q_{\max}$ 是什么?
此时,建立一些关于量化公式实际作用的直觉会有所帮助。
考虑量化操作可能产生的最小值。这对应于 $x_{\text{int}}^{\min}$。将其代入反量化公式:
$$q_{\min_val} = s \cdot (x_{\text{int}}^{\min} - z)$$
类似地,最大值对应于 $x_{\text{int}}^{\max}$:
$$q_{\max_val} = s \cdot (x_{\text{int}}^{\max} - z)$$
这里需要注意的关键点是:我们不再在连续的浮点范围内操作,而是在一组离散的 $2^b$ 个值上操作,每个值之间由缩放因子 $s$ 分隔。
量化误差
我们的量化公式中有两个“坏家伙”:
- 舍入运算符,引入舍入误差;
- Clamp 运算符,引入裁剪误差。
这就是量化网络偏离原始网络的原因,即量化误差(或量化噪声)。正如我们将看到的,改善其中一个往往会恶化另一个,因此量化实际上是关于平衡这两者。
- 舍入误差:对于值 $x$,它是原始浮点值与其映射到的量化网格值之间的差异。
- 裁剪误差:当值超出可表示范围并被裁剪到最小值或最大值时发生。
让我们再次思考极端情况。舍入误差何时最大?最坏情况发生在浮点值恰好位于两个网格点中间时。在这种情况下,舍入误差最大为 $s/2$。
因此,你可能会想:“为什么不使 $s$ 非常小以最小化每个值的误差?”。不幸的是,量化没有免费的午餐:减小 $s$ 会使 $[q_{\min}, q_{\max}]$ 的范围变小,从而增加裁剪误差。
如何最优地选择量化参数?
如何计算量化参数 (scale 和 offset)?
最简单的方法是 min-max 量化,其中我们使用无符号整数网格 $(0 \text{ 到 } 2^b - 1)$ 并设置缩放因子,使得整个浮点范围适合其中,避免裁剪。也就是说,你想要:
$q_{\max}$ 映射到 $fp_{\max}$ 且 $q_{\min}$ 映射到 $fp_{\min}$
因此你求解:
$$s = \frac{fp_{\max} - fp_{\min}}{2^b - 1}$$
