Author InformationName: Shutter Zor(左祥太)Email: Shutter_Z@outlook.com
在前段时期,我出了一期使用大语言模型做文本分类任务的推文:「Python」用大语言模型做文本分类,以Qwen为例。
然而,使用 Qwen 在已经发表出来的文章中实际上不算是最主流的做法。现在的文章还是以 BERT 进行文本分类居多。当然,计算机等其他领域也有使用零样本分类(Zero-shot Classification)的方法。
所以,本期推文就来稍微介绍一下如何使用 BERT 做文本分类任务。
为什么需要计算机协助做文本分类?因为人力不可及。条件允许的话雇佣多人阅读文本当然是理想的,但同样地,经费是“理想般爆炸的”,结果的可复现性也是较差的。所以,我们需要一种简单的、经济节约型的方法来让计算机帮我们实现简单的分类工作。
使用 BERT 做文本分析,通俗来说就是,我先人工标注少量样本(1 或 0,或者多分类也可以)。然后我让 BERT 学习我的标注特征(微调)。再然后用微调后的 BERT 去预测剩下的大部分未被标注的文本应该属于何种标签。
所以在这里,我们的任务就被简化成为了三步:
我们现在假设有这样一个任务,即从政府采购合同中识别哪些采购项目属于固定资产。我已经预先抽出了 100 条,并完成了人工标签工作,大致如下:
text,label怀集县自然资源局资产评估服务定点议价采购合同。怀集县自然资源局资产评估服务定点采购。资产及其他评估服务,0吹扫捕集-气相色谱-质谱联用仪器。哈尔滨工业大学吹扫捕集-气相色谱-质谱联用仪器。吹扫捕集-气相色谱-质谱联用仪器,1普宁市里湖镇第三小学复印纸直接订购采购合同。普宁市里湖镇第三小学采购订单。复印纸,0鄂温克族自治旗文化旅游事业发展中心台式计算机直接订购采购合同。鄂温克族自治旗文化旅游事业发展中心采购订单。台式计算机,1梅州市市场监督管理局(梅州市知识产权局)印刷服务定点议价采购合同。梅州市市场监督管理局(梅州市知识产权局)印刷服务定点采购。印刷服务,0白银市应急成品粮储备承储协议。应急成品粮油储备服务001标段(二次)。应急成品粮油储备服务,0广东工贸职业技术学院便携式计算机直接订购采购合同。广东工贸职业技术学院采购订单。便携式计算机,1第一列是采购合同的内容。第二列是我人工打的标签:标记为 1 的表示该采购为固定资产;标记为 0 的表示该采购不是固定资产。
接下来我们可以通过以下代码实现 BERT 的微调与微调后模型的保存。
以下代码使用 CPU 版 torch,可在大多数电脑上运行
import torchimport pandas as pdimport numpy as npfrom torch.utils.data import Dataset, DataLoaderfrom torch.optim import AdamWfrom transformers import BertTokenizer, BertForSequenceClassificationfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import classification_report, confusion_matrix# --- 1. 配置参数 ---MODEL_PATH = "./chinese-bert-wwm-ext"# 本地模型文件夹DATA_PATH = "./Full_100.csv"# 你的CSV文件SAVE_PATH = "./my_finetuned_bert"DEVICE = torch.device("cpu")MAX_LEN = 128BATCH_SIZE = 4EPOCHS = 5LR = 5e-6# 针对小数据集使用极低学习率# --- 2. 加载数据 ---defload_and_prep_data(file_path):# 读取CSV df = pd.read_csv(file_path) text_col, label_col = 'text', 'label' df = df.dropna(subset=[text_col, label_col])# 1. 定义你的语义化映射# 强制确保 key 是数字,value 是字符串 mapping = {0: "非固定资产",1: "固定资产" }# 2. 检查数据中的标签是否在映射范围内# 如果 CSV 里的 label 是字符串形式的 "0", "1",先转成 int df[label_col] = df[label_col].astype(int)# 3. 生成 label2id 和 id2label# 即使数据里只有 0 或 1,我们也根据 mapping 完整生成 unique_ids = sorted(df[label_col].unique().tolist())# 核心:根据你的定义生成字典 id2label = {i: mapping[i] for i in unique_ids} label2id = {v: k for k, v in id2label.items()}# 4. 准备训练用的 label_id 列 df['label_id'] = df[label_col] # 已经是数字 0 和 1 了 print("--- 原始数据分布 ---")# 打印时显示中文名,方便确认 stats = df[label_col].map(mapping).value_counts() print(stats)return train_test_split(df, test_size=0.2, random_state=42), label2id, id2label, text_col# 执行加载(train_df, test_df), label2id, id2label, TEXT_COL = load_and_prep_data(DATA_PATH)print(f"\n检查标签映射结果: {id2label}")# --- 3. 构建 Dataset ---classTextDataset(Dataset):def__init__(self, texts, labels, tokenizer, max_len): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.max_len = max_lendef__len__(self):return len(self.texts)def__getitem__(self, item): encoding = self.tokenizer.encode_plus( str(self.texts[item]), add_special_tokens=True, max_length=self.max_len, padding='max_length', truncation=True, return_tensors='pt' )return {'input_ids': encoding['input_ids'].flatten().long(),'attention_mask': encoding['attention_mask'].flatten().long(),'labels': torch.tensor(self.labels[item], dtype=torch.long) }tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)train_loader = DataLoader( TextDataset(train_df[TEXT_COL].tolist(), train_df['label_id'].tolist(), tokenizer, MAX_LEN), batch_size=BATCH_SIZE, shuffle=True)test_loader = DataLoader( TextDataset(test_df[TEXT_COL].tolist(), test_df['label_id'].tolist(), tokenizer, MAX_LEN), batch_size=BATCH_SIZE, shuffle=False)# --- 4. 初始化模型 ---model = BertForSequenceClassification.from_pretrained(MODEL_PATH, num_labels=len(label2id))model.to(DEVICE)optimizer = AdamW(model.parameters(), lr=LR)# --- 5. 训练循环 ---print(f"\n开始训练... 目标类别数: {len(label2id)}")for epoch in range(EPOCHS): model.train() total_loss = 0for batch in train_loader: optimizer.zero_grad() input_ids = batch['input_ids'].to(DEVICE) attention_mask = batch['attention_mask'].to(DEVICE) labels = batch['labels'].to(DEVICE) outputs = model(input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss.backward() optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(train_loader) print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_loss:.4f}")# --- 6. 保存模型 ---model.save_pretrained(SAVE_PATH)tokenizer.save_pretrained(SAVE_PATH)完成训练后,我们可以查看模型的效果:
# --- 7. 评估模型 (指标与混淆矩阵) ---defevaluate_model(model, loader, device, id2label_dict): model.eval() all_preds = [] all_labels = [] print("\n正在对测试集进行评估...")with torch.no_grad():for batch in loader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) outputs = model(input_ids, attention_mask=attention_mask)# 获取概率最大的索引 preds = torch.argmax(outputs.logits, dim=1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.cpu().numpy())# 获取实际出现的类别索引,并转为对应的文本标签# 这样做是为了防止某些类别在测试集中没出现而导致报错 unique_ids = sorted(list(set(all_labels) | set(all_preds))) display_names = [id2label_dict[i] for i in unique_ids] print("\n" + "="*50) print(" 分类报告 (Metrics Report)") print("="*50)# 输出 Precision, Recall, F1, Accuracy print(classification_report( all_labels, all_preds, labels=unique_ids, target_names=display_names, zero_division=0 )) print("\n" + "="*50) print(" 混淆矩阵 (Confusion Matrix)") print("="*50)# 计算混淆矩阵 cm = confusion_matrix(all_labels, all_preds, labels=unique_ids)# 转为 DataFrame 格式,方便阅读(行是真实标签,列是预测标签) cm_df = pd.DataFrame(cm, index=[f"真实:{n}"for n in display_names], columns=[f"预测:{n}"for n in display_names]) print(cm_df) print("="*50)# 执行评估# 注意:直接传入你代码里的 id2label 字典evaluate_model(model, test_loader, DEVICE, id2label)结果如下:
正在对测试集进行评估...================================================== 分类报告 (Metrics Report)================================================== precision recall f1-score support 非固定资产 1.000.730.8411 固定资产 0.751.000.869 accuracy 0.8520 macro avg 0.880.860.8520weighted avg 0.890.850.8520================================================== 混淆矩阵 (Confusion Matrix)================================================== 预测:非固定资产 预测:固定资产真实:非固定资产 83真实:固定资产 09==================================================我们重点关注固定资产对应的 precision、recall、f1-score,以及下方的 accuracy。这些结果在每次运行之后都会发生变动,这是很正常的,因为前面有一个打乱数据的随机抽取操作。
接下来我们可以对新的文本进行预测:
# --- 8. 预测新文本 ---defpredict(text): model.eval() inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_LEN, padding='max_length').to(DEVICE)with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=1) pred_id = torch.argmax(probs, dim=1).item()return id2label[pred_id], probs[0][pred_id].item()# 示例预测 1print("\n--- 预测演示 ---")sample_text = "厦门大学购置一批桌椅板凳"label, confidence = predict(sample_text)print(f"文本: {sample_text}\n结果: {label} (置信度: {confidence:.2%})")# 示例预测 2print("\n--- 预测演示 ---")sample_text = "厦门大学购置一项资产评估服务"label, confidence = predict(sample_text)print(f"文本: {sample_text}\n结果: {label} (置信度: {confidence:.2%})")程序返回结果:
--- 预测演示 ---文本: 厦门大学购置一批桌椅板凳结果: 固定资产 (置信度: 95.33%)--- 预测演示 ---文本: 厦门大学购置一项资产评估服务结果: 非固定资产 (置信度: 62.72%)预测结果相对准确。当然,还需要进一步结合你的文章内容选择是否需要以置信度作为预测的依据或者变量构建。
本期推文的代码与数据可以从如下链接获取:
BW-20260417链接: https://pan.baidu.com/s/1wG9txtFqHx2UpUzZP0jhYw 提取码: GWSL
上手简单,欢迎大家亲自尝试!