"""
AI本地知识库 - 完整代码
功能:上传文档 → 自动切分 → 向量化 → 智能问答
运行:streamlit run app.py
"""
import os
import streamlit as st
from sentence_transformers import SentenceTransformer
from chromadb import PersistentClient
import ollama
from PyPDF2 import PdfReader
# ========== 配置区 ==========
MODEL_NAME = "qwen2.5:7b"
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
CHUNK_SIZE = 500
CHUNK_OVERLAP = 100
TOP_K = 3
# ========== 初始化 ==========
@st.cache_resource
def init_components():
embedder = SentenceTransformer(EMBEDDING_MODEL)
client = PersistentClient(path="./chroma_db")
collection = client.get_or_create_collection(
name="knowledge_base",
metadata={"hnsw:space": "cosine"}
)
return embedder, collection
# ========== 文档处理 ==========
def split_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
"""把长文档切成小段,保持上下文连贯"""
chunks = []
start = 0
while start < len(text):
chunks.append(text[start:start+chunk_size])
start = start + chunk_size - overlap
return chunks
def process_file(file_path, embedder, collection):
"""处理文档:读取→切分→向量化→存储"""
if file_path.name.endswith(".pdf"):
reader = PdfReader(file_path)
text = "\n".join([p.extract_text() for p in reader.pages])
else:
text = file_path.read().decode("utf-8")
chunks = split_text(text)
embeddings = embedder.encode(chunks).tolist()
ids = [f"{file_path.name}_{i}" for i in range(len(chunks))]
collection.upsert(
ids=ids,
embeddings=embeddings,
documents=chunks,
metadatas=[{"source": file_path.name}]*len(chunks)
)
return len(chunks)
# ========== 智能问答(RAG) ==========
def ask_question(question, embedder, collection):
q_emb = embedder.encode([question]).tolist()
results = collection.query(query_embeddings=q_emb, n_results=TOP_K)
context = "\n\n".join(results["documents"][0])
prompt = f"""请根据以下知识库内容回答问题。如果知识库中没有相关信息,请如实告知。
【知识库内容】
{context}
【用户问题】
{question}
请分点详细回答:"""
response = ollama.chat(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}]
)
return response["message"]["content"]
# ========== 界面 ==========
def main():
st.set_page_config(page_title="AI本地知识库", page_icon="📚")
st.title("📚 AI本地知识库")
st.caption("数据完全本地 · 安全可控 · 零成本")
embedder, collection = init_components()
with st.sidebar:
st.header("📂 知识库管理")
uploaded = st.file_uploader("上传文档", type=["txt", "pdf", "md"])
if uploaded:
n = process_file(uploaded, embedder, collection)
st.success(f"✅ 导入了 {n} 个知识片段")
st.caption(f"📊 当前 {collection.count()} 个片段")
q = st.text_input("💬 输入问题")
if q:
with st.spinner("🤔 检索知识库 + AI生成中..."):
answer = ask_question(q, embedder, collection)
st.markdown("### 🤖 回答")
st.markdown(answer)
if __name__ == "__main__":
main()