高级 RAG 08自适应 RAG

高级 RAG 08自适应 RAG

Barry Lv6

直观示例、原理、代码解析及自适应RAG的洞察

本文以一个常见场景——开卷考试为起点。我们通常有两种策略:

  • 方法一:对于熟悉的话题,迅速作答;对于不熟悉的话题,翻阅参考书查找,快速定位相关部分,在脑海中整理和总结,然后在试卷上作答。
  • 方法二:对于每个话题,都参考书籍。找到相关章节,在心中整理和总结,然后在试卷上写出答案。

显然,方法一是更优的选择。方法二可能会耗费时间,并可能引入无关或错误的信息,这可能导致混淆和错误,甚至在你原本理解的地方也会出错。

然而,方法二体现了经典RAG 流程,而方法一则代表了自适应RAG 流程,本文将进一步探讨。

概述

图1展示了RAG与自适应RAG 主要流程的对比:

自适应RAG包含三个步骤:

  1. 按需检索:当模型需要进行检索时,例如查询“美国各州是如何命名的?”(图1右上角),模型的输出将包含一个**[Retrieve]**标记。这表明需要检索与查询相关的内容。相反,当被要求“写一篇关于你最棒的暑假的文章”(图1右下角)时,模型选择直接生成答案,无需检索。
  2. 并行生成:模型利用提示和检索到的内容生成输出。在此过程中,三种类型的反思标记指示检索内容的关联性。
  3. 评估与选择:对步骤2中生成的内容进行评估,并选择最佳片段作为输出。

请注意,上述模型是经过专门训练的模型。其训练过程将在本文后续部分讨论。

反思令牌

与RAG相比,自RAG框架的不同之处在于它使用反思令牌在生成过程中进行更精确的控制,如图2所示。

本质上,自RAG进行四个不同的判断:

  • **[检索]**:决定是否从资源**R**中检索信息的决策过程。
  • **[相关性检查]**:判断给定数据**d**是否包含解决问题**x**所需信息的相关性检查。
  • **[支持性验证]**:检查提供的响应**y**中的陈述是否得到数据**d**支持的验证过程。
  • **[实用性评估]**:评估响应**y**对问题x`的实用性的评估过程。输出是一个1到5的分数,其中5代表最高实用性。

在RAG中,检索是一个固定的过程,总是首先进行,无论条件如何。相比之下,自RAG引入了反思令牌,使LLM更具适应性和智能性。当LLM生成文本并遇到不确定区域时,它会在反思令牌处暂停,进行快速精确的检索,然后使用新获取的信息继续生成。

代码解释

为了直观理解自RAG过程,我们将首先审视代码,然后讨论模型的训练过程。

自RAG是开源的 Langchain 和LlamaIndex都有各自的实现。我们将以LlamaIndex的实现 作为参考进行解释。

环境配置

首先,配置环境。

1
2
3
4
5
6
7
8
9
(base) Florian@instance-1:~$ conda create -n llamaindex python=3.11

(base) Florian@instance-1:~$ conda activate llamaindex

(llamaindex) Florian@instance-1:~$ pip install llama-index

(llamaindex) Florian@instance-1:~$ pip install huggingface-hub

(llamaindex) Florian@instance-1:~$ huggingface-cli login

安装完成后,LlamaIndex 的对应版本如下:

1
2
3
llama-index                             0.10.20

llama-index-core 0.10.20.post2

下载论文提供的 Llama2–7B 模型,大小约为 4.08G。您也可以从这里 下载。

1
2
3
4
(llamaindex) Florian@instance-1:~$ huggingface-cli download m4r1/selfrag_llama2_7b-GGUF selfrag_llama2_7b.q4_k_m.gguf --local-dir "YOUR_DOWNLOAD_MODEL_DIR" --local-dir-use-symlinks False

(llamaindex) Florian@instance-1:~$ ls "YOUR_DOWNLOAD_MODEL_DIR"
selfrag_llama2_7b.q4_k_m.gguf

测试代码

测试代码如下所示。首次执行需要下载 SelfRAGPack

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
os.environ["OPENAI_API_KEY"] = "YOUR_OPENAI_API_KEY"

from llama_index.core import Document, VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.readers import SimpleDirectoryReader
from pathlib import Path


# 选项:下载 SelfRAGPack
# 首次执行需要下载 SelfRAGPack。
# 后续执行可以注释掉这一部分。
from llama_index.core.llama_pack import download_llama_pack
download_llama_pack(
"SelfRAGPack",
"./self_rag_pack")

from llama_index.packs.self_rag import SelfRAGQueryEngine

# Llama2 模型之前下载并保存的目录。
download_dir = "YOUR_DOWNLOAD_MODEL_DIR"

# 创建测试文档
documents = [
Document(
text="一群企鹅,在陆地上被称为'摇摆队',在南极冰面上蹒跚前行,它们像燕尾服一样的羽毛在雪地上格外显眼。"
),
Document(
text="帝企鹅是所有企鹅中最高的种类,它们可以潜入比任何鸟类都深的地方,达到超过500米的深度。"
),
Document(
text="企鹅的黑白配色是一种名为反荫蔽的伪装;从上方看,它们的黑色背部与海洋深处融为一体,而从下方看,它们的白色腹部与明亮的表面相匹配。"
),
Document(
text="尽管企鹅直立姿态,但它们是不能飞行的鸟类;它们的翅膀已经进化成鳍状肢,使它们成为游泳高手。"
),
Document(
text="速度最快的种类,巴布亚企鹅,可以以每小时36公里的速度游泳,利用它们的鳍状肢和流线型身体在水中穿梭。"
),
Document(
text="企鹅是群居鸟类;许多种类形成大规模的繁殖群落,数量可达数万只。"
),
Document(
text="有趣的是,企鹅具有出色的听力,并依赖独特的叫声在嘈杂的群落中识别它们的配偶和幼崽。"
),
Document(
text="最小的企鹅种类,小蓝企鹅,身高仅约40厘米,分布在澳大利亚和纽西兰的沿海地区。"
),
Document(
text="在繁殖季节,雄性帝企鹅会忍受严酷的南极冬季数月,禁食并孵化它们的蛋,而雌性则在海上捕猎。"
),
Document(
text="企鹅食用各种海鲜;它们的饮食主要由鱼类、鱿鱼和磷虾组成,这些食物是它们在潜水探险中捕获的。"
),
]

index = VectorStoreIndex.from_documents(documents)

# 设置一个简单的检索器
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=10,
)


model_path = Path(download_dir) / "selfrag_llama2_7b.q4_k_m.gguf"
query_engine = SelfRAGQueryEngine(str(model_path), retriever, verbose=True)

# 无检索示例
response = query_engine.query("《傲慢与偏见》这本书属于哪个类型?")

# 检索示例
response = query_engine.query("最小的企鹅有多高?")

测试代码产生了以下结果(大部分 llama_cpp 调试信息已被移除):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
...
...
Model metadata: {'tokenizer.ggml.add_eos_token': 'false', 'tokenizer.ggml.eos_token_id': '2', 'general.architecture': 'llama', 'llama.rope.freq_base': '10000.000000', 'llama.context_length': '4096', 'general.name': 'LLaMA v2', 'tokenizer.ggml.add_bos_token': 'true', 'llama.embedding_length': '4096', 'llama.feed_forward_length': '11008', 'llama.attention.layer_norm_rms_epsilon': '0.000010', 'llama.rope.dimension_count': '128', 'tokenizer.ggml.bos_token_id': '1', 'llama.attention.head_count': '32', 'llama.block_count': '32', 'llama.attention.head_count_kv': '32', 'general.quantization_version': '2', 'tokenizer.ggml.model': 'llama', 'general.file_type': '15'}
Using fallback chat format: None

llama_print_timings: load time = 4887.53 ms
llama_print_timings: sample time = 11.29 ms / 22 runs ( 0.51 ms per token, 1947.76 tokens per second)
llama_print_timings: prompt eval time = 4887.46 ms / 24 tokens ( 203.64 ms per token, 4.91 tokens per second)
llama_print_timings: eval time = 5883.27 ms / 21 runs ( 280.16 ms per token, 3.57 tokens per second)
llama_print_timings: total time = 10901.84 ms / 45 tokens
Final answer: 《傲慢与偏见》是简·奥斯汀所著的一部浪漫小说。
...
...
llama_print_timings: load time = 4887.53 ms
llama_print_timings: sample time = 11.74 ms / 20 runs ( 0.59 ms per token, 1703.29 tokens per second)
llama_print_timings: prompt eval time = 7473.66 ms / 37 tokens ( 201.99 ms per token, 4.95 tokens per second)
llama_print_timings: eval time = 5414.34 ms / 19 runs ( 284.96 ms per token, 3.51 tokens per second)
llama_print_timings: total time = 13076.88 ms / 56 tokens
Input: ### Instruction:
最小的企鹅有多高?

### Response:
[Retrieval]<paragraph>企鹅食用各种海鲜;它们的饮食主要由鱼类、鱿鱼和磷虾组成,这些食物是它们在潜水探险中捕获的。</paragraph>
Prediction: [Relevant]最小的企鹅种类的高度会根据种类而有所不同。[No support / Contradictory][Utility:5]
Score: 1.4213598342974367
10/10 paragraphs done

End evaluation
Selected the best answer: [Relevant]最小的企鹅种类是小蓝企鹅(也称为仙女企鹅),它们可以长到大约40厘米(16英寸)高。[Fully supported][Utility:5]
Final answer: 最小的企鹅种类是小蓝企鹅(也称为仙女企鹅),它们可以长到大约40厘米(16英寸)高。

我们可以观察到,第一个查询不需要检索,而第二个查询已经进行了检索和评估。

理解测试代码的关键在于 de>class SelfRAGQueryEngine 的实现,让我们深入研究这个类。

class SelfRAGQueryEngine

首先是构造函数 ,主要用于通过llama_cpp加载Llama2–7B模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class SelfRAGQueryEngine(CustomQueryEngine):
"""简单的自我RAG查询引擎短形式。"""

llm: Any = Field(default=None, description="llM模型")
retriever: BaseRetriever = Field(default=None, description="检索器")
generate_kwargs: Dict = Field(default=None, description="LLM生成参数")
verbose: bool = Field(default=True, description="是否详细输出")

def __init__(
self,
model_path: str,
retriever: BaseRetriever,
verbose: bool = False,
model_kwargs: Dict = None,
generate_kwargs: Dict = None,
**kwargs: Any,
) -> None:
"""初始化参数。"""
super().__init__(verbose=verbose, **kwargs)
model_kwargs = model_kwargs or _MODEL_KWARGS
self.generate_kwargs = generate_kwargs or _GENERATE_KWARGS
try:
from llama_cpp import Llama
except ImportError:
raise ImportError(_IMPORT_ERROR_MSG)
self.llm = Llama(model_path=model_path, verbose=verbose, **model_kwargs)
self.retriever = retriever

接下来,我们将解释查询函数 。其主要流程如图3所示:

关键部分已添加注释以便更好地理解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def custom_query(self, query_str: str) -> Response:
"""运行自我RAG。"""
# 使用Llama2模型获取响应。
response = self.llm(prompt=_format_prompt(query_str), **_GENERATE_KWARGS)
answer = response["choices"][0]["text"]
source_nodes = []

# 判断是否需要检索。
if "[Retrieval]" in answer:
if self.verbose:
print_text("需要检索\n", color="blue")
# 图1的步骤1,按需检索。
documents = self.retriever.retrieve(query_str)
if self.verbose:
print_text(f"收到: {len(documents)}份文档\n", color="blue")
paragraphs = [
_format_prompt(query_str, document.node.text) for document in documents
]

if self.verbose:
print_text("开始评估\n", color="blue")

# 图1的步骤2和3,并行生成和评估
# (代码未实现并行)
critic_output = self._run_critic(paragraphs)

paragraphs_final_score = critic_output.paragraphs_final_score
llm_response_per_paragraph = critic_output.llm_response_per_paragraph
source_nodes = critic_output.source_nodes

if self.verbose:
print_text("结束评估\n", color="blue")

# 选择得分最高的段落并返回。
best_paragraph_id = max(
paragraphs_final_score, key=paragraphs_final_score.get
)
answer = llm_response_per_paragraph[best_paragraph_id]
if self.verbose:
print_text(f"选定最佳答案: {answer}\n", color="blue")

answer = _postprocess_answer(answer)
if self.verbose:
print_text(f"最终答案: {answer}\n", color="green")
return Response(response=str(answer), source_nodes=source_nodes)

从代码中可以看出,图1中的三个步骤都有体现。然而,LlamaIndex的代码并未实现并行化。感兴趣的读者可以进一步查看self._run_critic 函数,它还处理了与各种反思标记对应的分数。

如何训练Llama2-7B模型

我们之前已经多次使用过Llama2-7B模型,现在让我们探讨如何获取它。

训练目标

使语言模型能够生成包含反思标记的文本。

两种模型

在训练过程中,需要两个模型:评价模型 **C** 和生成模型 **M**。评价模型 **C** 生成模型 **M** 所需的监督数据。

然而,在推理过程中,仅使用模型 M,模型 C 不需要。

评价模型 C

评价模型旨在生成反思标记。使用该模型的目的是为了能够离线将反思标记插入任务输出中,从而更新训练语料库。

手动为每个片段标注反思标记成本高昂。Self-RAG 利用 GPT-4 为每个反思标记分配独特的指令,因其定义、输入和输出各异,从而高效完成数据标注任务。例如,**[retrieval]** 标记的指令会提示 GPT-4 评估是否引入外部文档将提升结果。

一旦我们获得了训练数据 **D_critic**,我们就可以基于标准条件语言模型构建训练目标,如下所示:

评价模型 C 可以初始化为任意语言模型。例如,它可以与生成器使用相同的模型进行初始化,如 Llama2–7B。

生成器模型 M

图4展示了收集训练数据的具体过程。给定一个输入-输出对 **(x, y)**,自适应RAG利用检索和批评模型来增强原始输出 **y**,以创建监督数据。对于每个片段 **yt ∈ y**

需要注意的是,图4中的每个条件判断都是通过批评模型 **C** 执行的。获得的训练数据如图5所示:

在获得训练数据 **D_gen** 后,我们可以构建标准的下一个词预测目标函数,如下所示:

生成器 **M** 不仅需要预测输出,还需要预测反思标记。

关于self-RAG的见解与思考

总体而言,self-RAG为提升RAG流程提供了新的视角。然而,它要求更为复杂的训练过程以及在生成阶段进行多次标签生成与判断,这无疑会增加推理成本。对于需要实时性能的项目,这可能产生显著影响。

此外,该框架内仍有大量优化空间。为激发进一步讨论与创新,以下是几个要点:

  • 如何优化反思标记。Self-RAG设计了四种反思标记。除了**[Retrieve]**标记外,其他三种(**[IsREL]****[IsSUP]****[IsUSE]**)存在一定相似性。考虑使用更少的反思标记或代表其他语义的反思标记是一个可行方向。
  • 为何评判模型采用LLM?我认为这可能是因为像**[IsUSE]**这样的标记严重依赖于常识。判断对查询答案的实用性是小型模型可能完成的任务。然而,这些模型通常仅从其特定训练数据中学习,缺乏全面知识。因此,使用LLM作为评判模型是合理的。
  • 评判模型大小的选择。Self-RAG已通过7B和13B模型测试,取得了优异成果。但若切换至更小的LLM,如3B,我们能观察到哪些差异?同样,若转向更大的LLM,如33B,我们又能期待多少提升?
  • 为何不采用基于人类反馈的强化学习(RLHF)?论文建议在任务示例上训练目标语言模型。这些示例通过离线评判模型中的反思标记进行增强,相较于RLHF,训练成本大幅降低。此外,self-RAG中的反思标记使得生成在推理过程中可控,而RLHF则侧重于训练中的人类偏好对齐。不过,论文中并未包含与RLHF相关的对比实验。

结论

本文从一个直观的例子开始,介绍了自RAG的基本流程,并辅以代码解释,同时分享了我的见解和思考。

如果您对RAG技术感兴趣,欢迎查阅我的其他文章。

最后,如有任何错误或遗漏,或您有任何疑问,请随时在评论区讨论。

  • 标题: 高级 RAG 08自适应 RAG
  • 作者: Barry
  • 创建于 : 2024-03-22 04:18:22
  • 更新于 : 2024-08-31 06:59:45
  • 链接: https://wx.role.fun/2024/03/22/a8952873fe734f19ad177141a4aee237/
  • 版权声明: 本文章采用 CC BY-NC-SA 4.0 进行许可。