feat(eval): 添加模型评估脚本,支持文本分析与概率分布检测

This commit is contained in:
songsenand 2026-04-11 00:21:08 +08:00
parent 919d0972e2
commit a0e4d25b2f
7 changed files with 963 additions and 68 deletions

744
eval.py Normal file
View File

@ -0,0 +1,744 @@
#!/usr/bin/env python3
"""
eval.py - 评估模型在给定文本上的表现
使用方法:
python eval.py --text-file path/to/text.txt [--checkpoint path/to/model.pt] [--num-samples N]
功能:
1. 读取文本文件限制长度300随机抽取
2. 参考dataset.py将文本内容转化为模型可接受数据包含part1/part2/part3/part4/槽位历史等信息
3. 使用模型推理参考test.py
4. 错误的打印对于方差极小概率和猜没什么差别的重点标记
"""
import argparse
import random
import re
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import numpy as np
import torch
import torch.nn.functional as F
from modelscope import AutoTokenizer
from pypinyin import lazy_pinyin
from pypinyin.contrib.tone_convert import to_initials
# 添加src目录到路径
sys.path.append("src")
from src.model.model import InputMethodEngine
from src.model.query import QueryEngine
from src.model.dataset import text_to_pinyin_ids
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
class TextEvaluator:
def __init__(
self,
checkpoint_path: str,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
self.device = torch.device(device)
self.checkpoint_path = checkpoint_path
# 加载组件
self.load_model()
self.load_tokenizer()
self.load_query_engine()
# 拼音风格权重与dataset.py一致
self.py_style_weight = np.array([9, 2, 1]) / sum([9, 2, 1])
print(f"✅ 评估器初始化完成 (设备: {self.device})")
# 调试:运行一个固定样本测试
import os
if os.environ.get("EVAL_DEBUG"):
self._debug_sample()
def load_model(self):
"""加载训练好的模型"""
self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
if "model_state_dict" in checkpoint:
self.model.load_state_dict(checkpoint["model_state_dict"])
else:
self.model.load_state_dict(checkpoint)
self.model.eval()
self.model.to(self.device)
print(
f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}"
)
print(f"✅ 模型词汇表大小: {self.model.vocab_size}")
def load_tokenizer(self):
"""加载tokenizer"""
try:
tokenizer_path = (
Path(__file__).parent / "src" / "model" / "assets" / "tokenizer"
)
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
print(f"✅ Tokenizer加载完成词汇表大小: {self.tokenizer.vocab_size}")
except Exception as e:
print(f"⚠️ 无法加载tokenizer: {e}")
print("使用默认的bert-base-chinese tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
def load_query_engine(self):
"""加载查询引擎用于字符-ID转换"""
try:
self.query_engine = QueryEngine()
stats_path = (
Path(__file__).parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if stats_path.exists():
self.query_engine.load(stats_path)
print(
f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}"
)
else:
print(f"⚠️ 统计文件不存在: {stats_path}")
self.query_engine = None
except Exception as e:
print(f"⚠️ 无法加载查询引擎: {e}")
self.query_engine = None
def _debug_sample(self):
"""调试运行与test.py相同的固定样本"""
print("\n🧪 调试运行固定样本与test.py相同")
# 复制test.py中的样本
part1 = "他是一名大学生,在上海读"
part2 = "dayi"
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py))
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
masked_labels = [15, 4, 0, 0, 0, 0, 0, 0]
part3 = ""
part4 = "可行|特别|伤害"
encoded = self.tokenizer(
f"{part4}|{part1}",
part3,
max_length=128,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
sample = {
"input_ids": torch.stack([encoded["input_ids"].squeeze(0)]),
"token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]),
"attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]),
"history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze(
0
),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
"pinyin_ids": pinyin_ids,
}
# 推理
logits, probs = self.inference(sample)
analysis = self.analyze_probability_distribution(probs)
print(f"📊 调试样本结果:")
print(f" 最高概率: {analysis['max_prob']:.6f}")
print(f" 平均概率: {analysis['mean_prob']:.6f}")
print(f" 标准差: {analysis['std_prob']:.6f}")
print(f" 低方差: {analysis['low_variance']}")
# 打印top-5
print(f" Top-5预测:")
for j in range(5):
idx = analysis["top_indices"][j]
prob = analysis["top_probs"][j]
char_info = (
self.query_engine.query_by_id(idx) if self.query_engine else None
)
char = char_info.char if char_info else f"[ID:{idx}]"
print(f" {j + 1}. {char} (ID: {idx}) - 概率: {prob:.6f}")
def generate_pinyin(self, text: str) -> List[str]:
"""
流式处理单条文本转换为拼音列表
参考dataset.py中的generate_pinyin方法
"""
if not text:
return []
text_len = len(text)
result: List[str] = [""] * text_len
# 遍历所有连续汉字片段
for match in _HANZI_RE.finditer(text):
start_idx = match.start()
hanzi_segment = match.group()
pinyin_list = lazy_pinyin(hanzi_segment)
if len(pinyin_list) != len(hanzi_segment):
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
for i, py in enumerate(pinyin_list):
result[start_idx + i] = py
# 填充非汉字字符
for i, char in enumerate(text):
if not result[i]:
result[i] = char
return result
def get_mask_pinyin(
self, text: str, pinyin_list: List[str]
) -> Tuple[int, List[str]]:
"""
生成需要预测汉字对应的拼音并进行加强
参考dataset.py中的get_mask_pinyin方法
"""
mask_pinyin = []
for i in range(len(text)):
if not self.query_engine or not self.query_engine.is_chinese_char(text[i]):
break
else:
py = np.random.choice(
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
p=self.py_style_weight,
)
if py == "":
py = pinyin_list[i][0]
mask_pinyin.append(py)
return len(mask_pinyin), mask_pinyin
def create_sample_from_text(
self, text: str, position: Optional[int] = None
) -> Dict:
"""
从文本创建单个样本参考dataset.py的样本生成逻辑
Args:
text: 输入文本
position: 预测位置汉字位置如果为None则随机选择
Returns:
样本字典包含模型所需的所有字段
"""
# 确保文本长度至少为1
if not text:
raise ValueError("文本不能为空")
# 生成拼音列表
pinyin_list = self.generate_pinyin(text)
# 找到所有汉字位置
chinese_positions = [
i
for i, char in enumerate(text)
if self.query_engine and self.query_engine.is_chinese_char(char)
]
if not chinese_positions:
raise ValueError("文本中没有汉字")
# 随机选择预测位置
if position is None:
position = random.choice(chinese_positions)
elif position >= len(text) or position < 0:
raise ValueError(f"位置 {position} 超出文本范围")
elif not self.query_engine or not self.query_engine.is_chinese_char(
text[position]
):
raise ValueError(f"位置 {position} 不是汉字")
i = position
# part1: 光标前文本最多48个字符
if i < 48:
part1 = text[0:i]
else:
part1 = text[i - 48 : i]
# part2: 拼音输入随机长度1-8高斯分布
pinyin_len_probs = [0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02]
pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs)
py_end = min(i + pinyin_len, len(text))
# 获取拼音
pinyin_len_actual, part2_pinyin = self.get_mask_pinyin(
text[i:py_end], pinyin_list[i:py_end]
)
# 添加分隔符
split_char = np.random.choice(["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02])
part2 = split_char.join(part2_pinyin)
# 转换为拼音ID
pinyin_ids = text_to_pinyin_ids(part2)
len_py = len(pinyin_ids)
if len_py < 24:
pinyin_ids.extend([0] * (24 - len_py))
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids_tensor = torch.tensor(pinyin_ids, dtype=torch.long)
# part3: 光标后文本70%概率为空)
part3 = ""
if random.random() > 0.7:
part3 = text[
i + pinyin_len_actual : i
+ pinyin_len_actual
+ np.random.choice(range(1, 17))
]
# part4: 上下文提示50%概率为空)
part4 = ""
if random.random() > 0.5:
num_strings = random.randint(1, 5)
string_list = []
for _ in range(num_strings):
start_pos = random.randint(0, len(text) - 1)
x = random.randint(2, 6)
end_pos = min(start_pos + x + 1, len(text))
string_list.append(text[start_pos:end_pos])
part4 = "|".join(string_list)
# 标签预测字符的ID
labels = []
try:
labels = [
self.query_engine.get_char_info_by_char_pinyin(c, p).id
for c, p in zip(
text[i : i + pinyin_len_actual],
pinyin_list[i : i + pinyin_len_actual],
)
]
except (AttributeError, TypeError) as e:
# 如果查询失败使用默认ID
print(f"⚠️ 获取标签失败: {e}")
labels = [0] * pinyin_len_actual
# 历史槽位:当前预测字符之前已确认的字符
# 从位置i之前的字符中提取历史槽位最多8个
history_slot_ids = []
if self.query_engine:
# 收集i之前的汉字字符最多8个
for j in range(i - 1, max(-1, i - 9), -1):
if j < 0:
break
char = text[j]
if self.query_engine.is_chinese_char(char):
# 获取该字符的第一个拼音变体的ID
results = self.query_engine.query_by_char(char, limit=1)
if results:
history_slot_ids.append(results[0][0])
else:
# 如果查询失败使用0
history_slot_ids.append(0)
else:
# 非汉字字符使用0
history_slot_ids.append(0)
if len(history_slot_ids) >= 8:
break
# 如果历史槽位为空随机添加一些常见字符ID模拟已确认的输入
if not history_slot_ids and random.random() < 0.5:
# 随机选择1-3个常见字符ID范围1-1000
num_slots = random.randint(1, 3)
for _ in range(num_slots):
history_slot_ids.append(random.randint(1, 1000))
# 填充到8个槽位
if len(history_slot_ids) < 8:
history_slot_ids.extend([0] * (8 - len(history_slot_ids)))
else:
history_slot_ids = history_slot_ids[:8]
# Tokenize输入
encoded = self.tokenizer(
f"{part4}|{part1}",
part3,
max_length=128,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=True,
)
# 计算有效槽位数量(非零元素)
valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id != 0)
# 拼音输入长度(字符数)
pinyin_input_length = len(part2)
# 构建样本
sample = {
"input_ids": encoded["input_ids"],
"token_type_ids": encoded["token_type_ids"],
"attention_mask": encoded["attention_mask"],
"history_slot_ids": torch.tensor(
history_slot_ids, dtype=torch.long
).unsqueeze(0),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
"pinyin_ids": pinyin_ids_tensor.unsqueeze(0),
"true_labels": labels, # 真实标签(多个字符)
"position": i,
"target_char": text[i] if i < len(text) else "",
"valid_slot_count": valid_slot_count, # 有效槽位数量
"pinyin_input_length": pinyin_input_length, # 拼音输入长度
}
return sample
def inference(self, sample: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
"""
执行模型推理
Returns:
(logits, probs)
"""
# 准备输入张量并移动到设备
input_ids = sample["input_ids"].to(self.device)
token_type_ids = sample["token_type_ids"].to(self.device)
attention_mask = sample["attention_mask"].to(self.device)
pinyin_ids = sample["pinyin_ids"].to(self.device)
history_slot_ids = sample["history_slot_ids"].to(self.device)
with torch.no_grad():
if self.device.type == "cuda":
with torch.autocast(device_type="cuda"):
logits = self.model(
input_ids,
token_type_ids,
attention_mask,
pinyin_ids,
history_slot_ids,
)
else:
logits = self.model(
input_ids,
token_type_ids,
attention_mask,
pinyin_ids,
history_slot_ids,
)
probs = F.softmax(logits, dim=-1)
return logits, probs
def analyze_probability_distribution(
self, probs: torch.Tensor, threshold_low_var: float = 0.01
) -> Dict:
"""
分析概率分布检测低方差情况
Args:
probs: 概率张量形状为 (1, vocab_size)
threshold_low_var: 低方差阈值最高概率低于此值则标记为低方差
Returns:
分析结果字典
"""
probs_np = probs.cpu().numpy().flatten()
# 基本统计
max_prob = np.max(probs_np)
min_prob = np.min(probs_np)
mean_prob = np.mean(probs_np)
std_prob = np.std(probs_np)
# 熵
entropy = -np.sum(probs_np * np.log(probs_np + 1e-12))
# 低方差检测
low_variance = max_prob < threshold_low_var
# Top-k 分析
top_k = 20
top_indices = np.argsort(probs_np)[-top_k:][::-1]
top_probs = probs_np[top_indices]
# 检查前5个概率是否接近均匀
top5_probs = top_probs[:5]
top5_ratio = (
top5_probs[0] / (top5_probs[1] + 1e-12) if len(top5_probs) > 1 else 1.0
)
analysis = {
"max_prob": max_prob,
"min_prob": min_prob,
"mean_prob": mean_prob,
"std_prob": std_prob,
"entropy": entropy,
"low_variance": low_variance,
"top_indices": top_indices,
"top_probs": top_probs,
"top5_ratio": top5_ratio,
"vocab_size": len(probs_np),
}
return analysis
def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict:
"""
评估单个样本
Returns:
评估结果字典
"""
# 执行推理
logits, probs = self.inference(sample)
# 分析概率分布
analysis = self.analyze_probability_distribution(probs)
# 获取真实标签
true_labels = sample.get("true_labels", [])
target_char = sample.get("target_char", "")
# 检查预测是否正确(只检查第一个字符)
correct = False
pred_top_idx = analysis["top_indices"][0]
if true_labels and len(true_labels) > 0:
correct = pred_top_idx == true_labels[0]
# 打印结果
if print_details:
# 获取槽位信息
valid_slot_count = sample.get("valid_slot_count", 0)
pinyin_input_length = sample.get("pinyin_input_length", 0)
# 简化输出格式(在同一行继续输出)
print(f" {target_char}", end="")
if true_labels:
print(f"(ID:{true_labels[0]})", end="")
print(
f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}",
end="",
)
print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="")
# 关键概率信息
print(f" | 最高概率:{analysis['max_prob']:.3f}", end="")
if analysis["low_variance"]:
print(f" ⚠️", end="")
# Top-2 预测
top_strs = []
for j in range(min(2, len(analysis["top_indices"]))):
idx = analysis["top_indices"][j]
prob = analysis["top_probs"][j]
# 处理特殊ID
if idx == 0:
char = "[空]"
else:
char_info = (
self.query_engine.query_by_id(idx)
if self.query_engine
else None
)
char = char_info.char if char_info else f"[{idx}]"
is_correct = (
(true_labels and idx == true_labels[0]) if true_labels else False
)
correct_mark = "" if is_correct else ""
top_strs.append(f"{char}{correct_mark}({prob:.2f})")
print(f" | Top-2: {' '.join(top_strs)}", end="")
print(f" | 结果:{'' if correct else ''}")
# 返回评估结果
result = {
"sample": sample,
"analysis": analysis,
"correct": correct,
"target_char": target_char,
"true_label": true_labels[0] if true_labels else None,
"predicted_label": pred_top_idx,
}
return result
def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]:
"""
评估给定文本生成多个样本进行评估
Args:
text: 输入文本
num_samples: 生成的样本数量
Returns:
评估结果列表
"""
# 限制文本长度
if len(text) > 300:
start = random.randint(0, len(text) - 300)
text = text[start : start + 300]
print(f"📝 文本长度超过300随机截取300字符: {text[:50]}...")
else:
print(f"📝 文本长度: {len(text)} 字符")
results = []
low_variance_samples = []
for sample_idx in range(num_samples):
print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ")
try:
# 创建样本
sample = self.create_sample_from_text(text)
# 评估样本
result = self.evaluate_sample(sample, print_details=True)
results.append(result)
# 记录低方差样本
if result["analysis"]["low_variance"]:
low_variance_samples.append(
{
"sample_idx": sample_idx,
"max_prob": result["analysis"]["max_prob"],
"entropy": result["analysis"]["entropy"],
}
)
except Exception as e:
print(f"❌ 样本生成或评估失败: {e}")
import traceback
traceback.print_exc()
continue
# 打印汇总统计
if results:
self.print_summary(results, low_variance_samples)
return results
def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]):
"""打印评估汇总"""
print(f"\n--- 汇总 ({len(results)}样本) ---")
# 计算准确率
correct_count = sum(1 for r in results if r["correct"])
accuracy = correct_count / len(results) if results else 0
print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})")
# 概率统计
max_probs = [r["analysis"]["max_prob"] for r in results]
mean_max_prob = np.mean(max_probs) if max_probs else 0
print(f"平均最高概率: {mean_max_prob:.4f}")
# 预测ID:0的比例
zero_pred_count = sum(1 for r in results if r.get("predicted_label") == 0)
zero_pred_ratio = zero_pred_count / len(results) if results else 0
print(f"预测[空]比例: {zero_pred_ratio:.1%} ({zero_pred_count}/{len(results)})")
# 平均熵
entropies = [r["analysis"]["entropy"] for r in results]
mean_entropy = np.mean(entropies) if entropies else 0
print(f"平均熵: {mean_entropy:.2f}")
# 槽位和拼音长度统计
if results and "sample" in results[0]:
valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results]
pinyin_lengths = [
r["sample"].get("pinyin_input_length", 0) for r in results
]
avg_slots = np.mean(valid_slots) if valid_slots else 0
avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0
print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}")
# 低方差样本
if low_variance_samples:
print(f"低方差样本: {len(low_variance_samples)}")
if len(low_variance_samples) <= 5:
for lv in low_variance_samples:
print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}")
else:
for lv in low_variance_samples[:3]:
print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}")
print(f" ... 还有{len(low_variance_samples) - 3}")
else:
print(f"低方差样本: 0个")
def main():
parser = argparse.ArgumentParser(description="评估模型在文本上的表现")
parser.add_argument("--text-file", type=str, required=True, help="文本文件路径")
parser.add_argument(
"--checkpoint",
type=str,
default="/home/songsenand/下载/best_model.pt",
help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)",
)
parser.add_argument(
"--num-samples", type=int, default=20, help="评估样本数量 (默认: 20)"
)
parser.add_argument(
"--device",
type=str,
default="auto",
choices=["auto", "cpu", "cuda"],
help="推理设备 (默认: auto)",
)
args = parser.parse_args()
# 选择设备
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
# 读取文本文件
try:
with open(args.text_file, "r", encoding="utf-8") as f:
text = f.read().strip()
except Exception as e:
print(f"❌ 无法读取文本文件: {e}")
sys.exit(1)
if not text:
print("❌ 文本文件为空")
sys.exit(1)
print(f"📄 读取文本文件: {args.text_file}")
print(f"📏 原始文本长度: {len(text)} 字符")
print(f"🔧 使用设备: {device}")
print(f"📦 模型checkpoint: {args.checkpoint}")
print(f"🔬 评估样本数: {args.num_samples}")
# 初始化评估器
try:
evaluator = TextEvaluator(args.checkpoint, device)
except Exception as e:
print(f"❌ 初始化评估器失败: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# 执行评估
evaluator.evaluate_text(text, args.num_samples)
if __name__ == "__main__":
main()

View File

@ -9,14 +9,17 @@ from typing import Optional, Union
import pandas as pd
from flask import Flask, render_template, request, jsonify, send_from_directory
app = Flask(__name__,
template_folder=Path(__file__).parent / 'templates',
static_folder=Path(__file__).parent / 'static')
app = Flask(
__name__,
template_folder=Path(__file__).parent / "templates",
static_folder=Path(__file__).parent / "static",
)
# 全局配置
DEFAULT_STATUS_FILE = "./output/training_status.json"
DEFAULT_PORT = 8501
DEFAULT_HOST = "0.0.0.0"
ALLOWED_DATA_SOURCE_TYPES = ["local", "remote"] # 允许的数据源类型
def load_training_data(file_path: str, is_url: bool = False) -> list:
@ -73,72 +76,146 @@ def validate_and_clean_data(data: list) -> list:
cleaned = []
for item in data:
if isinstance(item, dict) and ('step' in item or 'train/loss' in item or 'timestamp' in item):
if isinstance(item, dict) and (
"step" in item or "train/loss" in item or "timestamp" in item
):
cleaned.append(item)
return cleaned
@app.route('/')
@app.route("/")
def index():
"""主页"""
data_source_type = request.args.get('data_source_type', 'local')
data_source = request.args.get('data_source', DEFAULT_STATUS_FILE)
refresh_interval = int(request.args.get('refresh_interval', 5))
# 根据允许的数据源类型设置默认值
if (
"local" in ALLOWED_DATA_SOURCE_TYPES
and "remote" not in ALLOWED_DATA_SOURCE_TYPES
):
default_type = "local"
default_source = DEFAULT_STATUS_FILE
elif (
"remote" in ALLOWED_DATA_SOURCE_TYPES
and "local" not in ALLOWED_DATA_SOURCE_TYPES
):
default_type = "remote"
default_source = ""
else:
default_type = "local"
default_source = DEFAULT_STATUS_FILE
return render_template('index.html',
data_source_type=data_source_type,
data_source=data_source,
refresh_interval=refresh_interval)
data_source_type = request.args.get("data_source_type", default_type)
data_source = request.args.get("data_source", default_source)
refresh_interval = int(request.args.get("refresh_interval", 5))
return render_template(
"index.html",
data_source_type=data_source_type,
data_source=data_source,
refresh_interval=refresh_interval,
allowed_data_source_types=ALLOWED_DATA_SOURCE_TYPES,
)
@app.route('/api/status')
@app.route("/api/status")
def api_status():
"""API接口返回训练状态数据"""
data_source_type = request.args.get('data_source_type', 'local')
data_source = request.args.get('data_source', DEFAULT_STATUS_FILE)
# 根据允许的数据源类型设置默认值
if (
"local" in ALLOWED_DATA_SOURCE_TYPES
and "remote" not in ALLOWED_DATA_SOURCE_TYPES
):
default_type = "local"
elif (
"remote" in ALLOWED_DATA_SOURCE_TYPES
and "local" not in ALLOWED_DATA_SOURCE_TYPES
):
default_type = "remote"
else:
default_type = "local"
data_source_type = request.args.get("data_source_type", default_type)
# 检查请求的数据源类型是否被允许
if data_source_type not in ALLOWED_DATA_SOURCE_TYPES:
return jsonify(
{
"error": f"数据源类型 '{data_source_type}' 不被允许。允许的类型: {ALLOWED_DATA_SOURCE_TYPES}",
"data": [],
}
), 400
# 根据数据源类型设置数据源路径
if data_source_type == "local":
# 本地文件模式固定使用DEFAULT_STATUS_FILE忽略用户提供的路径
data_source = DEFAULT_STATUS_FILE
else:
# 远程URL模式使用用户提供的URL
data_source = request.args.get("data_source", "")
if not data_source:
return jsonify(
{"error": "远程数据源模式下必须提供data_source参数URL", "data": []}
), 400
try:
is_url = data_source_type == 'remote'
is_url = data_source_type == "remote"
raw_data = load_training_data(data_source, is_url)
cleaned_data = validate_and_clean_data(raw_data)
return jsonify(cleaned_data)
except Exception as e:
return jsonify({
'error': str(e),
'data': []
}), 500
return jsonify({"error": str(e), "data": []}), 500
@app.route('/api/info')
@app.route("/api/info")
def api_info():
"""获取服务器信息"""
return jsonify({
'server_time': datetime.now().isoformat(),
'default_status_file': DEFAULT_STATUS_FILE,
'python_version': sys.version,
'working_directory': os.getcwd()
})
return jsonify(
{
"server_time": datetime.now().isoformat(),
"default_status_file": DEFAULT_STATUS_FILE,
"python_version": sys.version,
"working_directory": os.getcwd(),
}
)
def start_flask_server(host: str, port: int, debug: bool = False, use_wsgi: bool = False):
def start_flask_server(
host: str,
port: int,
debug: bool = False,
use_wsgi: bool = False,
status_file: Optional[str] = None,
allowed_data_sources: Optional[list] = None,
):
"""启动Flask服务器"""
from flask import cli
# 禁用Flask的默认启动消息
cli.show_server_banner = lambda *args: None
# 更新默认状态文件
global DEFAULT_STATUS_FILE
if status_file is not None:
DEFAULT_STATUS_FILE = status_file
# 设置允许的数据源类型
global ALLOWED_DATA_SOURCE_TYPES
if allowed_data_sources is not None:
ALLOWED_DATA_SOURCE_TYPES = allowed_data_sources
print(f"🚀 启动训练监控服务 ({'Waitress WSGI' if use_wsgi else 'Flask'})...")
print(f"📁 默认状态文件: {os.path.abspath(DEFAULT_STATUS_FILE)}")
print(f"🌐 监控地址: http://{host}:{port}")
print(f"📊 API接口: http://{host}:{port}/api/status")
print(f"📋 允许的数据源: {', '.join(ALLOWED_DATA_SOURCE_TYPES)}")
print("\n按 Ctrl+C 停止监控服务\n")
try:
if use_wsgi:
try:
import waitress
waitress.serve(app, host=host, port=port, threads=4)
except ImportError:
print("⚠️ waitress未安装回退到Flask开发服务器")
@ -162,11 +239,24 @@ def main():
global DEFAULT_STATUS_FILE
parser = argparse.ArgumentParser(description="AI模型训练监控工具 - Flask版本")
parser.add_argument('--host', default=DEFAULT_HOST, help=f'监控服务主机地址 (默认: {DEFAULT_HOST})')
parser.add_argument('--port', type=int, default=DEFAULT_PORT, help=f'监控服务端口号 (默认: {DEFAULT_PORT})')
parser.add_argument('--debug', action='store_true', help='启用调试模式')
parser.add_argument('--use-wsgi', action='store_true', help='使用Waitress WSGI服务器替代Flask开发服务器')
parser.add_argument('--status-file', default=DEFAULT_STATUS_FILE, help='默认状态文件路径')
parser.add_argument(
"--host", default=DEFAULT_HOST, help=f"监控服务主机地址 (默认: {DEFAULT_HOST})"
)
parser.add_argument(
"--port",
type=int,
default=DEFAULT_PORT,
help=f"监控服务端口号 (默认: {DEFAULT_PORT})",
)
parser.add_argument("--debug", action="store_true", help="启用调试模式")
parser.add_argument(
"--use-wsgi",
action="store_true",
help="使用Waitress WSGI服务器替代Flask开发服务器",
)
parser.add_argument(
"--status-file", default=DEFAULT_STATUS_FILE, help="默认状态文件路径"
)
args = parser.parse_args()
@ -176,5 +266,5 @@ def main():
return start_flask_server(args.host, args.port, args.debug, args.use_wsgi)
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main())

View File

@ -18,7 +18,11 @@ app = typer.Typer(help="AI模型训练监控工具 - 基于JSON旁路记录法
# 尝试导入Flask如果失败则提供友好错误提示
try:
from .flask_monitor import start_flask_server, DEFAULT_STATUS_FILE as FLASK_DEFAULT_STATUS_FILE
from .flask_monitor import (
start_flask_server,
DEFAULT_STATUS_FILE as FLASK_DEFAULT_STATUS_FILE,
)
FLASK_AVAILABLE = True
except ImportError as e:
FLASK_AVAILABLE = False
@ -41,6 +45,7 @@ def start_flask_monitor_server(
host: str,
open_browser: bool,
use_wsgi: bool = False,
allowed_data_sources: Optional[list] = None,
) -> int:
"""
启动Flask监控服务器
@ -51,6 +56,7 @@ def start_flask_monitor_server(
host: 主机地址
open_browser: 是否自动打开浏览器
use_wsgi: 是否使用Waitress WSGI服务器替代Flask开发服务器
allowed_data_sources: 允许的数据源类型列表["local"]["remote"]默认None表示两者都允许
Returns:
进程退出码
@ -62,14 +68,15 @@ def start_flask_monitor_server(
typer.echo("或在pyproject.toml中添加flask依赖")
return 1
# 设置环境变量,传递状态文件路径
os.environ["TRAINING_STATUS_FILE"] = os.path.abspath(status_file)
server_type = "Waitress WSGI" if use_wsgi else "Flask"
typer.echo(f"🚀 启动训练监控服务 ({server_type}版本)...")
typer.echo(f"📁 状态文件: {os.path.abspath(status_file)}")
typer.echo(f"🌐 监控地址: http://{host}:{port}")
typer.echo(f"📊 API接口: http://{host}:{port}/api/status")
if allowed_data_sources is None:
typer.echo(f"📋 允许的数据源: local, remote")
else:
typer.echo(f"📋 允许的数据源: {', '.join(allowed_data_sources)}")
if open_browser:
# 等待服务器启动后打开浏览器
@ -81,7 +88,15 @@ def start_flask_monitor_server(
try:
# 导入并启动Flask服务器
from .flask_monitor import start_flask_server
return start_flask_server(host=host, port=port, debug=False, use_wsgi=use_wsgi)
return start_flask_server(
host=host,
port=port,
debug=False,
use_wsgi=use_wsgi,
status_file=status_file,
allowed_data_sources=allowed_data_sources,
)
except KeyboardInterrupt:
typer.echo("\n🛑 监控服务已停止")
return 0
@ -101,7 +116,9 @@ def monitor_training(
port: int = typer.Option(8501, "--port", "-p", help="监控服务端口号"),
host: str = typer.Option("0.0.0.0", "--host", help="监控服务主机地址"),
open_browser: bool = typer.Option(False, "--open-browser", help="自动打开浏览器"),
use_wsgi: bool = typer.Option(False, "--use-wsgi", help="使用Waitress WSGI服务器替代Flask开发服务器"),
use_wsgi: bool = typer.Option(
False, "--use-wsgi", help="使用Waitress WSGI服务器替代Flask开发服务器"
),
):
"""
启动AI模型训练监控服务 (Flask版本)
@ -131,6 +148,13 @@ def monitor_training(
typer.echo("或在pyproject.toml中添加flask依赖")
raise typer.Exit(code=1)
# 根据是否提供了-s参数决定允许的数据源
default_status_file = "./output/training_status.json"
if status_file == default_status_file:
allowed_data_sources = ["remote"] # 未提供-s只允许远程
else:
allowed_data_sources = ["local"] # 提供了-s只允许本地
# 启动Flask服务器
return_code = start_flask_monitor_server(
status_file=status_file,
@ -138,6 +162,7 @@ def monitor_training(
host=host,
open_browser=open_browser,
use_wsgi=use_wsgi,
allowed_data_sources=allowed_data_sources,
)
raise typer.Exit(code=return_code)

View File

@ -273,6 +273,7 @@
</div>
{% endblock %} {% block extra_js %}
<script>
const allowedDataSourceTypes = {{ allowed_data_source_types|tojson }};
let currentData = [];
let refreshTimer = null;
let currentRefreshInterval = {{ refresh_interval }};
@ -287,6 +288,36 @@
const refreshIntervalSelect = document.getElementById('refreshIntervalSelect');
const applyConfigBtn = document.getElementById('applyConfigBtn');
// 根据允许的数据源类型调整UI
if (allowedDataSourceTypes.length === 1) {
// 隐藏选择框,固定数据源类型
dataSourceSelect.style.display = 'none';
dataSourceSelect.value = allowedDataSourceTypes[0];
// 根据类型显示相应的输入字段
if (allowedDataSourceTypes[0] === 'local') {
localFileField.style.display = 'block';
remoteUrlField.style.display = 'none';
// 使本地文件路径输入框只读
document.getElementById('localFilePath').readOnly = true;
} else {
localFileField.style.display = 'none';
remoteUrlField.style.display = 'block';
// 远程URL输入框必须填写保持可编辑
}
// 禁用选择框的change事件因为只有一个选项
dataSourceSelect.disabled = true;
} else {
// 正常情况:两个选项都允许
// 根据当前选择显示相应字段
if (dataSourceSelect.value === 'local') {
localFileField.style.display = 'block';
remoteUrlField.style.display = 'none';
} else {
localFileField.style.display = 'none';
remoteUrlField.style.display = 'block';
}
}
dataSourceSelect.addEventListener('change', function() {
if (this.value === 'local') {

View File

@ -1070,7 +1070,7 @@ def train(
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=5000,
shuffle_buffer_size=50000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
@ -1093,7 +1093,7 @@ def train(
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=50000,
shuffle_buffer_size=500000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)

View File

@ -56,7 +56,7 @@ if len_py < 24:
else:
pinyin_ids = pinyin_ids[:24]
pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0)
masked_labels = [15, 4, 0, 0, 0, 0, 0, 0]
masked_labels = [0, 0, 0, 0, 0, 0, 0, 0]
part3 = ""
part4 = "可行|特别|伤害"

5
test_text.txt Normal file
View File

@ -0,0 +1,5 @@
他是一名大学生,在上海读大学。他喜欢编程和机器学习。
他的专业是计算机科学,他对人工智能很感兴趣。
今天天气很好,我们去公园玩。公园里有很多人在散步。
中国的历史文化悠久,有着丰富的文化遗产。
北京是中国的首都,上海是中国的经济中心。