feat(eval): 添加模型评估脚本,支持文本分析与概率分布检测
This commit is contained in:
parent
919d0972e2
commit
a0e4d25b2f
|
|
@ -0,0 +1,744 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
eval.py - 评估模型在给定文本上的表现
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
python eval.py --text-file path/to/text.txt [--checkpoint path/to/model.pt] [--num-samples N]
|
||||||
|
|
||||||
|
功能:
|
||||||
|
1. 读取文本文件,限制长度300,随机抽取
|
||||||
|
2. 参考dataset.py将文本内容转化为模型可接受数据,包含part1/part2/part3/part4/槽位历史等信息
|
||||||
|
3. 使用模型推理,参考test.py
|
||||||
|
4. 错误的打印,对于方差极小,概率和猜没什么差别的,重点标记
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from modelscope import AutoTokenizer
|
||||||
|
from pypinyin import lazy_pinyin
|
||||||
|
from pypinyin.contrib.tone_convert import to_initials
|
||||||
|
|
||||||
|
# 添加src目录到路径
|
||||||
|
sys.path.append("src")
|
||||||
|
|
||||||
|
from src.model.model import InputMethodEngine
|
||||||
|
from src.model.query import QueryEngine
|
||||||
|
from src.model.dataset import text_to_pinyin_ids
|
||||||
|
|
||||||
|
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||||
|
|
||||||
|
|
||||||
|
class TextEvaluator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
checkpoint_path: str,
|
||||||
|
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
):
|
||||||
|
self.device = torch.device(device)
|
||||||
|
self.checkpoint_path = checkpoint_path
|
||||||
|
|
||||||
|
# 加载组件
|
||||||
|
self.load_model()
|
||||||
|
self.load_tokenizer()
|
||||||
|
self.load_query_engine()
|
||||||
|
|
||||||
|
# 拼音风格权重(与dataset.py一致)
|
||||||
|
self.py_style_weight = np.array([9, 2, 1]) / sum([9, 2, 1])
|
||||||
|
|
||||||
|
print(f"✅ 评估器初始化完成 (设备: {self.device})")
|
||||||
|
|
||||||
|
# 调试:运行一个固定样本测试
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get("EVAL_DEBUG"):
|
||||||
|
self._debug_sample()
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
"""加载训练好的模型"""
|
||||||
|
self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||||
|
checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
|
||||||
|
if "model_state_dict" in checkpoint:
|
||||||
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
else:
|
||||||
|
self.model.load_state_dict(checkpoint)
|
||||||
|
self.model.eval()
|
||||||
|
self.model.to(self.device)
|
||||||
|
print(
|
||||||
|
f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}"
|
||||||
|
)
|
||||||
|
print(f"✅ 模型词汇表大小: {self.model.vocab_size}")
|
||||||
|
|
||||||
|
def load_tokenizer(self):
|
||||||
|
"""加载tokenizer"""
|
||||||
|
try:
|
||||||
|
tokenizer_path = (
|
||||||
|
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
|
||||||
|
)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
|
||||||
|
print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ 无法加载tokenizer: {e}")
|
||||||
|
print("使用默认的bert-base-chinese tokenizer")
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
||||||
|
|
||||||
|
def load_query_engine(self):
|
||||||
|
"""加载查询引擎用于字符-ID转换"""
|
||||||
|
try:
|
||||||
|
self.query_engine = QueryEngine()
|
||||||
|
stats_path = (
|
||||||
|
Path(__file__).parent
|
||||||
|
/ "src"
|
||||||
|
/ "model"
|
||||||
|
/ "assets"
|
||||||
|
/ "pinyin_char_statistics.json"
|
||||||
|
)
|
||||||
|
if stats_path.exists():
|
||||||
|
self.query_engine.load(stats_path)
|
||||||
|
print(
|
||||||
|
f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"⚠️ 统计文件不存在: {stats_path}")
|
||||||
|
self.query_engine = None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ 无法加载查询引擎: {e}")
|
||||||
|
self.query_engine = None
|
||||||
|
|
||||||
|
def _debug_sample(self):
|
||||||
|
"""调试:运行与test.py相同的固定样本"""
|
||||||
|
print("\n🧪 调试:运行固定样本(与test.py相同)")
|
||||||
|
|
||||||
|
# 复制test.py中的样本
|
||||||
|
part1 = "他是一名大学生,在上海读"
|
||||||
|
part2 = "dayi"
|
||||||
|
pinyin_ids = text_to_pinyin_ids(part2)
|
||||||
|
len_py = len(pinyin_ids)
|
||||||
|
if len_py < 24:
|
||||||
|
pinyin_ids.extend([0] * (24 - len_py))
|
||||||
|
else:
|
||||||
|
pinyin_ids = pinyin_ids[:24]
|
||||||
|
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||||||
|
masked_labels = [15, 4, 0, 0, 0, 0, 0, 0]
|
||||||
|
part3 = "。"
|
||||||
|
part4 = "可行|特别|伤害"
|
||||||
|
|
||||||
|
encoded = self.tokenizer(
|
||||||
|
f"{part4}|{part1}",
|
||||||
|
part3,
|
||||||
|
max_length=128,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_token_type_ids=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sample = {
|
||||||
|
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]),
|
||||||
|
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]),
|
||||||
|
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]),
|
||||||
|
"history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze(
|
||||||
|
0
|
||||||
|
),
|
||||||
|
"prefix": f"{part4}^{part1}",
|
||||||
|
"suffix": part3,
|
||||||
|
"pinyin": part2,
|
||||||
|
"pinyin_ids": pinyin_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 推理
|
||||||
|
logits, probs = self.inference(sample)
|
||||||
|
analysis = self.analyze_probability_distribution(probs)
|
||||||
|
|
||||||
|
print(f"📊 调试样本结果:")
|
||||||
|
print(f" 最高概率: {analysis['max_prob']:.6f}")
|
||||||
|
print(f" 平均概率: {analysis['mean_prob']:.6f}")
|
||||||
|
print(f" 标准差: {analysis['std_prob']:.6f}")
|
||||||
|
print(f" 低方差: {analysis['low_variance']}")
|
||||||
|
|
||||||
|
# 打印top-5
|
||||||
|
print(f" Top-5预测:")
|
||||||
|
for j in range(5):
|
||||||
|
idx = analysis["top_indices"][j]
|
||||||
|
prob = analysis["top_probs"][j]
|
||||||
|
char_info = (
|
||||||
|
self.query_engine.query_by_id(idx) if self.query_engine else None
|
||||||
|
)
|
||||||
|
char = char_info.char if char_info else f"[ID:{idx}]"
|
||||||
|
print(f" {j + 1}. {char} (ID: {idx}) - 概率: {prob:.6f}")
|
||||||
|
|
||||||
|
def generate_pinyin(self, text: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
流式处理单条文本,转换为拼音列表。
|
||||||
|
参考dataset.py中的generate_pinyin方法。
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
text_len = len(text)
|
||||||
|
result: List[str] = [""] * text_len
|
||||||
|
|
||||||
|
# 遍历所有连续汉字片段
|
||||||
|
for match in _HANZI_RE.finditer(text):
|
||||||
|
start_idx = match.start()
|
||||||
|
hanzi_segment = match.group()
|
||||||
|
|
||||||
|
pinyin_list = lazy_pinyin(hanzi_segment)
|
||||||
|
|
||||||
|
if len(pinyin_list) != len(hanzi_segment):
|
||||||
|
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
|
||||||
|
|
||||||
|
for i, py in enumerate(pinyin_list):
|
||||||
|
result[start_idx + i] = py
|
||||||
|
|
||||||
|
# 填充非汉字字符
|
||||||
|
for i, char in enumerate(text):
|
||||||
|
if not result[i]:
|
||||||
|
result[i] = char
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_mask_pinyin(
|
||||||
|
self, text: str, pinyin_list: List[str]
|
||||||
|
) -> Tuple[int, List[str]]:
|
||||||
|
"""
|
||||||
|
生成需要预测汉字对应的拼音,并进行加强。
|
||||||
|
参考dataset.py中的get_mask_pinyin方法。
|
||||||
|
"""
|
||||||
|
mask_pinyin = []
|
||||||
|
for i in range(len(text)):
|
||||||
|
if not self.query_engine or not self.query_engine.is_chinese_char(text[i]):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
py = np.random.choice(
|
||||||
|
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
||||||
|
p=self.py_style_weight,
|
||||||
|
)
|
||||||
|
if py == "":
|
||||||
|
py = pinyin_list[i][0]
|
||||||
|
mask_pinyin.append(py)
|
||||||
|
return len(mask_pinyin), mask_pinyin
|
||||||
|
|
||||||
|
def create_sample_from_text(
|
||||||
|
self, text: str, position: Optional[int] = None
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
从文本创建单个样本,参考dataset.py的样本生成逻辑。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 输入文本
|
||||||
|
position: 预测位置(汉字位置),如果为None则随机选择
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
样本字典,包含模型所需的所有字段
|
||||||
|
"""
|
||||||
|
# 确保文本长度至少为1
|
||||||
|
if not text:
|
||||||
|
raise ValueError("文本不能为空")
|
||||||
|
|
||||||
|
# 生成拼音列表
|
||||||
|
pinyin_list = self.generate_pinyin(text)
|
||||||
|
|
||||||
|
# 找到所有汉字位置
|
||||||
|
chinese_positions = [
|
||||||
|
i
|
||||||
|
for i, char in enumerate(text)
|
||||||
|
if self.query_engine and self.query_engine.is_chinese_char(char)
|
||||||
|
]
|
||||||
|
|
||||||
|
if not chinese_positions:
|
||||||
|
raise ValueError("文本中没有汉字")
|
||||||
|
|
||||||
|
# 随机选择预测位置
|
||||||
|
if position is None:
|
||||||
|
position = random.choice(chinese_positions)
|
||||||
|
elif position >= len(text) or position < 0:
|
||||||
|
raise ValueError(f"位置 {position} 超出文本范围")
|
||||||
|
elif not self.query_engine or not self.query_engine.is_chinese_char(
|
||||||
|
text[position]
|
||||||
|
):
|
||||||
|
raise ValueError(f"位置 {position} 不是汉字")
|
||||||
|
|
||||||
|
i = position
|
||||||
|
|
||||||
|
# part1: 光标前文本(最多48个字符)
|
||||||
|
if i < 48:
|
||||||
|
part1 = text[0:i]
|
||||||
|
else:
|
||||||
|
part1 = text[i - 48 : i]
|
||||||
|
|
||||||
|
# part2: 拼音输入(随机长度1-8,高斯分布)
|
||||||
|
pinyin_len_probs = [0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
|
||||||
|
pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs)
|
||||||
|
py_end = min(i + pinyin_len, len(text))
|
||||||
|
|
||||||
|
# 获取拼音
|
||||||
|
pinyin_len_actual, part2_pinyin = self.get_mask_pinyin(
|
||||||
|
text[i:py_end], pinyin_list[i:py_end]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加分隔符
|
||||||
|
split_char = np.random.choice(["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02])
|
||||||
|
part2 = split_char.join(part2_pinyin)
|
||||||
|
|
||||||
|
# 转换为拼音ID
|
||||||
|
pinyin_ids = text_to_pinyin_ids(part2)
|
||||||
|
len_py = len(pinyin_ids)
|
||||||
|
if len_py < 24:
|
||||||
|
pinyin_ids.extend([0] * (24 - len_py))
|
||||||
|
else:
|
||||||
|
pinyin_ids = pinyin_ids[:24]
|
||||||
|
pinyin_ids_tensor = torch.tensor(pinyin_ids, dtype=torch.long)
|
||||||
|
|
||||||
|
# part3: 光标后文本(70%概率为空)
|
||||||
|
part3 = ""
|
||||||
|
if random.random() > 0.7:
|
||||||
|
part3 = text[
|
||||||
|
i + pinyin_len_actual : i
|
||||||
|
+ pinyin_len_actual
|
||||||
|
+ np.random.choice(range(1, 17))
|
||||||
|
]
|
||||||
|
|
||||||
|
# part4: 上下文提示(50%概率为空)
|
||||||
|
part4 = ""
|
||||||
|
if random.random() > 0.5:
|
||||||
|
num_strings = random.randint(1, 5)
|
||||||
|
string_list = []
|
||||||
|
for _ in range(num_strings):
|
||||||
|
start_pos = random.randint(0, len(text) - 1)
|
||||||
|
x = random.randint(2, 6)
|
||||||
|
end_pos = min(start_pos + x + 1, len(text))
|
||||||
|
string_list.append(text[start_pos:end_pos])
|
||||||
|
part4 = "|".join(string_list)
|
||||||
|
|
||||||
|
# 标签:预测字符的ID
|
||||||
|
labels = []
|
||||||
|
try:
|
||||||
|
labels = [
|
||||||
|
self.query_engine.get_char_info_by_char_pinyin(c, p).id
|
||||||
|
for c, p in zip(
|
||||||
|
text[i : i + pinyin_len_actual],
|
||||||
|
pinyin_list[i : i + pinyin_len_actual],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
except (AttributeError, TypeError) as e:
|
||||||
|
# 如果查询失败,使用默认ID
|
||||||
|
print(f"⚠️ 获取标签失败: {e}")
|
||||||
|
labels = [0] * pinyin_len_actual
|
||||||
|
|
||||||
|
# 历史槽位:当前预测字符之前已确认的字符
|
||||||
|
# 从位置i之前的字符中提取历史槽位(最多8个)
|
||||||
|
history_slot_ids = []
|
||||||
|
if self.query_engine:
|
||||||
|
# 收集i之前的汉字字符(最多8个)
|
||||||
|
for j in range(i - 1, max(-1, i - 9), -1):
|
||||||
|
if j < 0:
|
||||||
|
break
|
||||||
|
char = text[j]
|
||||||
|
if self.query_engine.is_chinese_char(char):
|
||||||
|
# 获取该字符的第一个拼音变体的ID
|
||||||
|
results = self.query_engine.query_by_char(char, limit=1)
|
||||||
|
if results:
|
||||||
|
history_slot_ids.append(results[0][0])
|
||||||
|
else:
|
||||||
|
# 如果查询失败,使用0
|
||||||
|
history_slot_ids.append(0)
|
||||||
|
else:
|
||||||
|
# 非汉字字符,使用0
|
||||||
|
history_slot_ids.append(0)
|
||||||
|
if len(history_slot_ids) >= 8:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果历史槽位为空,随机添加一些常见字符ID(模拟已确认的输入)
|
||||||
|
if not history_slot_ids and random.random() < 0.5:
|
||||||
|
# 随机选择1-3个常见字符ID(范围1-1000)
|
||||||
|
num_slots = random.randint(1, 3)
|
||||||
|
for _ in range(num_slots):
|
||||||
|
history_slot_ids.append(random.randint(1, 1000))
|
||||||
|
|
||||||
|
# 填充到8个槽位
|
||||||
|
if len(history_slot_ids) < 8:
|
||||||
|
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
|
||||||
|
else:
|
||||||
|
history_slot_ids = history_slot_ids[:8]
|
||||||
|
|
||||||
|
# Tokenize输入
|
||||||
|
encoded = self.tokenizer(
|
||||||
|
f"{part4}|{part1}",
|
||||||
|
part3,
|
||||||
|
max_length=128,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_token_type_ids=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算有效槽位数量(非零元素)
|
||||||
|
valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id != 0)
|
||||||
|
|
||||||
|
# 拼音输入长度(字符数)
|
||||||
|
pinyin_input_length = len(part2)
|
||||||
|
|
||||||
|
# 构建样本
|
||||||
|
sample = {
|
||||||
|
"input_ids": encoded["input_ids"],
|
||||||
|
"token_type_ids": encoded["token_type_ids"],
|
||||||
|
"attention_mask": encoded["attention_mask"],
|
||||||
|
"history_slot_ids": torch.tensor(
|
||||||
|
history_slot_ids, dtype=torch.long
|
||||||
|
).unsqueeze(0),
|
||||||
|
"prefix": f"{part4}^{part1}",
|
||||||
|
"suffix": part3,
|
||||||
|
"pinyin": part2,
|
||||||
|
"pinyin_ids": pinyin_ids_tensor.unsqueeze(0),
|
||||||
|
"true_labels": labels, # 真实标签(多个字符)
|
||||||
|
"position": i,
|
||||||
|
"target_char": text[i] if i < len(text) else "",
|
||||||
|
"valid_slot_count": valid_slot_count, # 有效槽位数量
|
||||||
|
"pinyin_input_length": pinyin_input_length, # 拼音输入长度
|
||||||
|
}
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def inference(self, sample: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
执行模型推理。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(logits, probs)
|
||||||
|
"""
|
||||||
|
# 准备输入张量并移动到设备
|
||||||
|
input_ids = sample["input_ids"].to(self.device)
|
||||||
|
token_type_ids = sample["token_type_ids"].to(self.device)
|
||||||
|
attention_mask = sample["attention_mask"].to(self.device)
|
||||||
|
pinyin_ids = sample["pinyin_ids"].to(self.device)
|
||||||
|
history_slot_ids = sample["history_slot_ids"].to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.device.type == "cuda":
|
||||||
|
with torch.autocast(device_type="cuda"):
|
||||||
|
logits = self.model(
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
attention_mask,
|
||||||
|
pinyin_ids,
|
||||||
|
history_slot_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.model(
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
attention_mask,
|
||||||
|
pinyin_ids,
|
||||||
|
history_slot_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
return logits, probs
|
||||||
|
|
||||||
|
def analyze_probability_distribution(
|
||||||
|
self, probs: torch.Tensor, threshold_low_var: float = 0.01
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
分析概率分布,检测低方差情况。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs: 概率张量,形状为 (1, vocab_size)
|
||||||
|
threshold_low_var: 低方差阈值,最高概率低于此值则标记为低方差
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分析结果字典
|
||||||
|
"""
|
||||||
|
probs_np = probs.cpu().numpy().flatten()
|
||||||
|
|
||||||
|
# 基本统计
|
||||||
|
max_prob = np.max(probs_np)
|
||||||
|
min_prob = np.min(probs_np)
|
||||||
|
mean_prob = np.mean(probs_np)
|
||||||
|
std_prob = np.std(probs_np)
|
||||||
|
|
||||||
|
# 熵
|
||||||
|
entropy = -np.sum(probs_np * np.log(probs_np + 1e-12))
|
||||||
|
|
||||||
|
# 低方差检测
|
||||||
|
low_variance = max_prob < threshold_low_var
|
||||||
|
|
||||||
|
# Top-k 分析
|
||||||
|
top_k = 20
|
||||||
|
top_indices = np.argsort(probs_np)[-top_k:][::-1]
|
||||||
|
top_probs = probs_np[top_indices]
|
||||||
|
|
||||||
|
# 检查前5个概率是否接近均匀
|
||||||
|
top5_probs = top_probs[:5]
|
||||||
|
top5_ratio = (
|
||||||
|
top5_probs[0] / (top5_probs[1] + 1e-12) if len(top5_probs) > 1 else 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
analysis = {
|
||||||
|
"max_prob": max_prob,
|
||||||
|
"min_prob": min_prob,
|
||||||
|
"mean_prob": mean_prob,
|
||||||
|
"std_prob": std_prob,
|
||||||
|
"entropy": entropy,
|
||||||
|
"low_variance": low_variance,
|
||||||
|
"top_indices": top_indices,
|
||||||
|
"top_probs": top_probs,
|
||||||
|
"top5_ratio": top5_ratio,
|
||||||
|
"vocab_size": len(probs_np),
|
||||||
|
}
|
||||||
|
|
||||||
|
return analysis
|
||||||
|
|
||||||
|
def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict:
|
||||||
|
"""
|
||||||
|
评估单个样本。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
评估结果字典
|
||||||
|
"""
|
||||||
|
# 执行推理
|
||||||
|
logits, probs = self.inference(sample)
|
||||||
|
|
||||||
|
# 分析概率分布
|
||||||
|
analysis = self.analyze_probability_distribution(probs)
|
||||||
|
|
||||||
|
# 获取真实标签
|
||||||
|
true_labels = sample.get("true_labels", [])
|
||||||
|
target_char = sample.get("target_char", "")
|
||||||
|
|
||||||
|
# 检查预测是否正确(只检查第一个字符)
|
||||||
|
correct = False
|
||||||
|
pred_top_idx = analysis["top_indices"][0]
|
||||||
|
if true_labels and len(true_labels) > 0:
|
||||||
|
correct = pred_top_idx == true_labels[0]
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
if print_details:
|
||||||
|
# 获取槽位信息
|
||||||
|
valid_slot_count = sample.get("valid_slot_count", 0)
|
||||||
|
pinyin_input_length = sample.get("pinyin_input_length", 0)
|
||||||
|
|
||||||
|
# 简化输出格式(在同一行继续输出)
|
||||||
|
print(f" {target_char}", end="")
|
||||||
|
if true_labels:
|
||||||
|
print(f"(ID:{true_labels[0]})", end="")
|
||||||
|
print(
|
||||||
|
f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}",
|
||||||
|
end="",
|
||||||
|
)
|
||||||
|
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="")
|
||||||
|
|
||||||
|
# 关键概率信息
|
||||||
|
print(f" | 最高概率:{analysis['max_prob']:.3f}", end="")
|
||||||
|
if analysis["low_variance"]:
|
||||||
|
print(f" ⚠️", end="")
|
||||||
|
|
||||||
|
# Top-2 预测
|
||||||
|
top_strs = []
|
||||||
|
for j in range(min(2, len(analysis["top_indices"]))):
|
||||||
|
idx = analysis["top_indices"][j]
|
||||||
|
prob = analysis["top_probs"][j]
|
||||||
|
# 处理特殊ID
|
||||||
|
if idx == 0:
|
||||||
|
char = "[空]"
|
||||||
|
else:
|
||||||
|
char_info = (
|
||||||
|
self.query_engine.query_by_id(idx)
|
||||||
|
if self.query_engine
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
char = char_info.char if char_info else f"[{idx}]"
|
||||||
|
is_correct = (
|
||||||
|
(true_labels and idx == true_labels[0]) if true_labels else False
|
||||||
|
)
|
||||||
|
correct_mark = "✓" if is_correct else ""
|
||||||
|
top_strs.append(f"{char}{correct_mark}({prob:.2f})")
|
||||||
|
print(f" | Top-2: {' '.join(top_strs)}", end="")
|
||||||
|
|
||||||
|
print(f" | 结果:{'✓' if correct else '✗'}")
|
||||||
|
|
||||||
|
# 返回评估结果
|
||||||
|
result = {
|
||||||
|
"sample": sample,
|
||||||
|
"analysis": analysis,
|
||||||
|
"correct": correct,
|
||||||
|
"target_char": target_char,
|
||||||
|
"true_label": true_labels[0] if true_labels else None,
|
||||||
|
"predicted_label": pred_top_idx,
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
评估给定文本,生成多个样本进行评估。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 输入文本
|
||||||
|
num_samples: 生成的样本数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
评估结果列表
|
||||||
|
"""
|
||||||
|
# 限制文本长度
|
||||||
|
if len(text) > 300:
|
||||||
|
start = random.randint(0, len(text) - 300)
|
||||||
|
text = text[start : start + 300]
|
||||||
|
print(f"📝 文本长度超过300,随机截取300字符: {text[:50]}...")
|
||||||
|
else:
|
||||||
|
print(f"📝 文本长度: {len(text)} 字符")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
low_variance_samples = []
|
||||||
|
|
||||||
|
for sample_idx in range(num_samples):
|
||||||
|
print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建样本
|
||||||
|
sample = self.create_sample_from_text(text)
|
||||||
|
|
||||||
|
# 评估样本
|
||||||
|
result = self.evaluate_sample(sample, print_details=True)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# 记录低方差样本
|
||||||
|
if result["analysis"]["low_variance"]:
|
||||||
|
low_variance_samples.append(
|
||||||
|
{
|
||||||
|
"sample_idx": sample_idx,
|
||||||
|
"max_prob": result["analysis"]["max_prob"],
|
||||||
|
"entropy": result["analysis"]["entropy"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 样本生成或评估失败: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 打印汇总统计
|
||||||
|
if results:
|
||||||
|
self.print_summary(results, low_variance_samples)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]):
|
||||||
|
"""打印评估汇总"""
|
||||||
|
print(f"\n--- 汇总 ({len(results)}样本) ---")
|
||||||
|
|
||||||
|
# 计算准确率
|
||||||
|
correct_count = sum(1 for r in results if r["correct"])
|
||||||
|
accuracy = correct_count / len(results) if results else 0
|
||||||
|
print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})")
|
||||||
|
|
||||||
|
# 概率统计
|
||||||
|
max_probs = [r["analysis"]["max_prob"] for r in results]
|
||||||
|
mean_max_prob = np.mean(max_probs) if max_probs else 0
|
||||||
|
print(f"平均最高概率: {mean_max_prob:.4f}")
|
||||||
|
|
||||||
|
# 预测ID:0的比例
|
||||||
|
zero_pred_count = sum(1 for r in results if r.get("predicted_label") == 0)
|
||||||
|
zero_pred_ratio = zero_pred_count / len(results) if results else 0
|
||||||
|
print(f"预测[空]比例: {zero_pred_ratio:.1%} ({zero_pred_count}/{len(results)})")
|
||||||
|
|
||||||
|
# 平均熵
|
||||||
|
entropies = [r["analysis"]["entropy"] for r in results]
|
||||||
|
mean_entropy = np.mean(entropies) if entropies else 0
|
||||||
|
print(f"平均熵: {mean_entropy:.2f}")
|
||||||
|
|
||||||
|
# 槽位和拼音长度统计
|
||||||
|
if results and "sample" in results[0]:
|
||||||
|
valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results]
|
||||||
|
pinyin_lengths = [
|
||||||
|
r["sample"].get("pinyin_input_length", 0) for r in results
|
||||||
|
]
|
||||||
|
avg_slots = np.mean(valid_slots) if valid_slots else 0
|
||||||
|
avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0
|
||||||
|
print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}")
|
||||||
|
|
||||||
|
# 低方差样本
|
||||||
|
if low_variance_samples:
|
||||||
|
print(f"低方差样本: {len(low_variance_samples)}个")
|
||||||
|
if len(low_variance_samples) <= 5:
|
||||||
|
for lv in low_variance_samples:
|
||||||
|
print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}")
|
||||||
|
else:
|
||||||
|
for lv in low_variance_samples[:3]:
|
||||||
|
print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}")
|
||||||
|
print(f" ... 还有{len(low_variance_samples) - 3}个")
|
||||||
|
else:
|
||||||
|
print(f"低方差样本: 0个")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="评估模型在文本上的表现")
|
||||||
|
parser.add_argument("--text-file", type=str, required=True, help="文本文件路径")
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
default="/home/songsenand/下载/best_model.pt",
|
||||||
|
help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-samples", type=int, default=20, help="评估样本数量 (默认: 20)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
choices=["auto", "cpu", "cuda"],
|
||||||
|
help="推理设备 (默认: auto)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 选择设备
|
||||||
|
if args.device == "auto":
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
else:
|
||||||
|
device = args.device
|
||||||
|
|
||||||
|
# 读取文本文件
|
||||||
|
try:
|
||||||
|
with open(args.text_file, "r", encoding="utf-8") as f:
|
||||||
|
text = f.read().strip()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 无法读取文本文件: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
print("❌ 文本文件为空")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"📄 读取文本文件: {args.text_file}")
|
||||||
|
print(f"📏 原始文本长度: {len(text)} 字符")
|
||||||
|
print(f"🔧 使用设备: {device}")
|
||||||
|
print(f"📦 模型checkpoint: {args.checkpoint}")
|
||||||
|
print(f"🔬 评估样本数: {args.num_samples}")
|
||||||
|
|
||||||
|
# 初始化评估器
|
||||||
|
try:
|
||||||
|
evaluator = TextEvaluator(args.checkpoint, device)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 初始化评估器失败: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 执行评估
|
||||||
|
evaluator.evaluate_text(text, args.num_samples)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -9,14 +9,17 @@ from typing import Optional, Union
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from flask import Flask, render_template, request, jsonify, send_from_directory
|
from flask import Flask, render_template, request, jsonify, send_from_directory
|
||||||
|
|
||||||
app = Flask(__name__,
|
app = Flask(
|
||||||
template_folder=Path(__file__).parent / 'templates',
|
__name__,
|
||||||
static_folder=Path(__file__).parent / 'static')
|
template_folder=Path(__file__).parent / "templates",
|
||||||
|
static_folder=Path(__file__).parent / "static",
|
||||||
|
)
|
||||||
|
|
||||||
# 全局配置
|
# 全局配置
|
||||||
DEFAULT_STATUS_FILE = "./output/training_status.json"
|
DEFAULT_STATUS_FILE = "./output/training_status.json"
|
||||||
DEFAULT_PORT = 8501
|
DEFAULT_PORT = 8501
|
||||||
DEFAULT_HOST = "0.0.0.0"
|
DEFAULT_HOST = "0.0.0.0"
|
||||||
|
ALLOWED_DATA_SOURCE_TYPES = ["local", "remote"] # 允许的数据源类型
|
||||||
|
|
||||||
|
|
||||||
def load_training_data(file_path: str, is_url: bool = False) -> list:
|
def load_training_data(file_path: str, is_url: bool = False) -> list:
|
||||||
|
|
@ -73,72 +76,146 @@ def validate_and_clean_data(data: list) -> list:
|
||||||
|
|
||||||
cleaned = []
|
cleaned = []
|
||||||
for item in data:
|
for item in data:
|
||||||
if isinstance(item, dict) and ('step' in item or 'train/loss' in item or 'timestamp' in item):
|
if isinstance(item, dict) and (
|
||||||
|
"step" in item or "train/loss" in item or "timestamp" in item
|
||||||
|
):
|
||||||
cleaned.append(item)
|
cleaned.append(item)
|
||||||
|
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
@app.route('/')
|
@app.route("/")
|
||||||
def index():
|
def index():
|
||||||
"""主页"""
|
"""主页"""
|
||||||
data_source_type = request.args.get('data_source_type', 'local')
|
# 根据允许的数据源类型设置默认值
|
||||||
data_source = request.args.get('data_source', DEFAULT_STATUS_FILE)
|
if (
|
||||||
refresh_interval = int(request.args.get('refresh_interval', 5))
|
"local" in ALLOWED_DATA_SOURCE_TYPES
|
||||||
|
and "remote" not in ALLOWED_DATA_SOURCE_TYPES
|
||||||
|
):
|
||||||
|
default_type = "local"
|
||||||
|
default_source = DEFAULT_STATUS_FILE
|
||||||
|
elif (
|
||||||
|
"remote" in ALLOWED_DATA_SOURCE_TYPES
|
||||||
|
and "local" not in ALLOWED_DATA_SOURCE_TYPES
|
||||||
|
):
|
||||||
|
default_type = "remote"
|
||||||
|
default_source = ""
|
||||||
|
else:
|
||||||
|
default_type = "local"
|
||||||
|
default_source = DEFAULT_STATUS_FILE
|
||||||
|
|
||||||
return render_template('index.html',
|
data_source_type = request.args.get("data_source_type", default_type)
|
||||||
data_source_type=data_source_type,
|
data_source = request.args.get("data_source", default_source)
|
||||||
data_source=data_source,
|
refresh_interval = int(request.args.get("refresh_interval", 5))
|
||||||
refresh_interval=refresh_interval)
|
|
||||||
|
return render_template(
|
||||||
|
"index.html",
|
||||||
|
data_source_type=data_source_type,
|
||||||
|
data_source=data_source,
|
||||||
|
refresh_interval=refresh_interval,
|
||||||
|
allowed_data_source_types=ALLOWED_DATA_SOURCE_TYPES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/status')
|
@app.route("/api/status")
|
||||||
def api_status():
|
def api_status():
|
||||||
"""API接口,返回训练状态数据"""
|
"""API接口,返回训练状态数据"""
|
||||||
data_source_type = request.args.get('data_source_type', 'local')
|
# 根据允许的数据源类型设置默认值
|
||||||
data_source = request.args.get('data_source', DEFAULT_STATUS_FILE)
|
if (
|
||||||
|
"local" in ALLOWED_DATA_SOURCE_TYPES
|
||||||
|
and "remote" not in ALLOWED_DATA_SOURCE_TYPES
|
||||||
|
):
|
||||||
|
default_type = "local"
|
||||||
|
elif (
|
||||||
|
"remote" in ALLOWED_DATA_SOURCE_TYPES
|
||||||
|
and "local" not in ALLOWED_DATA_SOURCE_TYPES
|
||||||
|
):
|
||||||
|
default_type = "remote"
|
||||||
|
else:
|
||||||
|
default_type = "local"
|
||||||
|
|
||||||
|
data_source_type = request.args.get("data_source_type", default_type)
|
||||||
|
|
||||||
|
# 检查请求的数据源类型是否被允许
|
||||||
|
if data_source_type not in ALLOWED_DATA_SOURCE_TYPES:
|
||||||
|
return jsonify(
|
||||||
|
{
|
||||||
|
"error": f"数据源类型 '{data_source_type}' 不被允许。允许的类型: {ALLOWED_DATA_SOURCE_TYPES}",
|
||||||
|
"data": [],
|
||||||
|
}
|
||||||
|
), 400
|
||||||
|
|
||||||
|
# 根据数据源类型设置数据源路径
|
||||||
|
if data_source_type == "local":
|
||||||
|
# 本地文件模式:固定使用DEFAULT_STATUS_FILE,忽略用户提供的路径
|
||||||
|
data_source = DEFAULT_STATUS_FILE
|
||||||
|
else:
|
||||||
|
# 远程URL模式:使用用户提供的URL
|
||||||
|
data_source = request.args.get("data_source", "")
|
||||||
|
if not data_source:
|
||||||
|
return jsonify(
|
||||||
|
{"error": "远程数据源模式下必须提供data_source参数(URL)", "data": []}
|
||||||
|
), 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_url = data_source_type == 'remote'
|
is_url = data_source_type == "remote"
|
||||||
raw_data = load_training_data(data_source, is_url)
|
raw_data = load_training_data(data_source, is_url)
|
||||||
cleaned_data = validate_and_clean_data(raw_data)
|
cleaned_data = validate_and_clean_data(raw_data)
|
||||||
|
|
||||||
return jsonify(cleaned_data)
|
return jsonify(cleaned_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({
|
return jsonify({"error": str(e), "data": []}), 500
|
||||||
'error': str(e),
|
|
||||||
'data': []
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/info')
|
@app.route("/api/info")
|
||||||
def api_info():
|
def api_info():
|
||||||
"""获取服务器信息"""
|
"""获取服务器信息"""
|
||||||
return jsonify({
|
return jsonify(
|
||||||
'server_time': datetime.now().isoformat(),
|
{
|
||||||
'default_status_file': DEFAULT_STATUS_FILE,
|
"server_time": datetime.now().isoformat(),
|
||||||
'python_version': sys.version,
|
"default_status_file": DEFAULT_STATUS_FILE,
|
||||||
'working_directory': os.getcwd()
|
"python_version": sys.version,
|
||||||
})
|
"working_directory": os.getcwd(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def start_flask_server(host: str, port: int, debug: bool = False, use_wsgi: bool = False):
|
def start_flask_server(
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
debug: bool = False,
|
||||||
|
use_wsgi: bool = False,
|
||||||
|
status_file: Optional[str] = None,
|
||||||
|
allowed_data_sources: Optional[list] = None,
|
||||||
|
):
|
||||||
"""启动Flask服务器"""
|
"""启动Flask服务器"""
|
||||||
from flask import cli
|
from flask import cli
|
||||||
|
|
||||||
# 禁用Flask的默认启动消息
|
# 禁用Flask的默认启动消息
|
||||||
cli.show_server_banner = lambda *args: None
|
cli.show_server_banner = lambda *args: None
|
||||||
|
|
||||||
|
# 更新默认状态文件
|
||||||
|
global DEFAULT_STATUS_FILE
|
||||||
|
if status_file is not None:
|
||||||
|
DEFAULT_STATUS_FILE = status_file
|
||||||
|
|
||||||
|
# 设置允许的数据源类型
|
||||||
|
global ALLOWED_DATA_SOURCE_TYPES
|
||||||
|
if allowed_data_sources is not None:
|
||||||
|
ALLOWED_DATA_SOURCE_TYPES = allowed_data_sources
|
||||||
|
|
||||||
print(f"🚀 启动训练监控服务 ({'Waitress WSGI' if use_wsgi else 'Flask'})...")
|
print(f"🚀 启动训练监控服务 ({'Waitress WSGI' if use_wsgi else 'Flask'})...")
|
||||||
print(f"📁 默认状态文件: {os.path.abspath(DEFAULT_STATUS_FILE)}")
|
print(f"📁 默认状态文件: {os.path.abspath(DEFAULT_STATUS_FILE)}")
|
||||||
print(f"🌐 监控地址: http://{host}:{port}")
|
print(f"🌐 监控地址: http://{host}:{port}")
|
||||||
print(f"📊 API接口: http://{host}:{port}/api/status")
|
print(f"📊 API接口: http://{host}:{port}/api/status")
|
||||||
|
print(f"📋 允许的数据源: {', '.join(ALLOWED_DATA_SOURCE_TYPES)}")
|
||||||
print("\n按 Ctrl+C 停止监控服务\n")
|
print("\n按 Ctrl+C 停止监控服务\n")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if use_wsgi:
|
if use_wsgi:
|
||||||
try:
|
try:
|
||||||
import waitress
|
import waitress
|
||||||
|
|
||||||
waitress.serve(app, host=host, port=port, threads=4)
|
waitress.serve(app, host=host, port=port, threads=4)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("⚠️ waitress未安装,回退到Flask开发服务器")
|
print("⚠️ waitress未安装,回退到Flask开发服务器")
|
||||||
|
|
@ -162,11 +239,24 @@ def main():
|
||||||
global DEFAULT_STATUS_FILE
|
global DEFAULT_STATUS_FILE
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="AI模型训练监控工具 - Flask版本")
|
parser = argparse.ArgumentParser(description="AI模型训练监控工具 - Flask版本")
|
||||||
parser.add_argument('--host', default=DEFAULT_HOST, help=f'监控服务主机地址 (默认: {DEFAULT_HOST})')
|
parser.add_argument(
|
||||||
parser.add_argument('--port', type=int, default=DEFAULT_PORT, help=f'监控服务端口号 (默认: {DEFAULT_PORT})')
|
"--host", default=DEFAULT_HOST, help=f"监控服务主机地址 (默认: {DEFAULT_HOST})"
|
||||||
parser.add_argument('--debug', action='store_true', help='启用调试模式')
|
)
|
||||||
parser.add_argument('--use-wsgi', action='store_true', help='使用Waitress WSGI服务器替代Flask开发服务器')
|
parser.add_argument(
|
||||||
parser.add_argument('--status-file', default=DEFAULT_STATUS_FILE, help='默认状态文件路径')
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_PORT,
|
||||||
|
help=f"监控服务端口号 (默认: {DEFAULT_PORT})",
|
||||||
|
)
|
||||||
|
parser.add_argument("--debug", action="store_true", help="启用调试模式")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-wsgi",
|
||||||
|
action="store_true",
|
||||||
|
help="使用Waitress WSGI服务器替代Flask开发服务器",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--status-file", default=DEFAULT_STATUS_FILE, help="默认状态文件路径"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -176,5 +266,5 @@ def main():
|
||||||
return start_flask_server(args.host, args.port, args.debug, args.use_wsgi)
|
return start_flask_server(args.host, args.port, args.debug, args.use_wsgi)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
sys.exit(main())
|
sys.exit(main())
|
||||||
|
|
@ -18,7 +18,11 @@ app = typer.Typer(help="AI模型训练监控工具 - 基于JSON旁路记录法
|
||||||
|
|
||||||
# 尝试导入Flask,如果失败则提供友好错误提示
|
# 尝试导入Flask,如果失败则提供友好错误提示
|
||||||
try:
|
try:
|
||||||
from .flask_monitor import start_flask_server, DEFAULT_STATUS_FILE as FLASK_DEFAULT_STATUS_FILE
|
from .flask_monitor import (
|
||||||
|
start_flask_server,
|
||||||
|
DEFAULT_STATUS_FILE as FLASK_DEFAULT_STATUS_FILE,
|
||||||
|
)
|
||||||
|
|
||||||
FLASK_AVAILABLE = True
|
FLASK_AVAILABLE = True
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
FLASK_AVAILABLE = False
|
FLASK_AVAILABLE = False
|
||||||
|
|
@ -41,6 +45,7 @@ def start_flask_monitor_server(
|
||||||
host: str,
|
host: str,
|
||||||
open_browser: bool,
|
open_browser: bool,
|
||||||
use_wsgi: bool = False,
|
use_wsgi: bool = False,
|
||||||
|
allowed_data_sources: Optional[list] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
启动Flask监控服务器
|
启动Flask监控服务器
|
||||||
|
|
@ -51,6 +56,7 @@ def start_flask_monitor_server(
|
||||||
host: 主机地址
|
host: 主机地址
|
||||||
open_browser: 是否自动打开浏览器
|
open_browser: 是否自动打开浏览器
|
||||||
use_wsgi: 是否使用Waitress WSGI服务器替代Flask开发服务器
|
use_wsgi: 是否使用Waitress WSGI服务器替代Flask开发服务器
|
||||||
|
allowed_data_sources: 允许的数据源类型列表,如["local"]或["remote"],默认None表示两者都允许
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
进程退出码
|
进程退出码
|
||||||
|
|
@ -62,14 +68,15 @@ def start_flask_monitor_server(
|
||||||
typer.echo("或在pyproject.toml中添加flask依赖")
|
typer.echo("或在pyproject.toml中添加flask依赖")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
# 设置环境变量,传递状态文件路径
|
|
||||||
os.environ["TRAINING_STATUS_FILE"] = os.path.abspath(status_file)
|
|
||||||
|
|
||||||
server_type = "Waitress WSGI" if use_wsgi else "Flask"
|
server_type = "Waitress WSGI" if use_wsgi else "Flask"
|
||||||
typer.echo(f"🚀 启动训练监控服务 ({server_type}版本)...")
|
typer.echo(f"🚀 启动训练监控服务 ({server_type}版本)...")
|
||||||
typer.echo(f"📁 状态文件: {os.path.abspath(status_file)}")
|
typer.echo(f"📁 状态文件: {os.path.abspath(status_file)}")
|
||||||
typer.echo(f"🌐 监控地址: http://{host}:{port}")
|
typer.echo(f"🌐 监控地址: http://{host}:{port}")
|
||||||
typer.echo(f"📊 API接口: http://{host}:{port}/api/status")
|
typer.echo(f"📊 API接口: http://{host}:{port}/api/status")
|
||||||
|
if allowed_data_sources is None:
|
||||||
|
typer.echo(f"📋 允许的数据源: local, remote")
|
||||||
|
else:
|
||||||
|
typer.echo(f"📋 允许的数据源: {', '.join(allowed_data_sources)}")
|
||||||
|
|
||||||
if open_browser:
|
if open_browser:
|
||||||
# 等待服务器启动后打开浏览器
|
# 等待服务器启动后打开浏览器
|
||||||
|
|
@ -81,7 +88,15 @@ def start_flask_monitor_server(
|
||||||
try:
|
try:
|
||||||
# 导入并启动Flask服务器
|
# 导入并启动Flask服务器
|
||||||
from .flask_monitor import start_flask_server
|
from .flask_monitor import start_flask_server
|
||||||
return start_flask_server(host=host, port=port, debug=False, use_wsgi=use_wsgi)
|
|
||||||
|
return start_flask_server(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
debug=False,
|
||||||
|
use_wsgi=use_wsgi,
|
||||||
|
status_file=status_file,
|
||||||
|
allowed_data_sources=allowed_data_sources,
|
||||||
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
typer.echo("\n🛑 监控服务已停止")
|
typer.echo("\n🛑 监控服务已停止")
|
||||||
return 0
|
return 0
|
||||||
|
|
@ -101,7 +116,9 @@ def monitor_training(
|
||||||
port: int = typer.Option(8501, "--port", "-p", help="监控服务端口号"),
|
port: int = typer.Option(8501, "--port", "-p", help="监控服务端口号"),
|
||||||
host: str = typer.Option("0.0.0.0", "--host", help="监控服务主机地址"),
|
host: str = typer.Option("0.0.0.0", "--host", help="监控服务主机地址"),
|
||||||
open_browser: bool = typer.Option(False, "--open-browser", help="自动打开浏览器"),
|
open_browser: bool = typer.Option(False, "--open-browser", help="自动打开浏览器"),
|
||||||
use_wsgi: bool = typer.Option(False, "--use-wsgi", help="使用Waitress WSGI服务器替代Flask开发服务器"),
|
use_wsgi: bool = typer.Option(
|
||||||
|
False, "--use-wsgi", help="使用Waitress WSGI服务器替代Flask开发服务器"
|
||||||
|
),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
启动AI模型训练监控服务 (Flask版本)
|
启动AI模型训练监控服务 (Flask版本)
|
||||||
|
|
@ -131,6 +148,13 @@ def monitor_training(
|
||||||
typer.echo("或在pyproject.toml中添加flask依赖")
|
typer.echo("或在pyproject.toml中添加flask依赖")
|
||||||
raise typer.Exit(code=1)
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
# 根据是否提供了-s参数决定允许的数据源
|
||||||
|
default_status_file = "./output/training_status.json"
|
||||||
|
if status_file == default_status_file:
|
||||||
|
allowed_data_sources = ["remote"] # 未提供-s,只允许远程
|
||||||
|
else:
|
||||||
|
allowed_data_sources = ["local"] # 提供了-s,只允许本地
|
||||||
|
|
||||||
# 启动Flask服务器
|
# 启动Flask服务器
|
||||||
return_code = start_flask_monitor_server(
|
return_code = start_flask_monitor_server(
|
||||||
status_file=status_file,
|
status_file=status_file,
|
||||||
|
|
@ -138,6 +162,7 @@ def monitor_training(
|
||||||
host=host,
|
host=host,
|
||||||
open_browser=open_browser,
|
open_browser=open_browser,
|
||||||
use_wsgi=use_wsgi,
|
use_wsgi=use_wsgi,
|
||||||
|
allowed_data_sources=allowed_data_sources,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise typer.Exit(code=return_code)
|
raise typer.Exit(code=return_code)
|
||||||
|
|
|
||||||
|
|
@ -273,6 +273,7 @@
|
||||||
</div>
|
</div>
|
||||||
{% endblock %} {% block extra_js %}
|
{% endblock %} {% block extra_js %}
|
||||||
<script>
|
<script>
|
||||||
|
const allowedDataSourceTypes = {{ allowed_data_source_types|tojson }};
|
||||||
let currentData = [];
|
let currentData = [];
|
||||||
let refreshTimer = null;
|
let refreshTimer = null;
|
||||||
let currentRefreshInterval = {{ refresh_interval }};
|
let currentRefreshInterval = {{ refresh_interval }};
|
||||||
|
|
@ -287,6 +288,36 @@
|
||||||
const refreshIntervalSelect = document.getElementById('refreshIntervalSelect');
|
const refreshIntervalSelect = document.getElementById('refreshIntervalSelect');
|
||||||
const applyConfigBtn = document.getElementById('applyConfigBtn');
|
const applyConfigBtn = document.getElementById('applyConfigBtn');
|
||||||
|
|
||||||
|
// 根据允许的数据源类型调整UI
|
||||||
|
if (allowedDataSourceTypes.length === 1) {
|
||||||
|
// 隐藏选择框,固定数据源类型
|
||||||
|
dataSourceSelect.style.display = 'none';
|
||||||
|
dataSourceSelect.value = allowedDataSourceTypes[0];
|
||||||
|
// 根据类型显示相应的输入字段
|
||||||
|
if (allowedDataSourceTypes[0] === 'local') {
|
||||||
|
localFileField.style.display = 'block';
|
||||||
|
remoteUrlField.style.display = 'none';
|
||||||
|
// 使本地文件路径输入框只读
|
||||||
|
document.getElementById('localFilePath').readOnly = true;
|
||||||
|
} else {
|
||||||
|
localFileField.style.display = 'none';
|
||||||
|
remoteUrlField.style.display = 'block';
|
||||||
|
// 远程URL输入框必须填写,保持可编辑
|
||||||
|
}
|
||||||
|
// 禁用选择框的change事件,因为只有一个选项
|
||||||
|
dataSourceSelect.disabled = true;
|
||||||
|
} else {
|
||||||
|
// 正常情况:两个选项都允许
|
||||||
|
// 根据当前选择显示相应字段
|
||||||
|
if (dataSourceSelect.value === 'local') {
|
||||||
|
localFileField.style.display = 'block';
|
||||||
|
remoteUrlField.style.display = 'none';
|
||||||
|
} else {
|
||||||
|
localFileField.style.display = 'none';
|
||||||
|
remoteUrlField.style.display = 'block';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
dataSourceSelect.addEventListener('change', function() {
|
dataSourceSelect.addEventListener('change', function() {
|
||||||
if (this.value === 'local') {
|
if (this.value === 'local') {
|
||||||
|
|
|
||||||
|
|
@ -1070,7 +1070,7 @@ def train(
|
||||||
max_seq_length=max_seq_len,
|
max_seq_length=max_seq_len,
|
||||||
text_field="text",
|
text_field="text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size=5000,
|
shuffle_buffer_size=50000,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1093,7 +1093,7 @@ def train(
|
||||||
max_seq_length=max_seq_len,
|
max_seq_length=max_seq_len,
|
||||||
text_field="text",
|
text_field="text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(9, 2, 1),
|
||||||
shuffle_buffer_size=50000,
|
shuffle_buffer_size=500000,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
2
test.py
2
test.py
|
|
@ -56,7 +56,7 @@ if len_py < 24:
|
||||||
else:
|
else:
|
||||||
pinyin_ids = pinyin_ids[:24]
|
pinyin_ids = pinyin_ids[:24]
|
||||||
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
|
||||||
masked_labels = [15, 4, 0, 0, 0, 0, 0, 0]
|
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
|
||||||
part3 = "。"
|
part3 = "。"
|
||||||
part4 = "可行|特别|伤害"
|
part4 = "可行|特别|伤害"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
他是一名大学生,在上海读大学。他喜欢编程和机器学习。
|
||||||
|
他的专业是计算机科学,他对人工智能很感兴趣。
|
||||||
|
今天天气很好,我们去公园玩。公园里有很多人在散步。
|
||||||
|
中国的历史文化悠久,有着丰富的文化遗产。
|
||||||
|
北京是中国的首都,上海是中国的经济中心。
|
||||||
Loading…
Reference in New Issue