# -*- coding: utf-8 -*-# ============================================================# 一、文档加载(企业知识库测试数据)# ============================================================def load_sample_docs(): """ 加载企业知识库测试文档(15条) 包含:精确匹配类(编号/术语)、语义类(概念理解)、混合类 """ docs = [ # 编号精确匹配 "[工单编号 ORD-2024-00123] 客户:张三,状态:已处理,负责部门:技术支持部,解决时间:2小时", "[工单编号 ORD-2024-00124] 客户:李四,状态:处理中,负责部门:售后部,解决时间:待定", "[Bug编号 BUG-2024-00888] HTTP 404错误,发生在用户登录页面,影响版本:v2.1.0,已修复", "[Bug编号 BUG-2024-00899] NullPointerException,发生在订单模块,影响版本:v2.1.1,处理中", # 术语精确匹配 "系统配置:JWT Token 过期时间为 3600 秒(1小时),刷新 Token 有效期为 7 天", "系统配置:Redis Session 超时时间为 30 分钟,MySQL 连接池最大连接数为 100", "API 接口限流规则:普通用户每秒 10 次请求,VIP 用户每秒 100 次请求", # 语义匹配(概念理解) "如何部署应用到测试环境?步骤:1. 登录 Jenkins 2. 选择项目 3. 点击构建 4. 验证部署结果", "安装教程:Python 依赖通过 pip install -r requirements.txt 安装,Node 依赖通过 npm install", "服务器迁移指南:迁移前需备份数据库、配置文件、日志文件,迁移后验证服务可用性", # 混合类(编号+语义) "[升级指南 v2.1.0 to v2.1.1] 本次更新修复了 BUG-2024-00899 的 NullPointerException,建议所有用户升级", "[公告] 张三(技术支持部)已提交 ORD-2024-00123 的处理报告,请相关人员查阅", # 绩效/员工信息 "员工绩效评级:2024年Q1,张三绩效评级为 A(优秀),李四绩效评级为 B(良好),王五绩效评级为 C(待改进)", "绩效考核标准:A级(>=90分)、B级(80-89分)、C级(70-79分)、D级(<70分)", "[流程文档] 请假申请流程:员工提交 → 部门主管审批 → HR 备案 → 系统记录", ] return docs# ============================================================# 二、BM25 检索器(LangChain BaseRetriever 封装)# ============================================================from typing import List, Dict, Tuple, Optional, Anyfrom langchain_core.retrievers import BaseRetrieverfrom langchain_core.documents import Documentfrom pydantic import Fieldimport jiebafrom rank_bm25 import BM25Okapidef get_doc_id(doc: str) -> str: """从文档内容提取 doc_id(工单/Bug编号等)""" import re match = re.search(r'\[([^\]]+)\]', doc) if match: return match.group(1) return doc[:30]class BM25Retriever(BaseRetriever): """ BM25 检索器,封装为 LangChain BaseRetriever Pydantic v2 继承规范(LangChain BaseRetriever): - 用 __init__ 参数直接传值给 super().__init__() - 普通实例变量(_bm25/_tokenized_corpus)用 self.xxx 直接赋值 - __init__ 结束后 model_post_init 会被 Pydantic 自动调用 """ # 类级别:显式声明 Pydantic 字段 docs: List[str] = Field(default_factory=list) doc_ids: List[str] = Field(default_factory=list) k: int = Field(default=5) def __init__(self, docs: List[str], doc_ids: List[str] | None = None, k: int = 5, k1: float = 1.5, b: float = 0.75): # 自动生成 doc_ids _doc_ids = doc_ids if doc_ids is not None else [get_doc_id(d) for d in docs] # 分词 _tokenized = [self._tokenize(doc) for doc in docs] # 构建 BM25 模型 _bm25 = BM25Okapi(_tokenized, k1=k1, b=b) # 调用父类 __init__(传递 Pydantic 字段) super().__init__(docs=docs, doc_ids=_doc_ids, k=k) # 普通实例变量(Pydantic 不会管这些) self._tokenized_corpus = _tokenized self._bm25 = _bm25 def model_post_init(self, _context: Any) -> None: """ Pydantic v2 钩子:在父类 __init__ 完成后调用 用于依赖 self.xxx 的初始化逻辑(BM25 在 __init__ 里已构建,此处为空) """ pass @staticmethod def _tokenize(text: str) -> List[str]: """中英文混合分词""" import re brackets = re.findall(r'\[([^\]]+)\]', text) brackets_tokens = [b.split() for b in brackets] tokens = list(jieba.cut(text.lower())) tokens = [t.strip() for t in tokens if t.strip()] for bt in brackets_tokens: tokens.extend(bt) return tokens def _get_relevant_documents(self, query: str) -> List[Document]: """核心:BM25 得分计算,返回 Top-K 文档""" query_tokens = self._tokenize(query) scores = self._bm25.get_scores(query_tokens) ranked = sorted(zip(scores, self.doc_ids, self.docs), key=lambda x: x[0], reverse=True) top_k = ranked[:self.k] return [ Document(page_content=doc_content, metadata={"doc_id": doc_id, "score": float(score)}) for score, doc_id, doc_content in top_k ] def get_scores(self, query: str) -> Dict[str, float]: """获取所有文档的 BM25 得分(用于 RRF 融合)""" query_tokens = self._tokenize(query) scores = self._bm25.get_scores(query_tokens) return {doc_id: float(score) for doc_id, score in zip(self.doc_ids, scores)}# ============================================================# 三、向量检索器(Chroma + Ollama Embeddings)# ============================================================import osimport shutildef run_chromadb_patch(): """猴子补丁:修复 chromadb 1.5.7 + langchain-chroma 1.1.0 兼容性""" if getattr(run_chromadb_patch, '_done', False): return run_chromadb_patch._done = True try: import chromadb.api for _name in getattr(chromadb.api, '__all__', dir(chromadb.api)): if _name.startswith('_'): continue try: _cls = getattr(chromadb.api, _name) if hasattr(_cls, 'get_or_create_collection'): _orig = _cls.get_or_create_collection def _patched(self, name, configuration_metadata=None, get_or_create=False, embeddings=None, data_loader=None, **kwargs): kwargs.pop('embedding_function', None) return _orig(self, name, configuration_metadata=configuration_metadata, get_or_create=get_or_create, embeddings=embeddings, data_loader=data_loader, **kwargs) _cls.get_or_create_collection = _patched print("[OK] chromadb patch applied") break except Exception: continue except Exception: passclass OllamaEmbeddings: """ 自定义 Ollama Embeddings 实现,绕过 langchain-ollama 的兼容性问题 适配 LangChain 1.x + langchain-chroma 1.x """ def __init__(self, model: str = "nomic-embed-text:latest", base_url: str = "http://localhost:11434"): self.model = model self.base_url = base_url self._client = None def _get_client(self): if self._client is None: try: import ollama self._client = ollama.Client(host=self.base_url) except ImportError: raise ImportError("请安装 ollama: pip install ollama") return self._client def embed_documents(self, texts: List[str]) -> List[List[float]]: """嵌入多个文档""" client = self._get_client() embeddings = [] for text in texts: response = client.embeddings(model=self.model, prompt=text) embeddings.append(response["embedding"]) return embeddings def embed_query(self, text: str) -> List[float]: """嵌入单个查询""" client = self._get_client() response = client.embeddings(model=self.model, prompt=text) return response["embedding"]def get_vector_retriever(docs: List[str]) -> BaseRetriever: """ 获取向量检索器(Chroma + Ollama Embeddings) 依赖:Ollama 服务运行中 + nomic-embed-text:v1.5 模型已拉取 """ run_chromadb_patch() from langchain_chroma import Chroma embeddings = OllamaEmbeddings(model="nomic-embed-text:latest", base_url="http://localhost:11434") # 兼容交互式环境(__file__ 可能不存在) try: _base_dir = os.path.dirname(os.path.abspath(__file__)) except NameError: _base_dir = os.getcwd() persist_dir = os.path.join(_base_dir, "chroma_db") if os.path.exists(persist_dir): shutil.rmtree(persist_dir) vectorstore = Chroma.from_texts( texts=docs, embedding=embeddings, persist_directory=persist_dir, ids=[get_doc_id(d) for d in docs], ) return vectorstore.as_retriever(search_kwargs={"k": 5})# ============================================================# 四、RRF 排名融合# ============================================================def rrf_fusion(ranker_results: List[Tuple[str, List[Tuple[str, int]]]], k: int = 60) -> List[Tuple[str, float]]: """ Reciprocal Rank Fusion(倒数排名融合) 原理:每个检索器给文档打排名(第1名、第2名...), RRF 得分 = sum(1/(k + rank_i)), 多个检索器都认可的文档得分更高。 """ from collections import defaultdict doc_scores: Dict[str, float] = defaultdict(float) for _ranker_name, ranked_list in ranker_results: for doc_id, rank in ranked_list: doc_scores[doc_id] += 1.0 / (k + rank) return sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)# ============================================================# 五、混合检索主函数# ============================================================def hybrid_search(query: str, bm25_retriever: BM25Retriever, vector_retriever: BaseRetriever, top_k: int = 5, rrf_k: int = 60) -> List[Tuple[Document, float, Dict[str, Any]]]: """ 混合检索:BM25 + 向量 -> RRF 融合 """ # Step1: 分别检索 bm25_scores = bm25_retriever.get_scores(query) vector_results = vector_retriever.invoke(query) # Step2: 构建排名列表 bm25_ranked = sorted(bm25_scores.items(), key=lambda x: x[1], reverse=True) bm25_ranked_list = [(doc_id, rank) for rank, (doc_id, _) in enumerate(bm25_ranked)] doc_id_to_doc = {d.metadata.get("doc_id", ""): d for d in vector_results} vector_ranked_list = [(d.metadata.get("doc_id", f"v{ri}"), ri) for ri, d in enumerate(vector_results)] # Step3: RRF 融合 fused = rrf_fusion([("bm25", bm25_ranked_list), ("vector", vector_ranked_list)], k=rrf_k) # Step4: 组装最终结果 results = [] for doc_id, rrf_score in fused[:top_k]: bm25_doc = next((d for d in bm25_retriever.docs if get_doc_id(d) == doc_id), None) vector_doc = doc_id_to_doc.get(doc_id) if bm25_doc is None and vector_doc: bm25_doc = vector_doc.page_content if bm25_doc is None: continue bm25_rank = next((r for d_id, r in bm25_ranked_list if d_id == doc_id), None) vector_rank = next((r for d_id, r in vector_ranked_list if d_id == doc_id), None) results.append(( Document(page_content=bm25_doc, metadata={"doc_id": doc_id, "rrf_score": rrf_score}), rrf_score, {"bm25_rank": bm25_rank, "vector_rank": vector_rank, "bm25_score": bm25_scores.get(doc_id, 0), "vector_score": (vector_doc.metadata.get("score", 0) if vector_doc else 0)} )) return results# ============================================================# 六、演示# ============================================================def demo(): """演示:三种检索方式对比""" docs = load_sample_docs() print(f"[OK] Load {len(docs)} docs") # BM25 bm25_retriever = BM25Retriever(docs=docs, k=5) print(f"[OK] BM25 retriever ready (_bm25={type(bm25_retriever._bm25).__name__})") # 向量(可选) try: vector_retriever = get_vector_retriever(docs) print("[OK] Vector retriever ready (Chroma+Ollama)") has_vector = True except Exception as e: print(f"[WARN] Vector retriever unavailable ({e}), skip vector part") vector_retriever = None has_vector = False test_queries = [ ("ORD-2024-00123", "exact-id"), ("张三的工单", "exact-name"), ("部署应用到测试环境怎么做", "semantic"), ("请假申请流程", "semantic"), ("JWT 配置多少秒", "exact-term"), ] for query, qtype in test_queries: print(f"\n--- Query ({qtype}): {query} ---") bm25_results = bm25_retriever.invoke(query) print(f"[BM25 Top-3]") for i, doc in enumerate(bm25_results[:3], 1): print(f" #{i} score={doc.metadata['score']:.4f}{doc.metadata['doc_id']}") if has_vector and vector_retriever: vector_results = vector_retriever.invoke(query) print(f"[Vector Top-3]") for i, doc in enumerate(vector_results[:3], 1): print(f" #{i}{doc.page_content[:60]}") hybrid = hybrid_search(query, bm25_retriever, vector_retriever, top_k=3) print(f"[Hybrid Top-3]") for i, (doc, score, ranks) in enumerate(hybrid[:3], 1): print(f" #{i} rrf={score:.4f} b_rank={ranks['bm25_rank']} v_rank={ranks['vector_rank']}")# ============================================================# 七、快捷函数(供外部导入)# ============================================================_bm25_r: Optional[BM25Retriever] = None_vec_r: Optional[BaseRetriever] = Nonedef get_retrievers(): global _bm25_r, _vec_r if _bm25_r is None: docs = load_sample_docs() _bm25_r = BM25Retriever(docs=docs, k=5) try: _vec_r = get_vector_retriever(docs) except Exception as e: print(f"[WARN] Vector retriever unavailable: {e}") _vec_r = None return _bm25_r, _vec_rdef search(query: str, mode: str = "hybrid", top_k: int = 5) -> List[Document]: """快捷搜索函数""" bm25_r, vec_r = get_retrievers() if mode == "bm25": return bm25_r.invoke(query) if mode == "vector": return (vec_r.invoke(query) if vec_r else bm25_r.invoke(query)) # hybrid if vec_r is None: return bm25_r.invoke(query) results, _, _ = zip(*hybrid_search(query, bm25_r, vec_r, top_k=top_k)) return list(results)if __name__ == "__main__": demo()