import osos.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'import jsonimport requestsimport warningsfrom typing import List, Dictfrom sentence_transformers import SentenceTransformerimport faissimport numpy as npimport PyPDF2from PyPDF2.errors import PdfReadWarningfrom docx import Documentimport fitz # PyMuPDFfrom paddleocr import PaddleOCRfrom PIL import Imageimport io# 抑制PyPDF2的编码警告warnings.filterwarnings('ignore', category=PdfReadWarning)# 配置DeepSeek API密钥和URLDEEPSEEK_API_KEY = "DEEPSEEK_API_KEY"#DEEPSEEK_API_KEYDEEPSEEK_API_URL = "https://api.deepseek.com/v1/chat/completions"# 初始化嵌入模型embedding_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')# 初始化PaddleOCR - 使用在线模型路径ocr = PaddleOCR(use_textline_orientation=True,lang="ch")def read_pdf(file_path: str) -> str: """ 读取PDF所有可见文本 + 图片中的文字(使用PaddleOCR) """ full_text = "" try: doc = fitz.open(file_path) for page_num in range(len(doc)): print(f"正在处理第 {page_num + 1} 页...") page = doc.load_page(page_num) # 1. 提取页面文本层(数字原生PDF) page_text = page.get_text() if page_text.strip(): full_text += page_text + "\n" # 2. 提取页面中的每张图片并OCR image_list = page.get_images(full=True) for img_index, img in enumerate(image_list): xref = img[0] base_image = doc.extract_image(xref) image_bytes = base_image["image"] # 将图片字节转为PIL Image,再转为numpy数组(PaddleOCR接受的格式) pil_img = Image.open(io.BytesIO(image_bytes)) # 确保是RGB模式 if pil_img.mode != 'RGB': pil_img = pil_img.convert('RGB') img_np = np.array(pil_img) # OCR识别(使用predict方法避免警告) result = ocr.predict(img_np) if result and result[0]: # result可能是None或[[]] for line in result[0]: # 注意:ocr.predict返回嵌套列表 if isinstance(line, list) and len(line) > 1 and isinstance(line[1], list): text = line[1][0] # line[1][0]是识别文本 full_text += text + "\n" doc.close() except Exception as e: print(f"读取PDF文件出错: {e}") return full_text.strip()# 读取txt文件def read_txt(file_path: str) -> str: text = "" try: with open(file_path, 'r', encoding='utf-8') as file: text = file.read() except Exception as e: print(f"读取txt文件出错: {e}") return text# 读取docx文件def read_docx(file_path: str) -> str: text = "" try: doc = Document(file_path) for paragraph in doc.paragraphs: text += paragraph.text + "\n" except Exception as e: print(f"读取docx文件出错: {e}") return text# 读取文件内容根据文件类型def read_file(file_path: str) -> str: ext = os.path.splitext(file_path)[1].lower() if ext == '.pdf': return read_pdf(file_path) elif ext == '.txt': return read_txt(file_path) elif ext == '.docx': return read_docx(file_path) else: print(f"不支持的文件类型: {ext}") return ""# 文档分块函数def chunk_document(text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]: chunks = [] start = 0 while start < len(text): end = start + chunk_size chunk = text[start:end] chunks.append(chunk) start = end - overlap return chunks# 从文档库文件夹读取所有支持的文件并分块def load_documents_from_folder(folder_path: str) -> List[str]: documents = [] supported_extensions = ['.pdf', '.txt', '.docx'] for root, dirs, files in os.walk(folder_path): for file in files: ext = os.path.splitext(file)[1].lower() if ext in supported_extensions: file_path = os.path.join(root, file) print(f"正在读取文件: {file_path}") content = read_file(file_path) if content: # 对文档进行分块 chunks = chunk_document(content) documents.extend(chunks) return documents# 模拟文档库 - 现在从文件夹加载并分块documents = load_documents_from_folder('文档库')print(f"文档分块完成,共{len(documents)}个片段")# 构建向量索引def build_vector_index(docs: List[str]) -> faiss.IndexFlatIP: embeddings = embedding_model.encode(docs) dimension = embeddings.shape[1] index = faiss.IndexFlatIP(dimension) faiss.normalize_L2(embeddings) index.add(embeddings.astype('float32')) return index# 检索相关文档def retrieve_documents(query: str, index: faiss.IndexFlatIP, docs: List[str], top_k: int = 3) -> List[str]: query_embedding = embedding_model.encode([query]) faiss.normalize_L2(query_embedding) distances, indices = index.search(query_embedding.astype('float32'), top_k) return [docs[i] for i in indices[0]]# 调用DeepSeek API生成回答def generate_answer(prompt: str) -> str: headers = { "Authorization": f"Bearer {DEEPSEEK_API_KEY}", "Content-Type": "application/json" } payload = { "model": "deepseek-chat", "messages": [{"role": "user", "content": prompt}], "stream": False } response = requests.post(DEEPSEEK_API_URL, headers=headers, data=json.dumps(payload)) if response.status_code == 200: result = response.json() return result["choices"][0]["message"]["content"] else: raise Exception(f"请求失败:{response.status_code}, {response.text}")# 主函数实现RAG流程def rag_pipeline(query: str, session_history: list, index: faiss.IndexFlatIP): print("检索相关文档...") retrieved_docs = retrieve_documents(query, index, documents) context = "\n".join(retrieved_docs) # 构建包含会话历史的prompt history_text = "\n".join([f"用户:{h['query']}\n助手:{h['answer']}" for h in session_history]) if history_text: prompt = f"根据以下内容和历史对话回答问题:\n{context}\n\n历史对话:\n{history_text}\n\n问题:{query}" else: prompt = f"根据以下内容回答问题:\n{context}\n\n问题:{query}" print("生成回答中...") answer = generate_answer(prompt) print("回答:", answer) return answerif __name__ == "__main__": print("构建向量索引...") index = build_vector_index(documents) print("向量索引构建完成,可以开始提问了!") session_history = [] while True: user_query = input("\n请输入你的问题(输入'退出'结束会话):") if user_query.strip() == "退出": print("会话结束,谢谢使用!") break answer = rag_pipeline(user_query, session_history, index) session_history.append({"query": user_query, "answer": answer}) # 限制历史记录长度,避免prompt过长 if len(session_history) > 5: session_history = session_history[-5:]