文章

摘要级和句子级检索

摘要级和句子级检索

在信息检索系统中,摘要级检索与句子级检索模型的训练存在显著差异,主要体现在数据构造、模型架构、训练目标和应用场景等方面。以下是两者的核心区别及对应的技术方案设计:


1. 训练数据构造差异

维度摘要级检索模型句子级检索模型
文本长度处理200-1000词的长文本(文档摘要)处理10-50词的短文本(独立句子)
标注粒度文档级相关性标签(0/1或分级评分)句子级细粒度标签(精确匹配度0-1)
负样本来源跨文档负采样(不同主题文档的摘要)同文档负采样(同一文档中不相关句子)
数据增强策略摘要改写、主题混淆同义词替换、局部删除

示例数据格式:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 摘要级训练样本
{
    "query": "气候变化的影响",
    "summary": "全球变暖导致极地冰川融化...(200词)",
    "label": 0.8  # 相关性分数
}

# 句子级训练样本
{
    "query": "巴黎气候协定签署时间",
    "sentence": "2015年12月12日,195个国家在巴黎达成气候变化协议",
    "label": 1.0  # 精确匹配
}

2. 模型架构设计差异

摘要级检索模型

1
2
3
4
5
6
7
8
9
10
11
12
# 使用长文本编码器(如Longformer)
from transformers import LongformerModel

class SummaryEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = LongformerModel.from_pretrained("allenai/longformer-base-4096")
        self.pooler = nn.Linear(768, 256)  # 维度压缩

    def forward(self, texts):
        outputs = self.encoder(**texts)
        return self.pooler(outputs.last_hidden_state[:,0])  # [CLS] pooling

句子级检索模型

1
2
3
4
5
6
7
8
9
10
11
# 使用密集检索架构(如DPR)
class DenseSentenceRetriever(nn.Module):
    def __init__(self):
        super().__init__()
        self.query_encoder = AutoModel.from_pretrained("bert-base-uncased")
        self.ctx_encoder = AutoModel.from_pretrained("bert-base-uncased")
    
    def forward(self, query_inputs, ctx_inputs):
        q_emb = self.query_encoder(**query_inputs).last_hidden_state[:,0]
        c_emb = self.ctx_encoder(**ctx_inputs).last_hidden_state[:,0]
        return torch.matmul(q_emb, c_emb.T)  # 相似度矩阵

关键区别:

  • 上下文窗口:摘要级模型需支持4K+ tokens,句子级通常≤512 tokens
  • 参数共享:句子级常使用双编码器,摘要级多用单一编码器
  • 位置编码:摘要级需处理长程位置关系,常使用相对位置编码

3. 训练目标差异

摘要级训练目标

1
2
3
4
5
6
7
# 对比学习目标(InfoNCE Loss)
def contrastive_loss(anchor, positive, negatives, temperature=0.05):
    sim_pos = torch.cosine_similarity(anchor, positive, dim=-1) / temperature
    sim_neg = torch.cosine_similarity(anchor.unsqueeze(1), negatives, dim=-1) / temperature
    logits = torch.cat([sim_pos.unsqueeze(1), sim_neg], dim=1)
    labels = torch.zeros(logits.size(0), dtype=torch.long)
    return F.cross_entropy(logits, labels)

句子级训练目标

1
2
3
4
5
6
7
8
# 精细排序损失(Listwise Loss)
class ListNetLoss(nn.Module):
    def forward(self, scores, labels):
        # scores: (batch_size, num_candidates)
        # labels: (batch_size, num_candidates)
        P_y = F.softmax(scores, dim=1)
        P_z = F.softmax(labels, dim=1)
        return -(P_z * torch.log(P_y)).sum(dim=1).mean()

目标对比:

目标类型摘要级模型句子级模型
主要目标提高召回率(Recall)提升精确率(Precision)
辅助目标主题覆盖度(Topic Coverage)局部一致性(Local Context)
评估指标MRR@100NDCG@10

4. 训练策略差异

摘要级模型训练技巧

1
2
3
4
5
6
# 困难负样本挖掘
class HardNegativeMiner:
    def mine(self, query_vec, corpus_vecs, top_k=50):
        similarities = torch.matmul(query_vec, corpus_vecs.T)
        _, indices = torch.topk(similarities, top_k, dim=1)
        return indices[:,1:]  # 排除正样本后的最高相似负样本

句子级模型训练技巧

1
2
3
4
5
6
# 上下文增强
class ContextAugmenter:
    def add_context(self, sentence, window_size=2):
        # 添加前后相邻句子作为上下文
        idx = self.get_sentence_index(sentence)
        return " ".join(self.corpus[idx-window_size : idx+window_size+1])

关键训练策略对比:

策略摘要级模型句子级模型
Batch Size较小(16-32)因长文本内存限制较大(64-128)短文本允许批量处理
学习率调度线性warmup(10\%训练步数)余弦退火(快速收敛)
正则化方法梯度裁剪(防止长文本梯度爆炸)Dropout(防过拟合短文本)

5. 典型应用场景

摘要级模型适用场景

1
2
3
4
5
# 文档检索系统
def document_retrieval(query):
    summary_emb = summary_encoder.encode(query)
    doc_ids = faiss_index.search(summary_emb, top_k=100)
    return [doc_db.get(id) for id in doc_ids]

句子级模型适用场景

1
2
3
4
5
6
# 问答系统证据提取
def answer_extraction(query, document):
    sentences = split_into_sentences(document)
    sentence_embs = sentence_encoder.encode(sentences)
    scores = torch.matmul(query_emb, sentence_embs.T)
    return sentences[torch.argmax(scores)]

6. 协同训练建议

  1. 级联训练(Cascade Training)
    • 先训练摘要级模型,固定其参数后训练句子级模型
    • 使用摘要级检索结果作为句子级训练的候选池
  2. 联合优化(Joint Optimization)

    1
    2
    3
    4
    5
    
    class JointLoss(nn.Module):
        def forward(self, doc_scores, sent_scores, doc_labels, sent_labels):
            doc_loss = F.binary_cross_entropy(doc_scores, doc_labels)
            sent_loss = F.mse_loss(sent_scores, sent_labels)
            return 0.7*doc_loss + 0.3*sent_loss  # 动态权重调整
    
  3. 共享表示学习
    1
    2
    3
    4
    5
    6
    7
    
    # 共享底层编码器
    class SharedEncoder(nn.Module):
        def __init__(self):
            super().__init__()
            self.base_encoder = BertModel.from_pretrained('bert-base-uncased')
            self.doc_head = nn.Linear(768, 256)  # 摘要级投影头
            self.sent_head = nn.Linear(768, 256) # 句子级投影头
    

总结:两种模型的核心差异

维度摘要级模型句子级模型
核心任务快速筛选相关文档精准定位关键证据
关注点宏观主题匹配微观语义匹配
计算消耗高(长文本处理)低(短文本处理)
延迟要求容忍较高(百毫秒级)要求极低(十毫秒级)
可解释性较低(整体相关性)较高(可定位具体句子)

实际应用中,通常采用级联架构:摘要级模型作为召回器(Recall Optimizer),句子级模型作为精排器(Precision Optimizer),通过两阶段处理平衡效率与精度。

本文由作者按照 CC BY 4.0 进行授权