feat(train): 添加训练脚本并重构模型输入处理逻辑
This commit is contained in:
parent
1af85a36bc
commit
69349a88a6
|
|
@ -25,6 +25,9 @@ dependencies = [
|
||||||
"typer>=0.21.1",
|
"typer>=0.21.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
train-model = "model.trainer:app"
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
# 设置当前项目的默认索引源
|
# 设置当前项目的默认索引源
|
||||||
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
from modelscope import AutoModel
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- 注意力池化模块----------------------------
|
# ---------------------------- 注意力池化模块----------------------------
|
||||||
|
|
@ -25,9 +25,8 @@ class AttentionPooling(nn.Module):
|
||||||
pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
|
pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
|
||||||
return pooled
|
return pooled
|
||||||
|
|
||||||
# ---------------------------- 残差块 ----------------------------
|
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------- 残差块 ----------------------------
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, dim, dropout_prob=0.3):
|
def __init__(self, dim, dropout_prob=0.3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -92,9 +91,12 @@ class ContextEncoder(nn.Module):
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
# Embeddings
|
# Embeddings
|
||||||
self.text_emb = nn.Embedding(vocab_size, dim)
|
self.text_emb = AutoModel.from_pretrained(
|
||||||
|
"iic/nlp_structbert_backbone_lite_std"
|
||||||
|
).embeddings
|
||||||
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim)
|
||||||
self.pos_emb = nn.Embedding(max_len, dim)
|
self.pos_emb = nn.Embedding(max_len, dim)
|
||||||
|
self.pinyin_pooling = AttentionPooling(dim)
|
||||||
|
|
||||||
# Transformer Encoder (4 layers, 4 heads) [1]
|
# Transformer Encoder (4 layers, 4 heads) [1]
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
|
@ -118,9 +120,22 @@ class ContextEncoder(nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
H: [batch, seq_len, 512] Context representation [1]
|
H: [batch, seq_len, 512] Context representation [1]
|
||||||
"""
|
"""
|
||||||
# 1. Embedding Fusion: Text + Pinyin + Position
|
# 1. Embed text
|
||||||
|
text_emb = self.text_emb(text_ids) # [B, 128, dim]
|
||||||
|
|
||||||
|
# 2. Embed and pool pinyin to global feature
|
||||||
|
pinyin_emb = self.pinyin_emb(pinyin_ids) # [B, 24, dim]
|
||||||
|
# 方式1:Attention Pooling(推荐)
|
||||||
|
pinyin_global = self.pinyin_pooling(
|
||||||
|
pinyin_emb, mask=None
|
||||||
|
) # [B, dim] # 1. Embedding Fusion: Text + Pinyin + Position
|
||||||
|
|
||||||
|
# Broadcast pinyin to all text positions
|
||||||
|
pinyin_global = pinyin_global.unsqueeze(1) # [B, 1, dim]
|
||||||
|
pinyin_broadcast = pinyin_global.expand_as(text_emb) # [B, 128, dim]
|
||||||
|
|
||||||
# 策略:拼音作为增强特征叠加到文本上,符合轻量级设计
|
# 策略:拼音作为增强特征叠加到文本上,符合轻量级设计
|
||||||
x = self.text_emb(text_ids) + self.pinyin_emb(pinyin_ids)
|
x = text_emb + pinyin_broadcast
|
||||||
|
|
||||||
seq_len = x.size(1)
|
seq_len = x.size(1)
|
||||||
pos_ids = (
|
pos_ids = (
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,42 @@
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from loguru import logger
|
|
||||||
from modelscope import AutoTokenizer
|
from modelscope import AutoTokenizer
|
||||||
from pypinyin import Style, lazy_pinyin
|
from pypinyin import lazy_pinyin
|
||||||
from pypinyin.contrib.tone_convert import to_initials
|
from pypinyin.contrib.tone_convert import to_initials
|
||||||
from torch.utils.data import DataLoader, IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
from .query import QueryEngine
|
from .query import QueryEngine
|
||||||
|
|
||||||
|
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||||
|
|
||||||
|
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
|
||||||
|
CHAR_TO_ID["`"] = 27 # 显式添加反引号
|
||||||
|
CHAR_TO_ID["'"] = 28 # 显式添加引号
|
||||||
|
CHAR_TO_ID["-"] = 29 # 显式添加短横
|
||||||
|
|
||||||
|
|
||||||
|
def text_to_pinyin_ids(pinyin_str: str) -> List[int]:
|
||||||
|
"""
|
||||||
|
将拼音字符串转换为 ID 列表。
|
||||||
|
支持 a-z 和 `。
|
||||||
|
未知字符映射为 0 (PAD/UNK)。
|
||||||
|
"""
|
||||||
|
# 使用 dict.get(key, default) 处理未知字符,默认返回 0
|
||||||
|
return [CHAR_TO_ID.get(c, 0) for c in pinyin_str]
|
||||||
|
|
||||||
|
|
||||||
class PinyinInputDataset(IterableDataset):
|
class PinyinInputDataset(IterableDataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_path: str,
|
data_path: str,
|
||||||
max_workes: int = -1,
|
max_workers: int = -1,
|
||||||
max_iter_length=1e6,
|
max_iter_length=1e6,
|
||||||
max_seq_length=128,
|
max_seq_length=128,
|
||||||
text_field: str = "text",
|
text_field: str = "text",
|
||||||
|
|
@ -35,7 +52,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
self.min_freq = 109
|
self.min_freq = 109
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
Path(files(__package__) / "assets" / "tokenizer")
|
Path(str(files(__package__))) / "assets" / "tokenizer"
|
||||||
)
|
)
|
||||||
self.data_path = data_path
|
self.data_path = data_path
|
||||||
|
|
||||||
|
|
@ -43,7 +60,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
self.max_seq_length = max_seq_length
|
self.max_seq_length = max_seq_length
|
||||||
self.text_field = text_field
|
self.text_field = text_field
|
||||||
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
||||||
self.max_workers = max_workes
|
self.max_workers = max_workers
|
||||||
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
||||||
self.shuffle_buffer_size = shuffle_buffer_size
|
self.shuffle_buffer_size = shuffle_buffer_size
|
||||||
self.possible_lengths = list(length_weights.keys())
|
self.possible_lengths = list(length_weights.keys())
|
||||||
|
|
@ -98,7 +115,56 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
# 生成对应文本的拼音
|
# 生成对应文本的拼音
|
||||||
def generate_pinyin(self, text: str) -> List[str]:
|
def generate_pinyin(self, text: str) -> List[str]:
|
||||||
return lazy_pinyin(text, errors=lambda x: [c for c in x])
|
"""
|
||||||
|
流式处理单条文本,转换为拼音列表。
|
||||||
|
|
||||||
|
特性:
|
||||||
|
1. 严格一一对应:len(result) == len(text)
|
||||||
|
2. 高多音字准确率:利用 pypinyin 内部的词语分词能力
|
||||||
|
3. 高性能:预分配内存,无多余对象创建
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 输入字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 拼音或非汉字字符的列表
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
text_len = len(text)
|
||||||
|
# 2. 预分配结果列表,初始化占位符。
|
||||||
|
# 使用 None 或空字符串均可,这里用空字符串方便后续判断
|
||||||
|
result: List[str] = [""] * text_len
|
||||||
|
|
||||||
|
# 3. 遍历所有连续汉字片段
|
||||||
|
for match in _HANZI_RE.finditer(text):
|
||||||
|
start_idx = match.start()
|
||||||
|
hanzi_segment = match.group()
|
||||||
|
|
||||||
|
# 4. 核心转换:利用 pypinyin 的分词能力处理该片段
|
||||||
|
# style=Style.NORMAL 获取不带声调的拼音
|
||||||
|
pinyin_list = lazy_pinyin(hanzi_segment)
|
||||||
|
|
||||||
|
# 5. 健壮性兜底:
|
||||||
|
# 正常情况下,pypinyin 返回的拼音数应等于汉字数。
|
||||||
|
# 若不等(极罕见,如遇到特殊 Unicode 标点被误判为汉字),降级为单字转换
|
||||||
|
if len(pinyin_list) != len(hanzi_segment):
|
||||||
|
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
|
||||||
|
|
||||||
|
# 6. 直接通过索引填充到预分配的位置
|
||||||
|
# 这比 list slicing assignment (result[start:end] = pinyin_list) 略快且更直观
|
||||||
|
for i, py in enumerate(pinyin_list):
|
||||||
|
result[start_idx + i] = py
|
||||||
|
|
||||||
|
# 7. 填充非汉字字符
|
||||||
|
# 遍历原文,如果 result 对应位置为空,则填入原字符
|
||||||
|
# 注意:对于纯汉字文本,这一步很快;对于混合文本,这是必要的
|
||||||
|
for i, char in enumerate(text):
|
||||||
|
if not result[i]:
|
||||||
|
result[i] = char
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
# 生成需要预测汉字对应的拼音,并进行加强
|
# 生成需要预测汉字对应的拼音,并进行加强
|
||||||
def get_mask_pinyin(
|
def get_mask_pinyin(
|
||||||
|
|
@ -107,7 +173,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
mask_pinyin = []
|
mask_pinyin = []
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
if not self.query_engine.is_chinese_char(text[i]):
|
if not self.query_engine.is_chinese_char(text[i]):
|
||||||
return i, mask_pinyin
|
break
|
||||||
else:
|
else:
|
||||||
py = np.random.choice(
|
py = np.random.choice(
|
||||||
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
||||||
|
|
@ -116,7 +182,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if py == "":
|
if py == "":
|
||||||
py = pinyin_list[i][0]
|
py = pinyin_list[i][0]
|
||||||
mask_pinyin.append(py)
|
mask_pinyin.append(py)
|
||||||
return len(text) - 1, mask_pinyin
|
return len(mask_pinyin), mask_pinyin
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
|
@ -165,7 +231,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if current_iter_index >= worker_quota:
|
if current_iter_index >= worker_quota:
|
||||||
break
|
break
|
||||||
|
|
||||||
labels = [] # 添加起始符
|
labels = []
|
||||||
# 如果text[i]不在字符库中,则跳过
|
# 如果text[i]不在字符库中,则跳过
|
||||||
# 当i小于48时候,则将part1取text[0:i]
|
# 当i小于48时候,则将part1取text[0:i]
|
||||||
# 当i大于48时候,则将part1取text[i-48:i]
|
# 当i大于48时候,则将part1取text[i-48:i]
|
||||||
|
|
@ -186,7 +252,19 @@ class PinyinInputDataset(IterableDataset):
|
||||||
pinyin_len, part2 = self.get_mask_pinyin(
|
pinyin_len, part2 = self.get_mask_pinyin(
|
||||||
text[i:py_end], pinyin_list[i:py_end]
|
text[i:py_end], pinyin_list[i:py_end]
|
||||||
)
|
)
|
||||||
part2 = "".join(part2)
|
|
||||||
|
split_char = np.random.choice(
|
||||||
|
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
|
||||||
|
)
|
||||||
|
|
||||||
|
part2 = split_char.join(part2)
|
||||||
|
pinyin_ids = text_to_pinyin_ids(part2)
|
||||||
|
len_py = len(pinyin_ids)
|
||||||
|
if len_py < 24:
|
||||||
|
pinyin_ids.extend([0] * (24 - len_py))
|
||||||
|
else:
|
||||||
|
pinyin_ids = pinyin_ids[:24]
|
||||||
|
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long)
|
||||||
|
|
||||||
# part3为文本,大概率(0.70)为空
|
# part3为文本,大概率(0.70)为空
|
||||||
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
|
# 不为空则是i+pinyin_len所指向的字符以及所指向字符后x个字符
|
||||||
|
|
@ -219,7 +297,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
string_list.append(text[start_pos:end_pos])
|
string_list.append(text[start_pos:end_pos])
|
||||||
# 用|连接所有字符串
|
# 用|连接所有字符串
|
||||||
part4 = "|".join(string_list)
|
part4 = "|".join(string_list)
|
||||||
|
|
||||||
labels = [
|
labels = [
|
||||||
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
||||||
for c, p in zip(
|
for c, p in zip(
|
||||||
|
|
@ -240,9 +317,9 @@ class PinyinInputDataset(IterableDataset):
|
||||||
samples = []
|
samples = []
|
||||||
for i, label in enumerate(labels):
|
for i, label in enumerate(labels):
|
||||||
repeats = self.adjust_frequency(label)
|
repeats = self.adjust_frequency(label)
|
||||||
l = labels[:i]
|
masked_labels = labels[:i]
|
||||||
len_l = len(l)
|
len_l = len(masked_labels)
|
||||||
l.extend([0] * (8 - len_l))
|
masked_labels.extend([0] * (8 - len_l))
|
||||||
|
|
||||||
samples.extend(
|
samples.extend(
|
||||||
[
|
[
|
||||||
|
|
@ -252,11 +329,12 @@ class PinyinInputDataset(IterableDataset):
|
||||||
"attention_mask": encoded["attention_mask"],
|
"attention_mask": encoded["attention_mask"],
|
||||||
"label": torch.tensor([label], dtype=torch.long),
|
"label": torch.tensor([label], dtype=torch.long),
|
||||||
"history_slot_ids": torch.tensor(
|
"history_slot_ids": torch.tensor(
|
||||||
l, dtype=torch.long
|
masked_labels, dtype=torch.long
|
||||||
),
|
),
|
||||||
"prefix": f"{part4}^{part1}",
|
"prefix": f"{part4}^{part1}",
|
||||||
"suffix": part3,
|
"suffix": part3,
|
||||||
"pinyin": part2,
|
"pinyin": part2,
|
||||||
|
"pinyin_ids": pinyin_ids,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
* repeats
|
* repeats
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
# 导入 components.py 中的组件
|
# 导入 components.py 中的组件
|
||||||
from .components import (
|
from .components import (
|
||||||
AttentionPooling, # 可选,暂不使用
|
|
||||||
ContextEncoder,
|
ContextEncoder,
|
||||||
CrossAttentionFusion,
|
CrossAttentionFusion,
|
||||||
MoELayer,
|
MoELayer,
|
||||||
|
|
@ -41,19 +40,17 @@ class InputMethodEngine(nn.Module):
|
||||||
n_heads: int = 4, # 注意力头数
|
n_heads: int = 4, # 注意力头数
|
||||||
num_experts: int = 20, # MoE 专家数量
|
num_experts: int = 20, # MoE 专家数量
|
||||||
max_seq_len: int = 128, # 最大上下文长度
|
max_seq_len: int = 128, # 最大上下文长度
|
||||||
use_pinyin: bool = False, # 是否使用拼音特征(若为 False,拼音嵌入恒为零)
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.num_slots = num_slots
|
self.num_slots = num_slots
|
||||||
self.use_pinyin = use_pinyin
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
# 1. 上下文编码器 (ContextEncoder)
|
# 1. 上下文编码器 (ContextEncoder)
|
||||||
# 若 use_pinyin=False,则传入 pinyin_vocab_size=1 并固定嵌入为零
|
# 若 use_pinyin=False,则传入 pinyin_vocab_size=1 并固定嵌入为零
|
||||||
self.context_encoder = ContextEncoder(
|
self.context_encoder = ContextEncoder(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
pinyin_vocab_size=pinyin_vocab_size if use_pinyin else 1,
|
pinyin_vocab_size=pinyin_vocab_size,
|
||||||
dim=dim,
|
dim=dim,
|
||||||
n_layers=n_layers,
|
n_layers=n_layers,
|
||||||
n_heads=n_heads,
|
n_heads=n_heads,
|
||||||
|
|
@ -79,16 +76,12 @@ class InputMethodEngine(nn.Module):
|
||||||
# 5. 分类头
|
# 5. 分类头
|
||||||
self.classifier = nn.Linear(dim, vocab_size)
|
self.classifier = nn.Linear(dim, vocab_size)
|
||||||
|
|
||||||
# 可选:如果不需要拼音,将拼音嵌入矩阵固定为零
|
|
||||||
if not use_pinyin and hasattr(self.context_encoder, "pinyin_emb"):
|
|
||||||
# 将拼音嵌入权重置零,确保对输出无影响
|
|
||||||
nn.init.zeros_(self.context_encoder.pinyin_emb.weight)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
token_type_ids: torch.Tensor,
|
token_type_ids: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
pinyin_ids: torch.Tensor,
|
||||||
history_slot_ids: torch.Tensor,
|
history_slot_ids: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
@ -100,28 +93,20 @@ class InputMethodEngine(nn.Module):
|
||||||
if history_slot_ids.dim() == 1:
|
if history_slot_ids.dim() == 1:
|
||||||
history_slot_ids = history_slot_ids.unsqueeze(0).expand(batch_size, -1)
|
history_slot_ids = history_slot_ids.unsqueeze(0).expand(batch_size, -1)
|
||||||
|
|
||||||
# 1. 构造拼音输入(如果 use_pinyin=False,则使用全零占位符)
|
# 1. 上下文编码 -> H [batch, seq_len, dim]
|
||||||
if self.use_pinyin:
|
|
||||||
# 注意:这里需要真实的拼音 ids,但当前输入未提供,故用零占位(实际应用中应从外部获取)
|
|
||||||
# 为简化演示,此处使用全零张量,并假设拼音词汇表中 0 为 padding。
|
|
||||||
pinyin_ids = torch.zeros_like(input_ids)
|
|
||||||
else:
|
|
||||||
pinyin_ids = torch.zeros_like(input_ids)
|
|
||||||
|
|
||||||
# 2. 上下文编码 -> H [batch, seq_len, dim]
|
|
||||||
# 注意:ContextEncoder.forward 接受 text_ids, pinyin_ids, mask
|
# 注意:ContextEncoder.forward 接受 text_ids, pinyin_ids, mask
|
||||||
H = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
|
H = self.context_encoder(input_ids, pinyin_ids, mask=attention_mask)
|
||||||
|
|
||||||
# 3. 槽位记忆编码 -> S [batch, num_slots, dim]
|
# 2. 槽位记忆编码 -> S [batch, num_slots, dim]
|
||||||
S = self.slot_memory(history_slot_ids) # history_slot_ids: [batch, num_slots]
|
S = self.slot_memory(history_slot_ids) # history_slot_ids: [batch, num_slots]
|
||||||
|
|
||||||
# 4. 交叉注意力融合 (使用 CrossAttentionFusion)
|
# 3. 交叉注意力融合 (使用 CrossAttentionFusion)
|
||||||
fused = self.cross_attn(S, H, context_mask=attention_mask)
|
fused = self.cross_attn(S, H, context_mask=attention_mask)
|
||||||
|
|
||||||
# 5. MoE 处理 -> [batch, num_slots, dim]
|
# 4. MoE 处理 -> [batch, num_slots, dim]
|
||||||
moe_out = self.moe(fused)
|
moe_out = self.moe(fused)
|
||||||
|
|
||||||
# 6. 池化与分类:对槽位维度求平均(或使用 mask 池化)
|
# 5. 池化与分类:对槽位维度求平均(或使用 mask 池化)
|
||||||
# 这里简单平均,若需要忽略 padding 槽位,可根据 history_slot_ids 是否为 0 构造 mask
|
# 这里简单平均,若需要忽略 padding 槽位,可根据 history_slot_ids 是否为 0 构造 mask
|
||||||
slot_mask = (history_slot_ids != 0).float() # [batch, num_slots]
|
slot_mask = (history_slot_ids != 0).float() # [batch, num_slots]
|
||||||
slot_mask = slot_mask.unsqueeze(-1) # [batch, num_slots, 1]
|
slot_mask = slot_mask.unsqueeze(-1) # [batch, num_slots, 1]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,913 @@
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import typer
|
||||||
|
from loguru import logger
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.progress import (
|
||||||
|
BarColumn,
|
||||||
|
Progress,
|
||||||
|
SpinnerColumn,
|
||||||
|
TextColumn,
|
||||||
|
TimeElapsedColumn,
|
||||||
|
TimeRemainingColumn,
|
||||||
|
)
|
||||||
|
from rich.table import Table
|
||||||
|
from torch import autocast
|
||||||
|
from torch.amp.grad_scaler import GradScaler
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from .dataset import PinyinInputDataset
|
||||||
|
|
||||||
|
# 导入模型和数据
|
||||||
|
from .model import InputMethodEngine
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
"""
|
||||||
|
输入法模型训练器
|
||||||
|
|
||||||
|
实现训练InputMethodEngine模型,支持:
|
||||||
|
- 预热+余弦退火学习率调度
|
||||||
|
- TensorBoard日志记录
|
||||||
|
- AdamW优化器(weight_decay=0.1)
|
||||||
|
- 混合精度训练
|
||||||
|
- CrossEntropyLoss损失函数(支持weight和label_smoothing)
|
||||||
|
- Rich终端美化输出
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: InputMethodEngine,
|
||||||
|
train_dataloader: DataLoader,
|
||||||
|
eval_dataloader: DataLoader,
|
||||||
|
total_steps: int,
|
||||||
|
output_dir: str = "./output",
|
||||||
|
num_epochs: int = 10,
|
||||||
|
learning_rate: float = 1e-4,
|
||||||
|
min_learning_rate: float = 1e-6,
|
||||||
|
weight_decay: float = 0.1,
|
||||||
|
warmup_ratio: float = 0.1,
|
||||||
|
label_smoothing: float = 0.15,
|
||||||
|
loss_weight: Optional[torch.Tensor] = None,
|
||||||
|
grad_accum_steps: int = 1,
|
||||||
|
clip_grad_norm: float = 1.0,
|
||||||
|
eval_frequency: int = 500,
|
||||||
|
save_frequency: int = 1000,
|
||||||
|
mixed_precision: bool = True,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
use_tensorboard: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化训练器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 要训练的InputMethodEngine模型
|
||||||
|
train_dataloader: 训练数据加载器
|
||||||
|
eval_dataloader: 评估数据加载器(可选)
|
||||||
|
output_dir: 输出目录,用于保存模型和日志
|
||||||
|
num_epochs: 训练轮数
|
||||||
|
total_steps: 总训练步数,如果为None则根据epochs计算
|
||||||
|
learning_rate: 最大学习率(预热后)
|
||||||
|
min_learning_rate: 最小学习率(余弦退火后的最低值)
|
||||||
|
weight_decay: AdamW优化器的权重衰减
|
||||||
|
warmup_ratio: 热身步数占总步数的比例
|
||||||
|
label_smoothing: CrossEntropyLoss的标签平滑参数
|
||||||
|
loss_weight: CrossEntropyLoss的类别权重
|
||||||
|
grad_accum_steps: 梯度累积步数
|
||||||
|
clip_grad_norm: 梯度裁剪的最大范数
|
||||||
|
eval_frequency: 评估频率(步数)
|
||||||
|
save_frequency: 保存检查点频率(步数)
|
||||||
|
mixed_precision: 是否使用混合精度训练
|
||||||
|
device: 训练设备,如果为None则自动选择
|
||||||
|
use_tensorboard: 是否使用TensorBoard记录
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.train_dataloader = train_dataloader
|
||||||
|
self.eval_dataloader = list([i for i in eval_dataloader])
|
||||||
|
self.output_dir = Path(output_dir)
|
||||||
|
self.num_epochs = num_epochs
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.min_learning_rate = min_learning_rate
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.warmup_ratio = warmup_ratio
|
||||||
|
self.label_smoothing = label_smoothing
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
self.grad_accum_steps = grad_accum_steps
|
||||||
|
self.clip_grad_norm = clip_grad_norm
|
||||||
|
self.eval_frequency = eval_frequency
|
||||||
|
self.save_frequency = save_frequency
|
||||||
|
self.mixed_precision = mixed_precision
|
||||||
|
self.use_tensorboard = use_tensorboard
|
||||||
|
|
||||||
|
# 设置设备
|
||||||
|
if device is None:
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
else:
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# 移动模型到设备
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.checkpoint_dir = self.output_dir / "checkpoints"
|
||||||
|
self.checkpoint_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# 计算总步数
|
||||||
|
self.total_steps = total_steps
|
||||||
|
|
||||||
|
self.warmup_steps = int(self.total_steps * warmup_ratio)
|
||||||
|
|
||||||
|
# 初始化优化器
|
||||||
|
self.optimizer = optim.AdamW(
|
||||||
|
model.parameters(),
|
||||||
|
lr=learning_rate,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化损失函数
|
||||||
|
if loss_weight is not None:
|
||||||
|
self.criterion = nn.CrossEntropyLoss(
|
||||||
|
weight=loss_weight.to(self.device), label_smoothing=label_smoothing
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
||||||
|
|
||||||
|
# 初始化混合精度训练器
|
||||||
|
device_type = "cuda" if "cuda" in str(self.device) else "cpu"
|
||||||
|
self.scaler = GradScaler(device_type, enabled=mixed_precision)
|
||||||
|
|
||||||
|
# 初始化TensorBoard
|
||||||
|
if use_tensorboard:
|
||||||
|
self.writer = SummaryWriter(log_dir=self.output_dir / "tensorboard")
|
||||||
|
else:
|
||||||
|
self.writer = None
|
||||||
|
|
||||||
|
# 初始化Rich控制台
|
||||||
|
self.console = Console()
|
||||||
|
|
||||||
|
# 训练状态
|
||||||
|
self.current_step = 0
|
||||||
|
self.current_epoch = 0
|
||||||
|
self.best_eval_loss = float("inf")
|
||||||
|
|
||||||
|
# 学习率调度函数
|
||||||
|
self.lr_scheduler = self._create_lr_scheduler()
|
||||||
|
|
||||||
|
logger.info(f"Trainer initialized with device: {self.device}")
|
||||||
|
logger.info(
|
||||||
|
f"Total steps: {self.total_steps}, Warmup steps: {self.warmup_steps}"
|
||||||
|
)
|
||||||
|
logger.info(f"Learning rate: {learning_rate}, Weight decay: {weight_decay}")
|
||||||
|
|
||||||
|
def _create_lr_scheduler(self) -> Callable[[int], float]:
|
||||||
|
"""创建学习率调度函数(预热 + 余弦退火)"""
|
||||||
|
|
||||||
|
def lr_scheduler(step: int) -> float:
|
||||||
|
if step < self.warmup_steps:
|
||||||
|
# 线性预热
|
||||||
|
return self.learning_rate * (step / self.warmup_steps)
|
||||||
|
else:
|
||||||
|
# 余弦退火
|
||||||
|
progress = (step - self.warmup_steps) / (
|
||||||
|
self.total_steps - self.warmup_steps
|
||||||
|
)
|
||||||
|
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
|
||||||
|
decayed_lr = (
|
||||||
|
self.min_learning_rate
|
||||||
|
+ (self.learning_rate - self.min_learning_rate) * cosine_decay
|
||||||
|
)
|
||||||
|
return decayed_lr
|
||||||
|
|
||||||
|
return lr_scheduler
|
||||||
|
|
||||||
|
def _get_current_lr(self) -> float:
|
||||||
|
"""获取当前学习率"""
|
||||||
|
return self.lr_scheduler(self.current_step)
|
||||||
|
|
||||||
|
def _update_learning_rate(self):
|
||||||
|
"""更新优化器中的学习率"""
|
||||||
|
current_lr = self._get_current_lr()
|
||||||
|
for param_group in self.optimizer.param_groups:
|
||||||
|
param_group["lr"] = current_lr
|
||||||
|
return current_lr
|
||||||
|
|
||||||
|
def train_step(
|
||||||
|
self, batch: Dict[str, torch.Tensor]
|
||||||
|
) -> Tuple[float, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
执行单个训练步骤
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: 包含输入数据的批次
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
loss: 损失值
|
||||||
|
metrics: 训练指标字典
|
||||||
|
"""
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
# 移动数据到设备
|
||||||
|
input_ids = batch["input_ids"].to(self.device)
|
||||||
|
token_type_ids = batch["token_type_ids"].to(self.device)
|
||||||
|
attention_mask = batch["attention_mask"].to(self.device)
|
||||||
|
history_slot_ids = batch["history_slot_ids"].to(self.device)
|
||||||
|
pinyin_ids = batch["pinyin_ids"].to(self.device).squeeze(-1)
|
||||||
|
labels = batch["labels"].to(self.device).squeeze(-1) # [batch_size]
|
||||||
|
|
||||||
|
# 混合精度训练
|
||||||
|
with autocast(device_type=self.device.type, enabled=self.mixed_precision):
|
||||||
|
# 前向传播
|
||||||
|
logits = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
pinyin_ids=pinyin_ids,
|
||||||
|
history_slot_ids=history_slot_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算损失
|
||||||
|
loss = self.criterion(logits, labels)
|
||||||
|
loss = loss / self.grad_accum_steps
|
||||||
|
|
||||||
|
# 反向传播
|
||||||
|
self.scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"loss": loss.item() * self.grad_accum_steps,
|
||||||
|
"lr": self._get_current_lr(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 计算准确率
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = torch.argmax(logits, dim=-1)
|
||||||
|
correct = (preds == labels).sum().item()
|
||||||
|
total = labels.size(0)
|
||||||
|
metrics["accuracy"] = correct / total if total > 0 else 0.0
|
||||||
|
|
||||||
|
return loss.item() * self.grad_accum_steps, metrics
|
||||||
|
|
||||||
|
def evaluate(self) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
在评估集上评估模型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
评估指标字典
|
||||||
|
"""
|
||||||
|
if self.eval_dataloader is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
total_correct = 0
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in self.eval_dataloader:
|
||||||
|
# 移动数据到设备
|
||||||
|
input_ids = batch["input_ids"].to(self.device)
|
||||||
|
token_type_ids = batch["token_type_ids"].to(self.device)
|
||||||
|
attention_mask = batch["attention_mask"].to(self.device)
|
||||||
|
history_slot_ids = batch["history_slot_ids"].to(self.device)
|
||||||
|
labels = batch["label"].to(self.device).squeeze(1)
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
logits = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
history_slot_ids=history_slot_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算损失
|
||||||
|
loss = self.criterion(logits, labels)
|
||||||
|
total_loss += loss.item() * labels.size(0)
|
||||||
|
|
||||||
|
# 计算准确率
|
||||||
|
preds = torch.argmax(logits, dim=-1)
|
||||||
|
correct = (preds == labels).sum().item()
|
||||||
|
total_correct += correct
|
||||||
|
total_samples += labels.size(0)
|
||||||
|
|
||||||
|
avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
|
||||||
|
accuracy = total_correct / total_samples if total_samples > 0 else 0.0
|
||||||
|
|
||||||
|
return {"eval_loss": avg_loss, "eval_accuracy": accuracy}
|
||||||
|
|
||||||
|
def save_checkpoint(self, filename: str, is_best: bool = False):
|
||||||
|
"""
|
||||||
|
保存检查点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: 检查点文件名
|
||||||
|
is_best: 是否是最佳模型
|
||||||
|
"""
|
||||||
|
checkpoint_path = self.checkpoint_dir / filename
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
"step": self.current_step,
|
||||||
|
"epoch": self.current_epoch,
|
||||||
|
"model_state_dict": self.model.state_dict(),
|
||||||
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||||
|
"scaler_state_dict": self.scaler.state_dict(),
|
||||||
|
"best_eval_loss": self.best_eval_loss,
|
||||||
|
"config": {
|
||||||
|
"learning_rate": self.learning_rate,
|
||||||
|
"weight_decay": self.weight_decay,
|
||||||
|
"warmup_ratio": self.warmup_ratio,
|
||||||
|
"label_smoothing": self.label_smoothing,
|
||||||
|
"total_steps": self.total_steps,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
||||||
|
logger.info(f"Checkpoint saved to {checkpoint_path}")
|
||||||
|
|
||||||
|
if is_best:
|
||||||
|
best_path = self.checkpoint_dir / "best_model.pt"
|
||||||
|
torch.save(checkpoint, best_path)
|
||||||
|
logger.info(f"Best model saved to {best_path}")
|
||||||
|
|
||||||
|
def load_checkpoint(self, checkpoint_path: Union[str, Path]):
|
||||||
|
"""
|
||||||
|
加载检查点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: 检查点文件路径
|
||||||
|
"""
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
|
||||||
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
||||||
|
|
||||||
|
self.current_step = checkpoint["step"]
|
||||||
|
self.current_epoch = checkpoint["epoch"]
|
||||||
|
self.best_eval_loss = checkpoint["best_eval_loss"]
|
||||||
|
|
||||||
|
logger.info(f"Checkpoint loaded from {checkpoint_path}")
|
||||||
|
logger.info(
|
||||||
|
f"Resuming from step {self.current_step}, epoch {self.current_epoch}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _log_to_tensorboard(self, metrics: Dict[str, float], step: int):
|
||||||
|
"""将指标记录到TensorBoard"""
|
||||||
|
if self.writer is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for key, value in metrics.items():
|
||||||
|
self.writer.add_scalar(key, value, step)
|
||||||
|
|
||||||
|
def _create_progress_bar(self) -> Progress:
|
||||||
|
"""创建Rich进度条"""
|
||||||
|
return Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
console=self.console,
|
||||||
|
expand=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _print_training_info(self):
|
||||||
|
"""打印训练信息"""
|
||||||
|
info_table = Table(
|
||||||
|
title="Training Configuration",
|
||||||
|
show_header=True,
|
||||||
|
header_style="bold magenta",
|
||||||
|
)
|
||||||
|
info_table.add_column("Parameter", style="cyan")
|
||||||
|
info_table.add_column("Value", style="green")
|
||||||
|
|
||||||
|
info_table.add_row("Device", str(self.device))
|
||||||
|
info_table.add_row("Total Steps", str(self.total_steps))
|
||||||
|
info_table.add_row("Warmup Steps", str(self.warmup_steps))
|
||||||
|
info_table.add_row("Learning Rate", f"{self.learning_rate:.2e}")
|
||||||
|
info_table.add_row("Min Learning Rate", f"{self.min_learning_rate:.2e}")
|
||||||
|
info_table.add_row("Weight Decay", str(self.weight_decay))
|
||||||
|
info_table.add_row("Label Smoothing", str(self.label_smoothing))
|
||||||
|
info_table.add_row("Gradient Accumulation", str(self.grad_accum_steps))
|
||||||
|
info_table.add_row("Mixed Precision", str(self.mixed_precision))
|
||||||
|
|
||||||
|
self.console.print(info_table)
|
||||||
|
|
||||||
|
def train(self, resume_from: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
主训练循环
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resume_from: 从哪个检查点恢复训练(可选)
|
||||||
|
"""
|
||||||
|
# 如果提供了检查点,则恢复训练
|
||||||
|
if resume_from is not None:
|
||||||
|
self.load_checkpoint(resume_from)
|
||||||
|
|
||||||
|
# 打印训练信息
|
||||||
|
self._print_training_info()
|
||||||
|
|
||||||
|
# 初始化训练状态
|
||||||
|
global_step = self.current_step
|
||||||
|
accumulated_loss = 0.0
|
||||||
|
accumulated_accuracy = 0.0
|
||||||
|
accumulation_counter = 0
|
||||||
|
|
||||||
|
# 创建进度条
|
||||||
|
with self._create_progress_bar() as progress:
|
||||||
|
epoch_task = progress.add_task(
|
||||||
|
f"[cyan]Epoch {self.current_epoch + 1}/{self.num_epochs}",
|
||||||
|
total=self.total_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 训练循环
|
||||||
|
for epoch in range(self.current_epoch, self.num_epochs):
|
||||||
|
self.current_epoch = epoch
|
||||||
|
progress.update(
|
||||||
|
epoch_task, description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(self.train_dataloader):
|
||||||
|
# 更新学习率
|
||||||
|
current_lr = self._update_learning_rate()
|
||||||
|
|
||||||
|
# 训练步骤
|
||||||
|
loss, metrics = self.train_step(batch)
|
||||||
|
|
||||||
|
# 累积指标
|
||||||
|
accumulated_loss += loss
|
||||||
|
accumulated_accuracy += metrics.get("accuracy", 0.0)
|
||||||
|
accumulation_counter += 1
|
||||||
|
|
||||||
|
# 梯度累积:每grad_accum_steps步更新一次参数
|
||||||
|
if (global_step + 1) % self.grad_accum_steps == 0:
|
||||||
|
# 梯度裁剪
|
||||||
|
self.scaler.unscale_(self.optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(), self.clip_grad_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新参数
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
# 更新进度条
|
||||||
|
progress.update(
|
||||||
|
epoch_task,
|
||||||
|
advance=1,
|
||||||
|
description=f"[cyan]Epoch {epoch + 1}/{self.num_epochs} | "
|
||||||
|
f"Step {global_step}/{self.total_steps} | "
|
||||||
|
f"Loss: {loss:.4f} | "
|
||||||
|
f"LR: {current_lr:.2e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 定期评估和记录
|
||||||
|
if (global_step + 1) % self.eval_frequency == 0:
|
||||||
|
# 计算平均指标
|
||||||
|
avg_loss = accumulated_loss / accumulation_counter
|
||||||
|
avg_accuracy = accumulated_accuracy / accumulation_counter
|
||||||
|
|
||||||
|
# 评估模型
|
||||||
|
eval_metrics = self.evaluate()
|
||||||
|
|
||||||
|
# 准备日志指标
|
||||||
|
log_metrics = {
|
||||||
|
"train/loss": avg_loss,
|
||||||
|
"train/accuracy": avg_accuracy,
|
||||||
|
"train/learning_rate": current_lr,
|
||||||
|
}
|
||||||
|
|
||||||
|
if eval_metrics:
|
||||||
|
log_metrics.update(
|
||||||
|
{
|
||||||
|
"eval/loss": eval_metrics["eval_loss"],
|
||||||
|
"eval/accuracy": eval_metrics["eval_accuracy"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新最佳模型
|
||||||
|
if eval_metrics["eval_loss"] < self.best_eval_loss:
|
||||||
|
self.best_eval_loss = eval_metrics["eval_loss"]
|
||||||
|
self.save_checkpoint(
|
||||||
|
f"step_{global_step}.pt", is_best=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录到TensorBoard
|
||||||
|
self._log_to_tensorboard(log_metrics, global_step)
|
||||||
|
|
||||||
|
# 打印日志
|
||||||
|
log_text = (
|
||||||
|
f"[Epoch {epoch + 1}/{self.num_epochs}] "
|
||||||
|
f"[Step {global_step}/{self.total_steps}] "
|
||||||
|
f"Train Loss: {avg_loss:.4f} | "
|
||||||
|
f"Train Acc: {avg_accuracy:.4f} | "
|
||||||
|
f"LR: {current_lr:.2e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if eval_metrics:
|
||||||
|
log_text += (
|
||||||
|
f" | Eval Loss: {eval_metrics['eval_loss']:.4f} | "
|
||||||
|
f"Eval Acc: {eval_metrics['eval_accuracy']:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
progress.console.log(log_text)
|
||||||
|
|
||||||
|
# 重置累积指标
|
||||||
|
accumulated_loss = 0.0
|
||||||
|
accumulated_accuracy = 0.0
|
||||||
|
accumulation_counter = 0
|
||||||
|
|
||||||
|
# 定期保存检查点
|
||||||
|
if (global_step + 1) % self.save_frequency == 0:
|
||||||
|
self.save_checkpoint(f"step_{global_step}.pt")
|
||||||
|
|
||||||
|
# 更新步数
|
||||||
|
global_step += 1
|
||||||
|
self.current_step = global_step
|
||||||
|
|
||||||
|
# 检查是否达到总步数
|
||||||
|
if global_step >= self.total_steps:
|
||||||
|
progress.update(epoch_task, completed=self.total_steps)
|
||||||
|
break
|
||||||
|
|
||||||
|
# 重置进度条
|
||||||
|
progress.reset(epoch_task)
|
||||||
|
|
||||||
|
# 每个epoch结束后保存检查点
|
||||||
|
self.save_checkpoint(f"epoch_{epoch + 1}.pt")
|
||||||
|
|
||||||
|
# 检查是否达到总步数
|
||||||
|
if global_step >= self.total_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 训练完成
|
||||||
|
logger.info("Training completed!")
|
||||||
|
|
||||||
|
# 保存最终模型
|
||||||
|
self.save_checkpoint("final_model.pt")
|
||||||
|
|
||||||
|
# 关闭TensorBoard写入器
|
||||||
|
if self.writer is not None:
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
def worker_init_fn(worker_id: int) -> None:
|
||||||
|
"""
|
||||||
|
初始化每个DataLoader worker的随机种子,确保可复现性
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_id: worker的ID
|
||||||
|
"""
|
||||||
|
worker_seed = torch.initial_seed() % (2**32)
|
||||||
|
np.random.seed(worker_seed)
|
||||||
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
自定义批处理函数,将多个样本组合成一个batch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: 样本列表,每个样本是一个字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
批处理后的字典,tensor字段已stack,字符串字段保持为列表
|
||||||
|
"""
|
||||||
|
# 处理tensor字段 - 使用squeeze去除多余的batch维度
|
||||||
|
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
|
||||||
|
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
|
||||||
|
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
|
||||||
|
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
||||||
|
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
||||||
|
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
||||||
|
|
||||||
|
# 字符串字段保持为列表
|
||||||
|
prefixes = [item["prefix"] for item in batch]
|
||||||
|
suffixes = [item["suffix"] for item in batch]
|
||||||
|
pinyins = [item["pinyin"] for item in batch]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"labels": labels,
|
||||||
|
"history_slot_ids": history_slot_ids,
|
||||||
|
"prefix": prefixes,
|
||||||
|
"suffix": suffixes,
|
||||||
|
"pinyin": pinyins,
|
||||||
|
"pinyin_ids": pinyin_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Typer CLI应用
|
||||||
|
app = typer.Typer(help="输入法模型训练命令行工具", add_completion=False)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def train(
|
||||||
|
# 数据参数
|
||||||
|
train_data_path: str = typer.Option(
|
||||||
|
..., "--train-data-path", "-t", help="训练数据集路径"
|
||||||
|
),
|
||||||
|
eval_data_path: str = typer.Option(
|
||||||
|
..., "--eval-data-path", "-e", help="评估数据集路径"
|
||||||
|
),
|
||||||
|
output_dir: str = typer.Option("./output", "--output-dir", "-o", help="输出目录"),
|
||||||
|
# 模型参数
|
||||||
|
vocab_size: int = typer.Option(10019, "--vocab-size", help="词汇表大小"),
|
||||||
|
pinyin_vocab_size: int = typer.Option(
|
||||||
|
30, "--pinyin-vocab-size", help="拼音词汇表大小"
|
||||||
|
),
|
||||||
|
max_iter_length: int = typer.Option(
|
||||||
|
1024 * 1024 * 128, "--max_iter_length", help="数据集大小"
|
||||||
|
),
|
||||||
|
dim: int = typer.Option(512, "--dim", help="模型维度"),
|
||||||
|
num_slots: int = typer.Option(8, "--num-slots", help="历史槽位数量"),
|
||||||
|
n_layers: int = typer.Option(4, "--n-layers", help="Transformer层数"),
|
||||||
|
n_heads: int = typer.Option(4, "--n-heads", help="注意力头数"),
|
||||||
|
num_experts: int = typer.Option(20, "--num-experts", help="MoE专家数量"),
|
||||||
|
max_seq_len: int = typer.Option(128, "--max-seq-len", help="最大序列长度"),
|
||||||
|
use_pinyin: bool = typer.Option(False, "--use-pinyin", help="是否使用拼音特征"),
|
||||||
|
# 训练参数
|
||||||
|
batch_size: int = typer.Option(128, "--batch-size", "-b", help="批次大小"),
|
||||||
|
num_epochs: int = typer.Option(10, "--num-epochs", help="训练轮数"),
|
||||||
|
learning_rate: float = typer.Option(1e-5, "--learning-rate", "-lr", help="学习率"),
|
||||||
|
min_learning_rate: float = typer.Option(
|
||||||
|
1e-9, "--min-learning-rate", help="最小学习率"
|
||||||
|
),
|
||||||
|
weight_decay: float = typer.Option(0.1, "--weight-decay", help="权重衰减"),
|
||||||
|
warmup_ratio: float = typer.Option(0.1, "--warmup-ratio", help="热身步数比例"),
|
||||||
|
label_smoothing: float = typer.Option(
|
||||||
|
0.15, "--label-smoothing", help="标签平滑参数"
|
||||||
|
),
|
||||||
|
grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"),
|
||||||
|
clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"),
|
||||||
|
eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"),
|
||||||
|
save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"),
|
||||||
|
# 其他参数
|
||||||
|
mixed_precision: bool = typer.Option(
|
||||||
|
True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练"
|
||||||
|
),
|
||||||
|
use_tensorboard: bool = typer.Option(
|
||||||
|
True, "--tensorboard/--no-tensorboard", help="是否使用TensorBoard"
|
||||||
|
),
|
||||||
|
resume_from: Optional[str] = typer.Option(
|
||||||
|
None, "--resume-from", help="从检查点恢复训练"
|
||||||
|
),
|
||||||
|
seed: int = typer.Option(42, "--seed", help="随机种子"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
训练输入法模型
|
||||||
|
"""
|
||||||
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
|
||||||
|
# 设置随机种子
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
# 打印配置信息
|
||||||
|
console.print(
|
||||||
|
Panel.fit("[bold cyan]输入法模型训练配置[/bold cyan]", border_style="cyan")
|
||||||
|
)
|
||||||
|
|
||||||
|
config_table = Table(show_header=True, header_style="bold magenta")
|
||||||
|
config_table.add_column("Category", style="cyan")
|
||||||
|
config_table.add_column("Parameter", style="green")
|
||||||
|
config_table.add_column("Value", style="yellow")
|
||||||
|
|
||||||
|
# 添加配置信息
|
||||||
|
config_table.add_row("数据", "训练数据路径", train_data_path)
|
||||||
|
config_table.add_row("数据", "评估数据路径", eval_data_path)
|
||||||
|
config_table.add_row("数据", "输出目录", output_dir)
|
||||||
|
config_table.add_row("数据", "批次大小", str(batch_size))
|
||||||
|
|
||||||
|
config_table.add_row("模型", "词汇表大小", str(vocab_size))
|
||||||
|
config_table.add_row("模型", "拼音词汇表", str(pinyin_vocab_size))
|
||||||
|
config_table.add_row("模型", "模型维度", str(dim))
|
||||||
|
config_table.add_row("模型", "槽位数量", str(num_slots))
|
||||||
|
config_table.add_row("模型", "Transformer层数", str(n_layers))
|
||||||
|
config_table.add_row("模型", "注意力头数", str(n_heads))
|
||||||
|
config_table.add_row("模型", "MoE专家数", str(num_experts))
|
||||||
|
config_table.add_row("模型", "使用拼音", str(use_pinyin))
|
||||||
|
|
||||||
|
config_table.add_row("训练", "训练轮数", str(num_epochs))
|
||||||
|
config_table.add_row("训练", "学习率", f"{learning_rate:.2e}")
|
||||||
|
config_table.add_row("训练", "最小学习率", f"{min_learning_rate:.2e}")
|
||||||
|
config_table.add_row("训练", "权重衰减", str(weight_decay))
|
||||||
|
config_table.add_row("训练", "热身比例", str(warmup_ratio))
|
||||||
|
config_table.add_row("训练", "标签平滑", str(label_smoothing))
|
||||||
|
config_table.add_row("训练", "梯度累积", str(grad_accum_steps))
|
||||||
|
config_table.add_row("训练", "梯度裁剪", str(clip_grad_norm))
|
||||||
|
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
||||||
|
|
||||||
|
console.print(config_table)
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
config = {
|
||||||
|
"train_data_path": train_data_path,
|
||||||
|
"eval_data_path": eval_data_path,
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"vocab_size": vocab_size,
|
||||||
|
"pinyin_vocab_size": pinyin_vocab_size,
|
||||||
|
"dim": dim,
|
||||||
|
"num_slots": num_slots,
|
||||||
|
"n_layers": n_layers,
|
||||||
|
"n_heads": n_heads,
|
||||||
|
"num_experts": num_experts,
|
||||||
|
"max_seq_len": max_seq_len,
|
||||||
|
"use_pinyin": use_pinyin,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"num_epochs": num_epochs,
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
"min_learning_rate": min_learning_rate,
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
"warmup_ratio": warmup_ratio,
|
||||||
|
"label_smoothing": label_smoothing,
|
||||||
|
"grad_accum_steps": grad_accum_steps,
|
||||||
|
"clip_grad_norm": clip_grad_norm,
|
||||||
|
"eval_frequency": eval_frequency,
|
||||||
|
"save_frequency": save_frequency,
|
||||||
|
"mixed_precision": mixed_precision,
|
||||||
|
"use_tensorboard": use_tensorboard,
|
||||||
|
"seed": seed,
|
||||||
|
"max_iter_length": max_iter_length,
|
||||||
|
}
|
||||||
|
|
||||||
|
config_file = output_path / "training_config.json"
|
||||||
|
with open(config_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
logger.info(f"Configuration saved to {config_file}")
|
||||||
|
|
||||||
|
# 创建数据加载器
|
||||||
|
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
||||||
|
|
||||||
|
# 训练数据集
|
||||||
|
train_dataset = PinyinInputDataset(
|
||||||
|
data_path=train_data_path,
|
||||||
|
max_workers=-1, # 自动选择worker数量
|
||||||
|
max_iter_length=max_iter_length,
|
||||||
|
max_seq_length=max_seq_len,
|
||||||
|
text_field="text",
|
||||||
|
py_style_weight=(9, 2, 1),
|
||||||
|
shuffle_buffer_size=5000,
|
||||||
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 训练数据加载器
|
||||||
|
# 注意:PinyinInputDataset是IterableDataset,所以不能使用shuffle参数
|
||||||
|
# 多worker配置:每个worker处理数据集的一个分片,由dataset.__iter__中的shard处理
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=max(1, (os.cpu_count() or 1) - 1),
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
prefetch_factor=64, # 每个worker预取64个batch,适合大内存场景
|
||||||
|
persistent_workers=True, # 保持worker进程存活,避免重建开销
|
||||||
|
)
|
||||||
|
|
||||||
|
# 评估数据集(使用相同的设置,但可以调整参数)
|
||||||
|
eval_dataset = PinyinInputDataset(
|
||||||
|
data_path=eval_data_path,
|
||||||
|
max_workers=-1,
|
||||||
|
max_iter_length=1024, # 评估集较小
|
||||||
|
max_seq_length=max_seq_len,
|
||||||
|
text_field="text",
|
||||||
|
py_style_weight=(9, 2, 1),
|
||||||
|
shuffle_buffer_size=1000,
|
||||||
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dataloader = DataLoader(
|
||||||
|
eval_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=1,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
prefetch_factor=64, # 每个worker预取64个batch
|
||||||
|
persistent_workers=True, # 保持worker进程存活
|
||||||
|
)
|
||||||
|
|
||||||
|
console.print("[green]✓ 数据加载器创建完成[/green]")
|
||||||
|
console.print(f" 训练批次大小: {batch_size}")
|
||||||
|
console.print(f" 评估批次大小: {batch_size}")
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
console.print("[bold cyan]正在创建模型...[/bold cyan]")
|
||||||
|
model = InputMethodEngine(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
pinyin_vocab_size=pinyin_vocab_size,
|
||||||
|
dim=dim,
|
||||||
|
num_slots=num_slots,
|
||||||
|
n_layers=n_layers,
|
||||||
|
n_heads=n_heads,
|
||||||
|
num_experts=num_experts,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
console.print(
|
||||||
|
f"[green]✓ 模型创建完成,参数量: {sum(p.numel() for p in model.parameters()):,}[/green]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建训练器
|
||||||
|
console.print("[bold cyan]正在创建训练器...[/bold cyan]")
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=eval_dataloader,
|
||||||
|
total_steps=int(max_iter_length / batch_size),
|
||||||
|
output_dir=output_dir,
|
||||||
|
num_epochs=num_epochs,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
min_learning_rate=min_learning_rate,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
warmup_ratio=warmup_ratio,
|
||||||
|
label_smoothing=label_smoothing,
|
||||||
|
grad_accum_steps=grad_accum_steps,
|
||||||
|
clip_grad_norm=clip_grad_norm,
|
||||||
|
eval_frequency=eval_frequency,
|
||||||
|
save_frequency=save_frequency,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
use_tensorboard=use_tensorboard,
|
||||||
|
)
|
||||||
|
|
||||||
|
console.print("[green]✓ 训练器创建完成[/green]")
|
||||||
|
|
||||||
|
# 开始训练
|
||||||
|
console.print("\n[bold cyan]开始训练...[/bold cyan]")
|
||||||
|
console.print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
|
trainer.train(resume_from=resume_from)
|
||||||
|
|
||||||
|
console.print("[bold green]✓ 训练完成![/bold green]")
|
||||||
|
console.print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
console.print(f"模型和日志保存在: {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def evaluate(
|
||||||
|
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
|
||||||
|
data_path: str = typer.Option(..., "--data-path", "-d", help="数据集路径"),
|
||||||
|
batch_size: int = typer.Option(32, "--batch-size", "-b", help="批次大小"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
评估训练好的模型
|
||||||
|
"""
|
||||||
|
console = Console()
|
||||||
|
console.print(f"[bold cyan]评估模型: {checkpoint_path}[/bold cyan]")
|
||||||
|
|
||||||
|
# 这里应该实现评估逻辑
|
||||||
|
# 1. 加载检查点
|
||||||
|
# 2. 创建数据加载器
|
||||||
|
# 3. 评估模型
|
||||||
|
|
||||||
|
console.print("[yellow]评估功能待实现[/yellow]")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def export(
|
||||||
|
checkpoint_path: str = typer.Option(..., "--checkpoint", "-c", help="检查点路径"),
|
||||||
|
output_path: str = typer.Option(
|
||||||
|
"./exported_model.onnx", "--output", "-o", help="输出路径"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
导出模型为ONNX格式
|
||||||
|
"""
|
||||||
|
console = Console()
|
||||||
|
console.print(f"[bold cyan]导出模型到: {output_path}[/bold cyan]")
|
||||||
|
|
||||||
|
# 这里应该实现导出逻辑
|
||||||
|
# 1. 加载检查点
|
||||||
|
# 2. 导出为ONNX
|
||||||
|
|
||||||
|
console.print("[yellow]导出功能待实现[/yellow]")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
24
test.py
24
test.py
|
|
@ -1,17 +1,29 @@
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from model.dataset import PinyinInputDataset
|
from model.dataset import PinyinInputDataset
|
||||||
|
from model.trainer import collate_fn, worker_init_fn
|
||||||
|
|
||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
dataset_path = "data"
|
dataset_path = "data"
|
||||||
else:
|
else:
|
||||||
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
|
dataset_path = "/home/songsenand/Data/corpus/CCI-Data/"
|
||||||
|
|
||||||
dataset = PinyinInputDataset(dataset_path, max_iter_length=20, max_workes=3)
|
dataset = PinyinInputDataset(dataset_path, max_iter_length=128 * 128)
|
||||||
for i, line in enumerate(dataset):
|
dataloader = DataLoader(
|
||||||
for k, v in line.items():
|
dataset,
|
||||||
if isinstance(v, str):
|
batch_size=128,
|
||||||
continue
|
num_workers=2,
|
||||||
print(k, v.shape)
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
prefetch_factor=64, # 每个worker预取64个batch,适合大内存场景
|
||||||
|
persistent_workers=True, # 保持worker进程存活,避免重建开销
|
||||||
|
)
|
||||||
|
dataloader = list([i for i in dataloader])
|
||||||
|
print(len(dataloader[0]["labels"]))
|
||||||
|
for i, line in tqdm(enumerate(dataloader), total=128):
|
||||||
|
print(line["pinyin_ids"].squeeze(-1).shape)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,451 @@
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from rich.console import Console
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
# 添加src目录到路径
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from src.model.model import InputMethodEngine
|
||||||
|
from src.model.trainer import Trainer
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
class MockDataset(Dataset):
|
||||||
|
"""模拟数据集用于测试"""
|
||||||
|
|
||||||
|
def __init__(self, num_samples=100, vocab_size=100, seq_len=128, num_slots=8):
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.num_slots = num_slots
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# 生成模拟数据
|
||||||
|
return {
|
||||||
|
"input_ids": torch.randint(0, self.vocab_size, (self.seq_len,)),
|
||||||
|
"token_type_ids": torch.randint(0, 2, (self.seq_len,)),
|
||||||
|
"attention_mask": torch.ones(self.seq_len, dtype=torch.long),
|
||||||
|
"history_slot_ids": torch.randint(0, self.vocab_size, (self.num_slots,)),
|
||||||
|
"label": torch.randint(0, self.vocab_size, (1,)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_creation():
|
||||||
|
"""测试数据集创建"""
|
||||||
|
console.print("[bold cyan]测试数据集创建...[/bold cyan]")
|
||||||
|
|
||||||
|
dataset = MockDataset(num_samples=10)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
batch = next(iter(dataloader))
|
||||||
|
console.print(f"批处理形状:")
|
||||||
|
console.print(f" input_ids: {batch['input_ids'].shape}")
|
||||||
|
console.print(f" token_type_ids: {batch['token_type_ids'].shape}")
|
||||||
|
console.print(f" attention_mask: {batch['attention_mask'].shape}")
|
||||||
|
console.print(f" history_slot_ids: {batch['history_slot_ids'].shape}")
|
||||||
|
console.print(f" label: {batch['label'].shape}")
|
||||||
|
|
||||||
|
assert batch["input_ids"].shape == (2, 128), "input_ids形状不正确"
|
||||||
|
assert batch["token_type_ids"].shape == (2, 128), "token_type_ids形状不正确"
|
||||||
|
assert batch["attention_mask"].shape == (2, 128), "attention_mask形状不正确"
|
||||||
|
assert batch["history_slot_ids"].shape == (2, 8), "history_slot_ids形状不正确"
|
||||||
|
assert batch["label"].shape == (2, 1), "label形状不正确"
|
||||||
|
|
||||||
|
console.print("[green]✓ 数据集测试通过[/green]\n")
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_creation():
|
||||||
|
"""测试模型创建"""
|
||||||
|
console.print("[bold cyan]测试模型创建...[/bold cyan]")
|
||||||
|
|
||||||
|
model = InputMethodEngine(
|
||||||
|
vocab_size=100,
|
||||||
|
pinyin_vocab_size=28,
|
||||||
|
dim=64, # 使用较小的维度以加速测试
|
||||||
|
num_slots=8,
|
||||||
|
n_layers=2,
|
||||||
|
n_heads=2,
|
||||||
|
num_experts=4,
|
||||||
|
max_seq_len=128,
|
||||||
|
use_pinyin=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试前向传播
|
||||||
|
batch_size = 2
|
||||||
|
input_ids = torch.randint(0, 100, (batch_size, 128))
|
||||||
|
token_type_ids = torch.randint(0, 2, (batch_size, 128))
|
||||||
|
attention_mask = torch.ones(batch_size, 128, dtype=torch.long)
|
||||||
|
history_slot_ids = torch.randint(0, 100, (batch_size, 8))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
history_slot_ids=history_slot_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
console.print(f"模型输出形状: {logits.shape}")
|
||||||
|
console.print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
assert logits.shape == (batch_size, 100), "模型输出形状不正确"
|
||||||
|
|
||||||
|
console.print("[green]✓ 模型测试通过[/green]\n")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_trainer_initialization():
|
||||||
|
"""测试训练器初始化"""
|
||||||
|
console.print("[bold cyan]测试训练器初始化...[/bold cyan]")
|
||||||
|
|
||||||
|
# 创建模型和数据集
|
||||||
|
model = InputMethodEngine(
|
||||||
|
vocab_size=100,
|
||||||
|
pinyin_vocab_size=28,
|
||||||
|
dim=64,
|
||||||
|
num_slots=8,
|
||||||
|
n_layers=2,
|
||||||
|
n_heads=2,
|
||||||
|
num_experts=4,
|
||||||
|
max_seq_len=128,
|
||||||
|
use_pinyin=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = MockDataset(num_samples=50)
|
||||||
|
eval_dataset = MockDataset(num_samples=10)
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=False)
|
||||||
|
eval_dataloader = DataLoader(eval_dataset, batch_size=4, shuffle=False)
|
||||||
|
|
||||||
|
# 初始化训练器
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=eval_dataloader,
|
||||||
|
output_dir="./test_output",
|
||||||
|
num_epochs=1,
|
||||||
|
total_steps=10, # 限制总步数
|
||||||
|
learning_rate=1e-4,
|
||||||
|
min_learning_rate=1e-6,
|
||||||
|
weight_decay=0.1,
|
||||||
|
warmup_ratio=0.1,
|
||||||
|
label_smoothing=0.1,
|
||||||
|
grad_accum_steps=1,
|
||||||
|
clip_grad_norm=1.0,
|
||||||
|
eval_frequency=5,
|
||||||
|
save_frequency=10,
|
||||||
|
mixed_precision=False, # 测试时关闭混合精度
|
||||||
|
use_tensorboard=False, # 测试时关闭TensorBoard
|
||||||
|
)
|
||||||
|
|
||||||
|
console.print(f"训练器设备: {trainer.device}")
|
||||||
|
console.print(f"总步数: {trainer.total_steps}")
|
||||||
|
console.print(f"热身步数: {trainer.warmup_steps}")
|
||||||
|
console.print(f"优化器类型: {type(trainer.optimizer)}")
|
||||||
|
console.print(f"损失函数类型: {type(trainer.criterion)}")
|
||||||
|
|
||||||
|
assert trainer.device.type in ["cpu", "cuda"], "设备类型不正确"
|
||||||
|
assert trainer.total_steps == 10, "总步数不正确"
|
||||||
|
assert trainer.warmup_steps == 1, "热身步数不正确" # 10 * 0.1 = 1
|
||||||
|
assert isinstance(trainer.optimizer, torch.optim.AdamW), "优化器类型不正确"
|
||||||
|
assert isinstance(trainer.criterion, nn.CrossEntropyLoss), "损失函数类型不正确"
|
||||||
|
|
||||||
|
console.print("[green]✓ 训练器初始化测试通过[/green]\n")
|
||||||
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_step():
|
||||||
|
"""测试训练步骤"""
|
||||||
|
console.print("[bold cyan]测试训练步骤...[/bold cyan]")
|
||||||
|
|
||||||
|
# 创建训练器
|
||||||
|
model = InputMethodEngine(
|
||||||
|
vocab_size=100,
|
||||||
|
pinyin_vocab_size=28,
|
||||||
|
dim=64,
|
||||||
|
num_slots=8,
|
||||||
|
n_layers=2,
|
||||||
|
n_heads=2,
|
||||||
|
num_experts=4,
|
||||||
|
max_seq_len=128,
|
||||||
|
use_pinyin=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = MockDataset(num_samples=10)
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=None,
|
||||||
|
output_dir="./test_output",
|
||||||
|
num_epochs=1,
|
||||||
|
total_steps=5,
|
||||||
|
learning_rate=1e-4,
|
||||||
|
mixed_precision=False,
|
||||||
|
use_tensorboard=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行一个训练步骤
|
||||||
|
batch = next(iter(train_dataloader))
|
||||||
|
loss, metrics = trainer.train_step(batch)
|
||||||
|
|
||||||
|
console.print(f"训练步骤损失: {loss:.4f}")
|
||||||
|
console.print(f"训练步骤指标: {metrics}")
|
||||||
|
|
||||||
|
assert isinstance(loss, float), "损失值类型不正确"
|
||||||
|
assert loss >= 0, "损失值应为非负数"
|
||||||
|
assert "lr" in metrics, "指标中缺少学习率"
|
||||||
|
assert "accuracy" in metrics, "指标中缺少准确率"
|
||||||
|
assert 0 <= metrics["accuracy"] <= 1, "准确率应在0-1之间"
|
||||||
|
|
||||||
|
console.print("[green]✓ 训练步骤测试通过[/green]\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluation():
|
||||||
|
"""测试评估功能"""
|
||||||
|
console.print("[bold cyan]测试评估功能...[/bold cyan]")
|
||||||
|
|
||||||
|
# 创建训练器
|
||||||
|
model = InputMethodEngine(
|
||||||
|
vocab_size=100,
|
||||||
|
pinyin_vocab_size=28,
|
||||||
|
dim=64,
|
||||||
|
num_slots=8,
|
||||||
|
n_layers=2,
|
||||||
|
n_heads=2,
|
||||||
|
num_experts=4,
|
||||||
|
max_seq_len=128,
|
||||||
|
use_pinyin=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = MockDataset(num_samples=10)
|
||||||
|
eval_dataset = MockDataset(num_samples=5)
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
|
||||||
|
eval_dataloader = DataLoader(eval_dataset, batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=eval_dataloader,
|
||||||
|
output_dir="./test_output",
|
||||||
|
num_epochs=1,
|
||||||
|
total_steps=5,
|
||||||
|
learning_rate=1e-4,
|
||||||
|
mixed_precision=False,
|
||||||
|
use_tensorboard=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行评估
|
||||||
|
eval_metrics = trainer.evaluate()
|
||||||
|
|
||||||
|
console.print(f"评估指标: {eval_metrics}")
|
||||||
|
|
||||||
|
assert "eval_loss" in eval_metrics, "评估指标中缺少eval_loss"
|
||||||
|
assert "eval_accuracy" in eval_metrics, "评估指标中缺少eval_accuracy"
|
||||||
|
assert eval_metrics["eval_loss"] >= 0, "评估损失应为非负数"
|
||||||
|
assert 0 <= eval_metrics["eval_accuracy"] <= 1, "评估准确率应在0-1之间"
|
||||||
|
|
||||||
|
console.print("[green]✓ 评估功能测试通过[/green]\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_lr_scheduler():
|
||||||
|
"""测试学习率调度器"""
|
||||||
|
console.print("[bold cyan]测试学习率调度器...[/bold cyan]")
|
||||||
|
|
||||||
|
model = InputMethodEngine(
|
||||||
|
vocab_size=100,
|
||||||
|
pinyin_vocab_size=28,
|
||||||
|
dim=64,
|
||||||
|
num_slots=8,
|
||||||
|
n_layers=2,
|
||||||
|
n_heads=2,
|
||||||
|
num_experts=4,
|
||||||
|
max_seq_len=128,
|
||||||
|
use_pinyin=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = MockDataset(num_samples=10)
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=None,
|
||||||
|
output_dir="./test_output",
|
||||||
|
num_epochs=1,
|
||||||
|
total_steps=100,
|
||||||
|
learning_rate=1e-3,
|
||||||
|
min_learning_rate=1e-5,
|
||||||
|
warmup_ratio=0.2, # 20%热身
|
||||||
|
mixed_precision=False,
|
||||||
|
use_tensorboard=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试不同步数的学习率
|
||||||
|
test_steps = [0, 10, 20, 50, 99]
|
||||||
|
lr_values = []
|
||||||
|
|
||||||
|
for step in test_steps:
|
||||||
|
trainer.current_step = step
|
||||||
|
lr = trainer._get_current_lr()
|
||||||
|
lr_values.append(lr)
|
||||||
|
console.print(f"步数 {step}: 学习率 = {lr:.2e}")
|
||||||
|
|
||||||
|
# 验证学习率变化趋势
|
||||||
|
assert lr_values[0] == 0.0, "第0步学习率应为0"
|
||||||
|
assert lr_values[1] > lr_values[0], "热身阶段学习率应增加"
|
||||||
|
assert lr_values[4] < lr_values[2], "余弦退火阶段学习率应下降"
|
||||||
|
assert lr_values[4] >= 1e-5, "最终学习率不应低于最小值"
|
||||||
|
|
||||||
|
console.print("[green]✓ 学习率调度器测试通过[/green]\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_checkpoint_saving():
|
||||||
|
"""测试检查点保存"""
|
||||||
|
console.print("[bold cyan]测试检查点保存...[/bold cyan]")
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
# 创建临时目录
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建训练器
|
||||||
|
model = InputMethodEngine(
|
||||||
|
vocab_size=100,
|
||||||
|
pinyin_vocab_size=28,
|
||||||
|
dim=64,
|
||||||
|
num_slots=8,
|
||||||
|
n_layers=2,
|
||||||
|
n_heads=2,
|
||||||
|
num_experts=4,
|
||||||
|
max_seq_len=128,
|
||||||
|
use_pinyin=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = MockDataset(num_samples=10)
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=None,
|
||||||
|
output_dir=temp_dir,
|
||||||
|
num_epochs=1,
|
||||||
|
total_steps=5,
|
||||||
|
learning_rate=1e-4,
|
||||||
|
mixed_precision=False,
|
||||||
|
use_tensorboard=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存检查点
|
||||||
|
checkpoint_path = Path(temp_dir) / "checkpoints" / "test_checkpoint.pt"
|
||||||
|
trainer.save_checkpoint("test_checkpoint.pt")
|
||||||
|
|
||||||
|
console.print(f"检查点保存路径: {checkpoint_path}")
|
||||||
|
assert checkpoint_path.exists(), "检查点文件未创建"
|
||||||
|
|
||||||
|
# 加载检查点
|
||||||
|
trainer2 = Trainer(
|
||||||
|
model=InputMethodEngine(
|
||||||
|
vocab_size=100,
|
||||||
|
pinyin_vocab_size=28,
|
||||||
|
dim=64,
|
||||||
|
num_slots=8,
|
||||||
|
n_layers=2,
|
||||||
|
n_heads=2,
|
||||||
|
num_experts=4,
|
||||||
|
max_seq_len=128,
|
||||||
|
use_pinyin=False,
|
||||||
|
),
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=None,
|
||||||
|
output_dir=temp_dir,
|
||||||
|
num_epochs=1,
|
||||||
|
total_steps=5,
|
||||||
|
learning_rate=1e-4,
|
||||||
|
mixed_precision=False,
|
||||||
|
use_tensorboard=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer2.load_checkpoint(checkpoint_path)
|
||||||
|
|
||||||
|
console.print(f"加载后的步数: {trainer2.current_step}")
|
||||||
|
console.print(f"加载后的epoch: {trainer2.current_epoch}")
|
||||||
|
|
||||||
|
assert trainer2.current_step == trainer.current_step, "步数未正确恢复"
|
||||||
|
assert trainer2.current_epoch == trainer.current_epoch, "epoch未正确恢复"
|
||||||
|
|
||||||
|
console.print("[green]✓ 检查点保存测试通过[/green]\n")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理临时目录
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""运行所有测试"""
|
||||||
|
console.print("[bold blue]开始测试Trainer类...[/bold blue]\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 测试1: 数据集创建
|
||||||
|
test_dataset_creation()
|
||||||
|
|
||||||
|
# 测试2: 模型创建
|
||||||
|
test_model_creation()
|
||||||
|
|
||||||
|
# 测试3: 训练器初始化
|
||||||
|
test_trainer_initialization()
|
||||||
|
|
||||||
|
# 测试4: 训练步骤
|
||||||
|
test_training_step()
|
||||||
|
|
||||||
|
# 测试5: 评估功能
|
||||||
|
test_evaluation()
|
||||||
|
|
||||||
|
# 测试6: 学习率调度器
|
||||||
|
test_lr_scheduler()
|
||||||
|
|
||||||
|
# 测试7: 检查点保存
|
||||||
|
test_checkpoint_saving()
|
||||||
|
|
||||||
|
console.print("[bold green]所有测试通过! ✅[/bold green]")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[bold red]测试失败: {e}[/bold red]")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 清理之前的测试输出
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
test_output_dir = Path("./test_output")
|
||||||
|
if test_output_dir.exists():
|
||||||
|
shutil.rmtree(test_output_dir)
|
||||||
|
|
||||||
|
# 运行测试
|
||||||
|
exit_code = main()
|
||||||
|
|
||||||
|
# 清理测试输出
|
||||||
|
if test_output_dir.exists():
|
||||||
|
shutil.rmtree(test_output_dir)
|
||||||
|
|
||||||
|
sys.exit(exit_code)
|
||||||
Loading…
Reference in New Issue