1. 多标签多分类概念
- • 定义:给一个样本贴上多个标签,但每个标签不再是简单的“是/否”,而是有多个可选的类别。
- • 直观理解:这相当于你同时在做多道“多选题”或“单选题”。
- 1. 颜色(红、蓝、绿、黑...)—— 这是一个多分类任务。
- 2. 款式(T恤、衬衫、卫衣...)—— 这是一个多分类任务。
- 3. 材质(棉、麻、丝...)—— 这是一个多分类任务。
- • 一个样本的结果可能是:
{颜色: 红色, 款式: T恤, 材质: 棉}。 - • 每个分支内部使用 Softmax 激活函数(因为每个属性内部是互斥的多分类)。
2. 与多标签二分类的不同
| 多标签二分类 (Common Multi-label) | |
|---|
| 标签含义 | | |
| 典型算法 | Binary Relevance, Classifier Chains | |
| 激活函数 | | |
| 损失函数 | Binary Cross Entropy (BCE) | 多个 Categorical Cross Entropy 之和 |
| 概率关系 | | |
3. 核心代码示例
当任务变成多标签多分类时,我们必须自己写分类头
分类头重写
import torchimport torch.nn as nnfrom transformers import AutoModel, AutoConfigclass BertForMultiTaskClassification(nn.Module): def __init__(self, model_name, num_tasks, num_classes_each): super().__init__() self.config = AutoConfig.from_pretrained(model_name) self.bert = AutoModel.from_pretrained(model_name) hidden_size = self.config.hidden_size self.dropout = nn.Dropout(0.1) self.heads = nn.ModuleList([ nn.Linear(hidden_size, num_classes_each) for _ in range(num_tasks) ]) self.loss_fct = nn.CrossEntropyLoss() def forward(self, input_ids=None, attention_mask=None, labels=None): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask ) pooled = outputs.last_hidden_state[:, 0] pooled = self.dropout(pooled) logits_list = [head(pooled) for head in self.heads] logits = torch.stack(logits_list, dim=1) loss = None if labels is not None: losses = [] for i in range(logits.size(1)): losses.append( self.loss_fct( logits[:, i, :], labels[:, i].long() ) ) loss = sum(losses) / len(losses) return {"loss": loss, "logits": logits}
模型训练函数
import osfrom transformers import AutoTokenizer, AutoModelForSequenceClassificationdef load_multilabel_model(model_dir, num_labels): """ 从本地目录加载用于多标签分类的预训练模型和分词器。 参数: - model_dir (str): 本地模型文件的存放路径。 - num_labels (int): 标签的总总数(即输出层的维度)。 返回: - tokenizer: 加载好的分词器实例。 - model: 配置好的多标签分类模型实例。 """ # 1. 加载分词器 tokenizer = AutoTokenizer.from_pretrained(model_dir) # 2. 加载模型并配置多标签任务 model = BertForMultiTaskClassification( model_dir, num_labels=num_labels, problem_type="multi_label_classification" ) return tokenizer, model
3.实验结果
将返回一个多标签多分类模型