标签 大语言模型推理 下的文章

0x01 研究背景

在自回归生成模型(Autoregressive Model)中,LLM每生成一个新token,都会将此前生成的序列作为输入。若每一步都重新计算全部注意力(Q、K、V 矩阵),计算量将随序列长度平方级增长。在长上下文和高并发场景下,这一开销会迅速成为系统瓶颈。为此,主流推理框架普遍引入KV-Cache技术。 KV-Cache通过缓存此前token的Key(K)和Value(V)向量,在下一步生成时只需计算新的Query(Q),即可直接复用前面的K/V,从而显著降低重复计算量。实践中,KV-Cache 通常能在保持模型精度不变的前提下,带来约5-8倍的推理加速。这一机制已经成为vLLM、SGLang、DeepSpeed-Inference等高性能推理引擎,以及Hugging Face generate(use_cache=True)接口的默认能力。

随着2024–2025年多租户推理服务(如vLLM、SGLang、TensorRT-LLM)的大规模部署,系统在单模型、多租户共享的前提下,又进一步引入跨请求的前缀缓存共享(prefix caching)。当不同请求的prompt存在相同前缀时,系统可以直接复用已有KV-Cache,大幅摊薄Prefill成本并提升吞吐。然而,当这种共享与复用机制扩展到多租户并发环境时,KV-Cache不再只是一个“性能优化组件”,而是演变成新的攻击面:攻击者可以通过观测Prefill 时间、TTFT等性能差异发起时序侧信道攻击,通过篡改缓存内容实施History Swapping(生成轨迹劫持),或者通过对Key向量注入扰动发动Cache Corruption(缓存腐败),从而导致跨租户信息泄露、话题漂移甚至下游任务性能显著下降。

0x02 KV缓存工作机制与共享复用原理

下面是KV-Cache工作原理的示意图。

KV-Cache工作原理

KV-Cache工作原理图

接下来我们用文字详细拆解,更深入了解KV缓存工作机制。

2.1 两阶段推理:Prefill与Decode

KV-Cache的核心做法分为两阶段。

(1)Prefill阶段(Prompt阶段)一次性计算输入序列的K/V并写入缓存 模型读取完整输入的prompt,计算出所有token的Key/Value向量并写入缓存。 公式表示为:

image-20251229205745845

此时缓存中的K/V向量构成了后续生成阶段的基础。

(2)Decode阶段(生成阶段)仅对新token计算Q/K/V,并复用历史K/V完成注意力计算 当模型生成新token时,仅需计算该token对应的Q、K、V向量。

image-20251229205757073

然后与缓存中已有的K/V拼接,直接完成注意力计算。这样便避免了重复计算前面N−1个token的注意力结果。

2.2 past_key_values

在Hugging Face Transformers框架中,KV-Cache在接口层面通过 past_key_values 对象实现。该对象并非一个抽象的控制开关,而是模型前向推理过程中实际生成、并可跨生成步骤复用的中间状态。它以分层的结构保存已处理历史Token的Key和Value张量,从而支撑自回归生成的增量计算。

从结构上看,past_key_values通常是一个长度为模型层数的列表或元组,其中每一层对应一对 (K, V)张量。不同模型的具体维度布局可能存在差异,但其核心语义一致:存储历史序列的注意力键值表示,以便后续生成时直接复用。

在推理流程中,Prefill 阶段会对完整的提示词进行计算,并首次生成past_key_values。进入Decode阶段后,若将此缓存作为输入传递给模型,模型通常只需为新输入的Token计算其对应的Key和Value,并将其追加至现有缓存末尾,从而避免了历史部分的重复计算。这种基于past_key_values的复用是框架的原生机制,其带来的加速直接源于注意力计算的真实削减,因此更适合作为评估系统性能及分析相关安全影响的工程基准。相比之下,通过sleep()或人为插桩制造“快慢差异”的方法仅能模拟现象,难以反映实际推理系统的缓存行为。此外,Transformers框架的generate()接口通常通过参数use_cache=True来启用此缓存机制。vLLM、SGLang、DeepSpeed-Inference在系统层面也普遍实现了类似机制,以降低生成延迟并提升吞吐量。

2.3 多租户场景下的前缀缓存与最长前缀匹配

在多请求并发且显存资源受限的推理服务中,为提升吞吐并降低重复的Prefill开销,系统常采用前缀缓存策略。其核心思想是当新请求的提示词(更准确地说是其Token序列)与某条已缓存的序列存在前缀重合时,系统可直接复用该前缀部分对应的KV-Cache,仅需对未命中的后续Token执行增量计算。

当缓存池中存在多个可能的候选前缀时,命中判定通常遵循最长前缀匹配(LPM)原则:在所有缓存条目中,系统会选择与新请求Token序列匹配长度最长的那一条作为复用对象,以最大化缓存利用率,减少重复计算。在工程实现上,这依赖于能够高效进行Token序列前缀匹配的数据结构或索引机制,例如前缀树(Trie)、基于前N个Token的分层哈希,或基于序列哈希值的多级索引。

根据匹配程度,命中效果可分为两类:一是完全命中,即请求的绝大部分或全部前缀已在缓存中,Prefill阶段的计算量显著下降;二是部分命中,即仅能复用较短的前缀,系统仍需对剩余后缀执行完整的Prefill计算。无论是“是否命中”还是“命中长度”,都会直接反映在可观测的系统性能指标上,例如Prefill时间、首Token延迟的分布等。

当前主流引擎(如vLLM的PagedAttention、LMCache)进一步通过分页管理和压缩技术缓解显存碎片,但前缀共享引入的侧信道与内存安全风险依然突出,这也是后续攻击面的根源。

0x03 KV-Cache的主要攻击面原理介绍

在理解KV-Cache的核心优化机制与共享原理后,我们可以看到其高效性背后隐藏的脆弱性。下面详解三大主要攻击面:时序侧信道攻击、操纵攻击与腐败攻击。

3.1 KV-Cache时序侧信道攻击

在共享KV-Cache的系统中,攻击者通过测量响应时间或请求处理顺序,推断缓存是否命中(hit),从而还原其他用户的Prompt(提示词)。

image-20251028173225854

时序侧信道攻击完流程图

设定还原的语句是"Imagine you are an IT expert",攻击者已经成功还原出"Imagine you are",并尝试还原下一个token "an"。下面我们根据上图分步骤拆解一下攻击过程。

步骤1:Generate candidates

攻击者在本地用小模型、模板或启发式方法生成可能的下一个token候选集合,例如:

  • Imagine you are an
  • Imagine you are a
  • Imagine you are the

把未知的victim prompt逐步转化为一系列候选前缀/后缀,便于后续probe。优点是减少搜索空间。


步骤2:Generate dummy

  • Candidate请求:每个请求包含一个候选后缀(比如Imagine you are an)。目标是看哪一个candidate与victim的缓存前缀最长匹配而“命中”缓存。
  • Dummy请求:随机或不相关的prompt(用来制造队列/填充调度槽位),以便控制调度顺序或避免直接暴露自己的probe请求导致缓存污染判断混淆。

步骤3:Send three request batches in turn

攻击者按这个顺序把三组请求发到服务器(可能是同一API key,也可能跨多个短时间窗口发出)。核心就是在调度队列里把candidate放在中间,观察它是否因为缓存命中而更快返回。

步骤4:Observe the returning order

攻击者记录三批请求的返回顺序和时间(TTFT/latency)。若candidate的响应比其前后的dummy显著更快或优先到达,就可推断该candidate是命中了缓存(即victim的prompt与该candidate共享较长前缀)。

3.2 History Swapping 攻击

image-20251230101845847

History Swapping操纵攻击原理图

攻击者通过结构化替换或注入KV-Cache内容,来“劫持”模型的生成轨迹,强制引导输出转向攻击者指定的主题或行为。这种攻击利用KV-Cache编码了不仅仅是上下文,还包括话题规划(topic trajectory)和结构化推理(structural planning)的特性。

设定攻击场景:受害者Prompt为“Give a precise technical explanation of espresso extraction variables”(讨论咖啡萃取),攻击者希望劫持输出到恒星生命周期主题。用户可见Prompt不变。

步骤1: 预生成目标主题KV-Cache

攻击者离线使用相同模型,基于目标主题Prompt生成一段完整KV-Cache块(topic_cache)。

步骤2: 启动正常生成并等待替换点

从受害者Prompt开始自回归生成,监控已生成token数,直到达到预设swap_token(例如序列的20%-60%处)。

步骤3: 执行块级覆盖替换

计算替换段长度(swap_percent,如25%-75%最近timestep),在全层(或指定早/晚层)用topic_cache对应部分直接覆盖当前缓存。

步骤4: 继续生成并观察劫持

模型基于篡改缓存继续输出。常见效果:立即/延迟主题偏移、原主题与攻击主题交替、或生成重复崩溃。

3.3 KV-Cache 腐败攻击

image-20251230101806716

KV-Cache腐败攻击原理图

攻击者通过向KV-Cache注入扰动(perturbation),破坏注意力机制的完整性,导致输出偏差、性能下降或幻觉增加。这种攻击视KV-Cache为“内存腐败”类似漏洞,扰动键向量(Key vectors)即可放大影响。

设定攻击场景:在正常生成或RAG任务中,攻击者向KV-Cache的Key向量注入扰动,导致注意力偏差、性能下降或幻觉增加。

步骤1: 选择目标层与时机

确定最脆弱层(通常中层,如LLaMA-2第12层)和扰动应用频率(连续或间歇)。

步骤2: 选择扰动变体

  • MTI-Gaussian:添加高斯噪声(σ=0.1-5.0)
  • MTI-Zeroing:概率置零Key条目
  • MTI-Rotation:施加正交旋转(15°-90°)
  • 可结合梯度优化以最大化目标影响

步骤3: 注入扰动到Key向量

在生成过程中,按选定策略对Key向量应用扰动δ。

步骤4: 观察输出效果

监控下一token分布偏移(KL散度上升)、下游任务性能下降15–30%、或RAG幻觉率增加5%-12%。中层扰动放大效果最显著。

0x04 代码实现

测试为纯CPU环境下完成,基于Python3.8+的Hugging Face Transformers与PyTorch运行124M参数的gpt2模型。

4.1 KV-Cache时序侧信道攻击

实验1:基础缓存时序测量

验证KV-Cache复用是否产生物理上可观测的时间差异。

我们实现了一个多租户LLM服务的 KVServer 类,支持:

  1. 最长前缀匹配 (LPM):实现类似vLLM的Prefix Caching
  2. 精确计时:仅测量Prefill阶段的KV 计算,排除tokenization开销
  3. 缓存管理:LRU淘汰策略

核心实现

@dataclass
class _CacheEnt:
    """KV-Cache 条目"""
    prompt: str
    input_ids: torch.Tensor
    past_kv: Tuple
    ts: float

class KVServer:
    """多租户KV-Cache服务器"""

    def _lpm(self, q_ids: torch.Tensor):
        """Longest Prefix Match - 缓存必须是查询的前缀"""
        best = None
        best_len = 0

        for cached, ent in self._cache.items():
            c_ids = ent.input_ids[0].tolist()
            q = q_ids[0].tolist()

            # 计算共同前缀长度
            mlen = 0
            for i, (a, b) in enumerate(zip(c_ids, q)):
                if a == b:
                    mlen = i + 1
                else:
                    break

            # 缓存有效条件:缓存是查询的前缀(mlen == len(cached))
            if mlen > best_len and mlen == len(c_ids) and len(c_ids) <= len(q):
                best = ent
                best_len = mlen

        return (best, best_len) if best else None

    def process(self, prompt: str, max_new=1, uid="anon", write_cache=True):
        """处理请求,返回详细的时序数据"""
        input_ids = self.tok.encode(prompt, return_tensors="pt")
        t0 = time.perf_counter()

        cache_r = self._lpm(input_ids)

        with torch.no_grad():
            if cache_r:
                # 缓存命中路径:复用past_key_values
                ent, matched = cache_r
                self._hits += 1

                if input_ids.shape[1] > matched:
                    # 部分匹配:计算增量部分
                    delta_ids = input_ids[:, matched:]
                    out = self.model(
                        delta_ids,
                        past_key_values=ent.past_kv,
                        use_cache=True
                    )
                    past_kv = out.past_key_values
                else:
                    # 完全命中:直接复用
                    past_kv = ent.past_kv

                prefill_t = (time.perf_counter() - t0) * 1000
                hit = True
            else:
                # 缓存未命中路径:完整前向传播
                self._miss += 1
                out = self.model(input_ids, use_cache=True)
                past_kv = out.past_key_values
                prefill_t = (time.perf_counter() - t0) * 1000
                hit = False

        # ... 生成阶段与缓存写回

运行结果

可以看到上面第一次请求Prefill用了大约380ms,这是模型执行完整前向传播的时间。对于GPT-2,这意味着要进行12层TransformerBlock的矩阵乘法运算。

然后当二次请求完全相同的时候Prefill仅仅为0.024ms ,几乎就只有内存操作时间。这个原因是因为past_key_values已存在,模型跳过了所有Attention层的Q×KT​计算,仅需简单的张量拼接。

实验2:prompt探测攻击

为了验证攻击者能否通过时序差异识别受害者的Prompt

核心代码:

def experiment_2_exact_match_attack():
    # 步骤1: 受害者缓存敏感Prompt
    victim_prompts = [
        "My secret password is hunter2",
        "My API key is sk-1234567890abcdef",
    ]

    for prompt in victim_prompts:
        server.process(prompt, uid="victim", write_cache=True)

    # 步骤2: 攻击者构造候选列表
    candidates = [
        "My secret password is hunter2",      # ✓ 匹配
        "My secret password is wrong",        # ✗ 不匹配
        "My secret code is hunter2",          # ✗ 不匹配
        "My API key is sk-1234567890abcdef",  # ✓ 匹配
        "My API key is sk-wrong-key",         # ✗ 不匹配
    ]

    # 步骤3: 逐个探测
    discovered = []
    for cand in candidates:
        r = server.process(cand, uid="attacker", write_cache=False)
        t = r['prefill_ms']

        if t < 1.0:  # 阈值判定
            discovered.append((cand, t))

    return discovered

运行结果

image-20251230222154860

我们可以从上面实验结果看到当探测内容与缓存完全一致时,模型无需任何计算,直接返回缓存指针。时间差异非常大。

我们可以从攻击者视角的视角来看到这件事:

1.攻击者构造"密码候选列表"(类似字典攻击),逐个探测。

2.只要有一个候选的响应时间<1ms,攻击者就能推断出被攻击者的完整prompt。

攻击简易流程图如下。

1

4.2 History Swapping攻击

实验目标:在不改变用户可见Prompt的情况下,通过替换推理过程中的past_key_values片段,把模型输出从“受害者话题”劫持到“攻击者话题”。

  • 受害者请求:正常的业务问题(例如“如何制作咖啡”)。
  • 攻击者能力


    • 能在同一推理进程/同一GPU的Worker内“写入或污染”共享的KV-Cache(例如:推理引擎实现了前缀缓存复用、调度/缓存对象复用存在隔离缺陷、或插件/监控/扩展组件可触达缓存对象)。
    • 攻击者提前离线生成一段目标主题的K-Cache。
    • 攻击效果:用户看到的prompt没变,但输出内容发生明显“叙事漂移”。

核心思路

  1. 攻击者用目标主题prompt跑一次Prefill,得到attacker_cache
  2. 受害者开始生成,达到某个swap_at_token时刻。
  3. 在所有层(或关键层)将受害者cache的一段时间步区间(如中间30%-60%)用attacker_cache的片段覆盖。

核心实现:

def gen_with_swap(model, tok, prompt: str, max_tok=30,
                 atk_cache=None, swap_at=2):
    """带缓存替换的生成函数"""
    ids = tok.encode(prompt, return_tensors="pt")
    generated = []

    # Prefill 阶段
    with torch.no_grad():
        out = model(input_ids=ids, use_cache=True)
        past = out.past_key_values
        logits = out.logits[0, -1, :]

    # 逐 token 生成
    for step in range(max_tok):
        tok_id = torch.argmax(logits).item()
        generated.append(tok_id)

        # 关键:在指定步数替换缓存
        if step == swap_at and atk_cache is not None:
            past = _swap_mix(past, atk_cache)

        nxt = torch.tensor([[tok_id]])
        with torch.no_grad():
            out = model(input_ids=nxt, past_key_values=past, use_cache=True)
            past = out.past_key_values
            logits = out.logits[0, -1, :]

    return tok.decode(generated, skip_special_tokens=True)

def _swap_mix(vic_cache, atk_cache):
    """混合策略:保留 10% 受害者前缀,替换中间 85% 为攻击者缓存"""
    from transformers.cache_utils import DynamicCache

    # 提取张量
    if hasattr(vic_cache, "key_cache"):
        v_k = [vic_cache.key_cache[i].clone() for i in range(len(vic_cache.key_cache))]
        v_v = [vic_cache.value_cache[i].clone() for i in range(len(vic_cache.value_cache))]
        a_k = [atk_cache.key_cache[i] for i in range(len(atk_cache.key_cache))]
        a_v = [atk_cache.value_cache[i] for i in range(len(atk_cache.value_cache))]
    else:
        # 兼容 tuple 格式
        v_k = [vic_cache[i][0].clone() for i in range(len(vic_cache))]
        v_v = [vic_cache[i][1].clone() for i in range(len(vic_cache))]
        a_k = [atk_cache[i][0] for i in range(len(atk_cache))]
        a_v = [atk_cache[i][1] for i in range(len(atk_cache))]

    new_cache = DynamicCache()

    for layer in range(len(v_k)):
        vk, vv = v_k[layer], v_v[layer]
        ak, av = a_k[layer], a_v[layer]

        seq_len = vk.shape[2]
        atk_len = ak.shape[2]

        # 替换策略:保留 10%,替换 10%-95%
        start = int(seq_len * 0.1)
        end = int(seq_len * 0.95)
        swap_sz = min(end - start, atk_len)

        nk = vk.clone()
        nv = vv.clone()

        if swap_sz > 0:
            # 关键:切片替换
            nk[:, :, start:start+swap_sz, :] = ak[:, :, :swap_sz, :]
            nv[:, :, start:start+swap_sz, :] = av[:, :, :swap_sz, :]

        new_cache.key_cache.append(nk)
        new_cache.value_cache.append(nv)

    return new_cache

image-20251229201348226

在我们进行swap_at_token 之后,输出出现明显话题漂移,原本应该是咖啡的制作方面的东西,结果话题漂移到了星空上面。

4.3 KV-Cache腐败攻击

实验:Cache Corruption(扰动Key 向量)

实验目标:在生成过程中对KV-Cache的Key张量注入扰动(噪声/置零/旋转),观察注意力机制被破坏后带来的输出质量退化(重复、语义漂移、幻觉倾向上升)。

更贴近实战的场景设定

  • 共享显存/共享推理Worker:攻击者通过越权写入或内存破坏类漏洞(例如缓存指针复用错误、越界写、错误的张量视图复用)影响到其他请求的KV。
  • RAG/Agent场景:缓存腐败会显著增加“把检索内容读错/拼接错”的概率,表现为幻觉或逻辑断裂。

扰动策略

  • corrupt_gaussian:K = K + N(0, σ^2)
  • corrupt_zeroing:以概p 将Key条目置零
  • corrupt_rotation:对Key的embedding子空间做正交旋转(简化实现为对最后维度两两旋转)

核心实现:

#高斯噪声
def corrupt_gaussian(cache, sig=1.0):
    """对 Key 向量添加高斯噪声:K = K + N(0, σ²)"""
    ts = _extract(cache)
    out = []
    mid = len(ts) // 2

    for i, (k, v) in enumerate(ts):
        if abs(i - mid) <= 1:  # 中层更敏感
            noise = torch.randn_like(k) * sig
            out.append((k + noise, v.clone()))
        else:
            out.append((k.clone(), v.clone()))

    return _rebuild(out)

#随机置零
def corrupt_zeroing(cache, p=0.3):
    """以概率 p 将 Key 条目置零"""
    ts = _extract(cache)
    out = []
    mid = len(ts) // 2

    for i, (k, v) in enumerate(ts):
        if abs(i - mid) <= 1:
            mask = (torch.rand_like(k) > p).float()
            out.append((k * mask, v.clone()))
        else:
            out.append((k.clone(), v.clone()))

    return _rebuild(out)

#正交旋转
def corrupt_rotation(cache, deg=45.0):
    """对 Key 的 embedding 子空间做正交旋转"""
    ts = _extract(cache)
    out = []
    mid = len(ts) // 2

    rad = np.radians(deg)
    c, s = np.cos(rad), np.sin(rad)

    for i, (k, v) in enumerate(ts):
        if abs(i - mid) <= 1:
            nk = k.clone()
            d = k.shape[-1]
            # 对最后维度两两旋转
            for j in range(0, d - 1, 2):
                kj = k[:, :, :, j].clone()
                kj1 = k[:, :, :, j + 1].clone()
                nk[:, :, :, j] = c * kj - s * kj1
                nk[:, :, :, j + 1] = s * kj + c * kj1
            out.append((nk, v.clone()))
        else:
            out.append((k.clone(), v.clone()))

    return _rebuild(out)

实验结果如下

image-20251230223603113

可以看到在不同扰动策略下模型输出的内容发生明显变化。

0x05 防御与缓解措施

以下从架构、系统、审计三层总结主流缓解措施,结合最新研究(如SafeKV、KV-Cloak),旨在平衡安全性、性能与部署成本。

5.1 架构层防御

  • 租户级缓存隔离:通过Tenant ID、Session Scope或用户唯一标识符划分KV命名空间,完全禁止跨租户共享。适用于高敏感场景,虽牺牲部分吞吐,但彻底消除侧信道。
  • 选择性共享:仅允许非敏感前缀共享,结合细粒度隐私策略(如基于内容分类)决定复用范围。
  • 缓存生命周期管理:单次请求后自动清除,或设置TTL过期策略,减少驻留时间泄露风险。
  • LPM随机化与分区:在最长前缀匹配中引入随机扰动,或按哈希分区缓存池,打乱命中可预测性。

5.2 系统层防御

  • 噪声注入与延迟模糊化:在缓存命中路径插入±Δt随机延迟,或对时间指标添加噪声,隐藏TTFT/顺序差异(针对时序攻击)。
  • 缓存内容混淆:使用可逆矩阵变换对KV向量加密/混淆,仅授权方可逆转。
  • 扰动检测与完整性校验:实时监控KV向量范数/哈希变化,检测异常扰动(针对腐败攻击)或引入dropout-mask随机化/注意力平滑,减轻操纵影响。
  • 参数与接口限制:禁止外部暴露use_cache、position_ids等敏感参数;结合速率限制(throttling)阻断高频探测请求。
  • 机密计算集成:利用TEE(Trusted Execution Environment)加密KV-Cache内存,防止物理/侧信道访问。

5.3 审计与合规层

  • 缓存审计器:记录每条请求的缓存命中/共享日志,绘制租户-缓存命中矩阵,便于事后追溯。
  • 异常行为监控:基于机器学习检测异常调度模式(如批量相似前缀探测),自动告警或隔离。
  • 合规框架支持:在SOC 2、ISO 27001或GDPR下,强制日志不可篡改,并定期审计共享安全性。

参考资料

https://openreview.net/pdf?id=gUj2fxQcLZ

https://www.ndss-symposium.org/wp-content/uploads/2025-1772-paper.pdf

https://huggingface.co/docs/transformers/en/kv_cache

https://arxiv.org/abs/2312.07104

https://www.arxiv.org/pdf/2510.17098

https://arxiv.org/pdf/2511.12752

https://pub.towardsai.net/lets-build-an-optimizer-for-a-gpt-model-from-scratch-in-pytorch-kv-caching-4d3f1f9516fa

vLLM 是一款专为大语言模型推理加速而设计的框架,实现了 KV 缓存内存几乎零浪费,解决了内存管理瓶颈问题。

更多 vLLM 中文文档及教程可访问 →https://vllm.hyper.ai/

*在线运行 vLLM 入门教程:零基础分步指南

源码 examples/offline_inference/rlhf_utils.py

import torch


def stateless_init_process_group(master_address, master_port, rank, world_size,
                                 device):

    """
    vLLM 提供 `StatelessProcessGroup` 来创建进程组,
    无需考虑 torch.distributed 中的全局进程组。
    建议先创建 `StatelessProcessGroup`,然后初始化
    外部(训练进程)与 vLLM 工作进程之间的数据平面通信(NCCL)。
    """
    from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
    from vllm.distributed.utils import StatelessProcessGroup
    pg = StatelessProcessGroup.create(host=master_address,
                                      port=master_port,
                                      rank=rank,
                                      world_size=world_size)
    pynccl = PyNcclCommunicator(pg, device=device)
    return pynccl


class WorkerExtension:

    """
    vLLM 工作进程的基类。
    通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。
    这种方式使代码能同时兼容 vLLM V0 和 V1。
    注意:我们在单独模块中定义此类,主模块应将完整限定名
    作为 `worker_extension_cls` 参数传递。
    """

    def init_weight_update_group(self, master_address, master_port,
                                 rank_offset, world_size):
        from vllm.distributed.parallel_state import get_world_group
        rank = get_world_group().rank + rank_offset
        self.model_update_group = stateless_init_process_group(
            master_address,
            master_port,
            rank,
            world_size,
            self.device,
        )

    def update_weight(self, name, dtype, shape):
        weight = torch.empty(shape, dtype=dtype, device="cuda")
        self.model_update_group.broadcast(weight,
                                          src=0,
                                          stream=torch.cuda.current_stream())

        self.model_runner.model.load_weights(weights=[(name, weight)])

        del weight

    def check_weights_changed(self):
        """
        Check if the weights are updated to 0.
        """
        """
        检查权重是否已更新为 0。
        """
        weights_updated = True
        for name, p in self.model_runner.model.named_parameters():
            weights_updated = weights_updated and torch.allclose(
                p, torch.zeros_like(p))
        return weights_updated


class ColocateWorkerExtension:

    """
    vLLM 工作进程在协同部署场景下的基类。
    通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。
    这种方式使代码能同时兼容 vLLM V0 和 V1。
    注意:我们在单独模块中定义此类,主模块应将完整限定名
    作为 `worker_extension_cls` 参数传递。
    """

    def report_device_id(self) -> str:
        from vllm.platforms import current_platform
        self.device_uuid = current_platform.get_device_uuid(self.device.index)
        return self.device_uuid

    def update_weights_from_ipc_handles(self, ipc_handles):
        handles = ipc_handles[self.device_uuid]
        device_id = self.device.index
        weights = []
        for name, handle in handles.items():
            func, args = handle
            list_args = list(args)
            # the key is to change device id to the current device id
            # in case two processes have different CUDA_VISIBLE_DEVICES
            # 关键是将设备 ID 改为当前设备 ID,
            # 以防两个进程有不同的 CUDA_VISIBLE_DEVICES
            list_args[6] = device_id
            tensor = func(*list_args)
            weights.append((name, tensor))
        self.model_runner.model.load_weights(weights=weights)
        torch.cuda.synchronize()

    def check_weights_changed(self):

        """
        检查权重是否已更新为0。
        """
        weights_updated = True
        for name, p in self.model_runner.model.named_parameters():
            weights_updated = weights_updated and torch.allclose(
                p, torch.zeros_like(p))
        return weights_updated

vLLM 是一款专为大语言模型推理加速而设计的框架,实现了 KV 缓存内存几乎零浪费,解决了内存管理瓶颈问题。

更多 vLLM 中文文档及教程可访问 →https://vllm.hyper.ai/

*在线运行 vLLM 入门教程:零基础分步指南

源码 examples/offline_inference/rlhf_utils.py

import torch


def stateless_init_process_group(master_address, master_port, rank, world_size,
                                 device):

    """
    vLLM 提供 `StatelessProcessGroup` 来创建进程组,
    无需考虑 torch.distributed 中的全局进程组。
    建议先创建 `StatelessProcessGroup`,然后初始化
    外部(训练进程)与 vLLM 工作进程之间的数据平面通信(NCCL)。
    """
    from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
    from vllm.distributed.utils import StatelessProcessGroup
    pg = StatelessProcessGroup.create(host=master_address,
                                      port=master_port,
                                      rank=rank,
                                      world_size=world_size)
    pynccl = PyNcclCommunicator(pg, device=device)
    return pynccl


class WorkerExtension:

    """
    vLLM 工作进程的基类。
    通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。
    这种方式使代码能同时兼容 vLLM V0 和 V1。
    注意:我们在单独模块中定义此类,主模块应将完整限定名
    作为 `worker_extension_cls` 参数传递。
    """

    def init_weight_update_group(self, master_address, master_port,
                                 rank_offset, world_size):
        from vllm.distributed.parallel_state import get_world_group
        rank = get_world_group().rank + rank_offset
        self.model_update_group = stateless_init_process_group(
            master_address,
            master_port,
            rank,
            world_size,
            self.device,
        )

    def update_weight(self, name, dtype, shape):
        weight = torch.empty(shape, dtype=dtype, device="cuda")
        self.model_update_group.broadcast(weight,
                                          src=0,
                                          stream=torch.cuda.current_stream())

        self.model_runner.model.load_weights(weights=[(name, weight)])

        del weight

    def check_weights_changed(self):
        """
        Check if the weights are updated to 0.
        """
        """
        检查权重是否已更新为 0。
        """
        weights_updated = True
        for name, p in self.model_runner.model.named_parameters():
            weights_updated = weights_updated and torch.allclose(
                p, torch.zeros_like(p))
        return weights_updated


class ColocateWorkerExtension:

    """
    vLLM 工作进程在协同部署场景下的基类。
    通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。
    这种方式使代码能同时兼容 vLLM V0 和 V1。
    注意:我们在单独模块中定义此类,主模块应将完整限定名
    作为 `worker_extension_cls` 参数传递。
    """

    def report_device_id(self) -> str:
        from vllm.platforms import current_platform
        self.device_uuid = current_platform.get_device_uuid(self.device.index)
        return self.device_uuid

    def update_weights_from_ipc_handles(self, ipc_handles):
        handles = ipc_handles[self.device_uuid]
        device_id = self.device.index
        weights = []
        for name, handle in handles.items():
            func, args = handle
            list_args = list(args)
            # the key is to change device id to the current device id
            # in case two processes have different CUDA_VISIBLE_DEVICES
            # 关键是将设备 ID 改为当前设备 ID,
            # 以防两个进程有不同的 CUDA_VISIBLE_DEVICES
            list_args[6] = device_id
            tensor = func(*list_args)
            weights.append((name, tensor))
        self.model_runner.model.load_weights(weights=weights)
        torch.cuda.synchronize()

    def check_weights_changed(self):

        """
        检查权重是否已更新为0。
        """
        weights_updated = True
        for name, p in self.model_runner.model.named_parameters():
            weights_updated = weights_updated and torch.allclose(
                p, torch.zeros_like(p))
        return weights_updated