使用 Sentence Transformers 训练和微调多模态嵌入与重排序模型
速览
本文详细阐述了如何使用 Sentence Transformers 库对多模态嵌入模型和重排序模型进行训练与微调。通过提供具体的代码示例和最佳实践,帮助开发者构建更精准的跨模态检索系统。这对于提升多模态信息检索的准确性和效率具有重要意义。
AI 深度解读
使用 Sentence Transformers 训练与微调多模态嵌入及重排序模型
背景
通用多模态嵌入模型(如 Qwen/Qwen3-VL-Embedding-2B)通常在多样化数据上进行训练,旨在跨多种语言和任务(如图像-文本匹配、视觉问答、文档理解等)表现良好。然而,这种通用性意味着模型很少是任何特定任务的最佳选择。
以**视觉文档检索(Visual Document Retrieval, VDR)**为例,该任务的目标是根据给定的文本查询,从包含数千个文档页面的语料库中检索出最相关的文档页面(以图像形式呈现,保留图表、表格和布局)。例如,当查询为“该公司第三季度的收入是多少?”时,模型需要理解文档布局、图表、表格和文本。这与“将鞋子图片与产品描述匹配”所需的技能截然不同。
为了在特定领域获得最佳性能,需要对模型进行微调。本文以 tomaarsen/Qwen3-VL-Embedding-2B-vdr 为例,展示了如何通过针对自有领域数据进行微调来提升性能。在评估数据上,微调后的模型 NDCG@10 得分从基础模型的 0.888 提升至 0.947,优于所有测试过的现有 VDR 模型,包括那些参数量高达其 4 倍的模型。
核心内容
1. 微调的必要性
通用模型虽然覆盖面广,但在特定垂直领域(如 VDR)往往缺乏针对性。通过微调,模型可以学习特定领域的专用模式。实验表明,微调能显著提升检索精度,甚至超越更大规模的通用模型。
2. 训练组件
Sentence Transformers 库支持多模态模型的训练,其核心组件与文本-only 模型类似,但数据集包含图像等多模态数据,且处理器(Processor)会自动处理图像预处理。主要组件包括:
- Model(模型):要训练或微调的多模态模型。
- Dataset(数据集):用于训练和评估的数据。
- Loss Function(损失函数):量化模型性能并指导优化的函数。
- Training Arguments(训练参数):可选,影响训练性能及调试跟踪。
- Evaluator(评估器):可选,用于训练前、中、后的模型评估。
- Trainer(训练器):整合上述组件进行训练的核心类
SentenceTransformerTrainer。
3. 模型构建与配置
有两种主要方式构建多模态嵌入模型:
方式一:微调现有的多模态嵌入模型
如果模型已有 modules.json 文件,可以通过传递 processor_kwargs 和 model_kwargs 来控制预处理和模型加载。
processor_kwargs:直接传递给AutoProcessor.from_pretrained(...),例如设置min_pixels和max_pixels以平衡图像质量和内存占用。model_kwargs:传递给AutoModel.from_pretrained(...),例如设置attn_implementation(如flash_attention_2)和torch_dtype(如bfloat16)。
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(
"Qwen/Qwen3-VL-Embedding-2B",
model_kwargs={"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"},
processor_kwargs={"min_pixels": 28 * 28, "max_pixels": 600 * 600},
)
方式二:从头开始使用视觉-语言模型(VLM)检查点
可以使用尚未针对嵌入任务训练的 VLM 检查点。Sentence Transformers 会自动识别架构、推断支持的模态并设置前向传播方法和池化层。如果自动检测不准确,可以编辑保存的 sentence_bert_config.json 文件来调整模态设置。
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("Qwen/Qwen3-VL-2B")
替代方案:使用 Router 构建多模态模型
如果不使用单一的 VLM 骨干网络,可以使用 Router 模块组合不同模态的独立编码器。
- 原理:根据检测到的模态,将输入路由到相应的编码器。
- 优势:可以使用轻量级、专用的编码器,而非大型 VLM。
- 注意:由于不同模态使用独立编码器,嵌入空间初始未对齐,需要训练以对齐空间以实现有意义的跨模态相似度。
Dense投影层有助于将不同编码器的嵌入映射到共享空间。
from sentence_transformers import SentenceTransformer
from sentence_transformers.sentence_transformer.modules import Dense, Pooling, Router, Transformer
# 创建不同模态的独立编码器
text_encoder = Transformer("sentence-transformers/all-MiniLM-L6-v2")
text_pooling = Pooling(text_encoder.get_embedding_dimension(), pooling_mode="mean")
text_projection = Dense(text_encoder.get_embedding_dimension(), 768)
# SigLIP 直接输出池化嵌入,无需单独的 Pooling 模块
image_encoder = Transformer("google/siglip2-base-patch16-224")
# 基于模态路由输入
router = Router(
sub_modules={
"text": [text_encoder, text_pooling, text_projection],
"image": [image_encoder],
},
)
model = SentenceTransformer(modules=[router])
4. 数据集准备
以视觉文档检索为例,使用 tomaarsen/llamaindex-vdr-en-train-preprocessed 数据集。
- 来源:基于 LlamaIndex 发布的
llamaindex/vdr-multilingual-train数据集的预处理英文子集。 - 规模:原始数据集包含约 50 万条多语言查询-图像样本,预处理后筛选出 53,512 条英文样本。
- 结构:每个样本包含查询(query)、图像(image)和硬负样本(hard negatives)。预处理版本将部分基于 ID 的硬负样本解析为实际的文档截图图像,可直接用于训练。
- 数据格式:训练数据通常构建为(锚点,正样本,硬负样本)三元组。
from datasets import load_dataset
train_dataset = load_dataset("tomaarsen/llamaindex-vdr-en-train-preprocessed", "train", split="train")
train_dataset = train_dataset.select_columns(["query", "image", "negative_0"])
# 类似地加载评估数据集
5. 训练与评估
- 训练器:使用
SentenceTransformerTrainer进行训练,它支持多模态数据的自动处理。 - 评估指标:使用 NDCG@10 等指标评估检索性能。
- 结果:在 VDR 任务中,微调后的模型显著优于基础模型及同类竞品。
关键要点
- 领域微调至关重要:通用多模态模型在特定垂直任务(如视觉文档检索)中表现往往不如经过领域数据微调的模型。微调可以显著提升 NDCG 等检索指标。
- Sentence Transformers 支持多模态:该库不仅支持文本,还全面支持图像、视频等多模态数据的嵌入和重排序模型训练,API 与文本模型保持一致。
- 灵活的模型构建方式:
- 可以直接微调现有的多模态嵌入模型,通过
processor_kwargs和model_kwargs精细控制预处理和加载参数。 - 也可以从基础 VLM 检查点开始,库会自动配置嵌入所需的模块。
- 对于资源受限或需要高度定制的场景,可以使用
Router模块组合不同模态的独立编码器,但需注意嵌入空间对齐问题。
- 可以直接微调现有的多模态嵌入模型,通过
- 数据预处理是关键:高质量的多模态数据集(如包含真实图像负样本的 VDR 数据集)对训练效果有决定性影响。
- 性能提升显著:实验显示,微调后的 2B 参数模型在 VDR 任务上超越了参数量高达其 4 倍的现有模型,证明了高效微调
