diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..7cc0b58 --- /dev/null +++ b/eval.py @@ -0,0 +1,744 @@ +#!/usr/bin/env python3 +""" +eval.py - 评估模型在给定文本上的表现 + +使用方法: + python eval.py --text-file path/to/text.txt [--checkpoint path/to/model.pt] [--num-samples N] + +功能: + 1. 读取文本文件,限制长度300,随机抽取 + 2. 参考dataset.py将文本内容转化为模型可接受数据,包含part1/part2/part3/part4/槽位历史等信息 + 3. 使用模型推理,参考test.py + 4. 错误的打印,对于方差极小,概率和猜没什么差别的,重点标记 +""" + +import argparse +import random +import re +import sys +from pathlib import Path +from typing import Dict, List, Tuple, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from modelscope import AutoTokenizer +from pypinyin import lazy_pinyin +from pypinyin.contrib.tone_convert import to_initials + +# 添加src目录到路径 +sys.path.append("src") + +from src.model.model import InputMethodEngine +from src.model.query import QueryEngine +from src.model.dataset import text_to_pinyin_ids + +_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") + + +class TextEvaluator: + def __init__( + self, + checkpoint_path: str, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + ): + self.device = torch.device(device) + self.checkpoint_path = checkpoint_path + + # 加载组件 + self.load_model() + self.load_tokenizer() + self.load_query_engine() + + # 拼音风格权重(与dataset.py一致) + self.py_style_weight = np.array([9, 2, 1]) / sum([9, 2, 1]) + + print(f"✅ 评估器初始化完成 (设备: {self.device})") + + # 调试:运行一个固定样本测试 + import os + + if os.environ.get("EVAL_DEBUG"): + self._debug_sample() + + def load_model(self): + """加载训练好的模型""" + self.model = InputMethodEngine(pinyin_vocab_size=30, compile=False) + checkpoint = torch.load(self.checkpoint_path, map_location="cpu") + if "model_state_dict" in checkpoint: + self.model.load_state_dict(checkpoint["model_state_dict"]) + else: + self.model.load_state_dict(checkpoint) + self.model.eval() + self.model.to(self.device) + print( + f"✅ 模型加载完成,参数量: {sum(p.numel() for p in self.model.parameters()):,}" + ) + print(f"✅ 模型词汇表大小: {self.model.vocab_size}") + + def load_tokenizer(self): + """加载tokenizer""" + try: + tokenizer_path = ( + Path(__file__).parent / "src" / "model" / "assets" / "tokenizer" + ) + self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) + print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}") + except Exception as e: + print(f"⚠️ 无法加载tokenizer: {e}") + print("使用默认的bert-base-chinese tokenizer") + self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") + + def load_query_engine(self): + """加载查询引擎用于字符-ID转换""" + try: + self.query_engine = QueryEngine() + stats_path = ( + Path(__file__).parent + / "src" + / "model" + / "assets" + / "pinyin_char_statistics.json" + ) + if stats_path.exists(): + self.query_engine.load(stats_path) + print( + f"✅ 查询引擎加载完成,字符对数量: {len(self.query_engine._id_to_info)}" + ) + else: + print(f"⚠️ 统计文件不存在: {stats_path}") + self.query_engine = None + except Exception as e: + print(f"⚠️ 无法加载查询引擎: {e}") + self.query_engine = None + + def _debug_sample(self): + """调试:运行与test.py相同的固定样本""" + print("\n🧪 调试:运行固定样本(与test.py相同)") + + # 复制test.py中的样本 + part1 = "他是一名大学生,在上海读" + part2 = "dayi" + pinyin_ids = text_to_pinyin_ids(part2) + len_py = len(pinyin_ids) + if len_py < 24: + pinyin_ids.extend([0] * (24 - len_py)) + else: + pinyin_ids = pinyin_ids[:24] + pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) + masked_labels = [15, 4, 0, 0, 0, 0, 0, 0] + part3 = "。" + part4 = "可行|特别|伤害" + + encoded = self.tokenizer( + f"{part4}|{part1}", + part3, + max_length=128, + padding="max_length", + truncation=True, + return_tensors="pt", + return_token_type_ids=True, + ) + + sample = { + "input_ids": torch.stack([encoded["input_ids"].squeeze(0)]), + "token_type_ids": torch.stack([encoded["token_type_ids"].squeeze(0)]), + "attention_mask": torch.stack([encoded["attention_mask"].squeeze(0)]), + "history_slot_ids": torch.tensor(masked_labels, dtype=torch.long).unsqueeze( + 0 + ), + "prefix": f"{part4}^{part1}", + "suffix": part3, + "pinyin": part2, + "pinyin_ids": pinyin_ids, + } + + # 推理 + logits, probs = self.inference(sample) + analysis = self.analyze_probability_distribution(probs) + + print(f"📊 调试样本结果:") + print(f" 最高概率: {analysis['max_prob']:.6f}") + print(f" 平均概率: {analysis['mean_prob']:.6f}") + print(f" 标准差: {analysis['std_prob']:.6f}") + print(f" 低方差: {analysis['low_variance']}") + + # 打印top-5 + print(f" Top-5预测:") + for j in range(5): + idx = analysis["top_indices"][j] + prob = analysis["top_probs"][j] + char_info = ( + self.query_engine.query_by_id(idx) if self.query_engine else None + ) + char = char_info.char if char_info else f"[ID:{idx}]" + print(f" {j + 1}. {char} (ID: {idx}) - 概率: {prob:.6f}") + + def generate_pinyin(self, text: str) -> List[str]: + """ + 流式处理单条文本,转换为拼音列表。 + 参考dataset.py中的generate_pinyin方法。 + """ + if not text: + return [] + + text_len = len(text) + result: List[str] = [""] * text_len + + # 遍历所有连续汉字片段 + for match in _HANZI_RE.finditer(text): + start_idx = match.start() + hanzi_segment = match.group() + + pinyin_list = lazy_pinyin(hanzi_segment) + + if len(pinyin_list) != len(hanzi_segment): + pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment] + + for i, py in enumerate(pinyin_list): + result[start_idx + i] = py + + # 填充非汉字字符 + for i, char in enumerate(text): + if not result[i]: + result[i] = char + + return result + + def get_mask_pinyin( + self, text: str, pinyin_list: List[str] + ) -> Tuple[int, List[str]]: + """ + 生成需要预测汉字对应的拼音,并进行加强。 + 参考dataset.py中的get_mask_pinyin方法。 + """ + mask_pinyin = [] + for i in range(len(text)): + if not self.query_engine or not self.query_engine.is_chinese_char(text[i]): + break + else: + py = np.random.choice( + (pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]), + p=self.py_style_weight, + ) + if py == "": + py = pinyin_list[i][0] + mask_pinyin.append(py) + return len(mask_pinyin), mask_pinyin + + def create_sample_from_text( + self, text: str, position: Optional[int] = None + ) -> Dict: + """ + 从文本创建单个样本,参考dataset.py的样本生成逻辑。 + + Args: + text: 输入文本 + position: 预测位置(汉字位置),如果为None则随机选择 + + Returns: + 样本字典,包含模型所需的所有字段 + """ + # 确保文本长度至少为1 + if not text: + raise ValueError("文本不能为空") + + # 生成拼音列表 + pinyin_list = self.generate_pinyin(text) + + # 找到所有汉字位置 + chinese_positions = [ + i + for i, char in enumerate(text) + if self.query_engine and self.query_engine.is_chinese_char(char) + ] + + if not chinese_positions: + raise ValueError("文本中没有汉字") + + # 随机选择预测位置 + if position is None: + position = random.choice(chinese_positions) + elif position >= len(text) or position < 0: + raise ValueError(f"位置 {position} 超出文本范围") + elif not self.query_engine or not self.query_engine.is_chinese_char( + text[position] + ): + raise ValueError(f"位置 {position} 不是汉字") + + i = position + + # part1: 光标前文本(最多48个字符) + if i < 48: + part1 = text[0:i] + else: + part1 = text[i - 48 : i] + + # part2: 拼音输入(随机长度1-8,高斯分布) + pinyin_len_probs = [0.05, 0.16, 0.45, 0.16, 0.08, 0.05, 0.03, 0.02] + pinyin_len = np.random.choice(range(1, 9), p=pinyin_len_probs) + py_end = min(i + pinyin_len, len(text)) + + # 获取拼音 + pinyin_len_actual, part2_pinyin = self.get_mask_pinyin( + text[i:py_end], pinyin_list[i:py_end] + ) + + # 添加分隔符 + split_char = np.random.choice(["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]) + part2 = split_char.join(part2_pinyin) + + # 转换为拼音ID + pinyin_ids = text_to_pinyin_ids(part2) + len_py = len(pinyin_ids) + if len_py < 24: + pinyin_ids.extend([0] * (24 - len_py)) + else: + pinyin_ids = pinyin_ids[:24] + pinyin_ids_tensor = torch.tensor(pinyin_ids, dtype=torch.long) + + # part3: 光标后文本(70%概率为空) + part3 = "" + if random.random() > 0.7: + part3 = text[ + i + pinyin_len_actual : i + + pinyin_len_actual + + np.random.choice(range(1, 17)) + ] + + # part4: 上下文提示(50%概率为空) + part4 = "" + if random.random() > 0.5: + num_strings = random.randint(1, 5) + string_list = [] + for _ in range(num_strings): + start_pos = random.randint(0, len(text) - 1) + x = random.randint(2, 6) + end_pos = min(start_pos + x + 1, len(text)) + string_list.append(text[start_pos:end_pos]) + part4 = "|".join(string_list) + + # 标签:预测字符的ID + labels = [] + try: + labels = [ + self.query_engine.get_char_info_by_char_pinyin(c, p).id + for c, p in zip( + text[i : i + pinyin_len_actual], + pinyin_list[i : i + pinyin_len_actual], + ) + ] + except (AttributeError, TypeError) as e: + # 如果查询失败,使用默认ID + print(f"⚠️ 获取标签失败: {e}") + labels = [0] * pinyin_len_actual + + # 历史槽位:当前预测字符之前已确认的字符 + # 从位置i之前的字符中提取历史槽位(最多8个) + history_slot_ids = [] + if self.query_engine: + # 收集i之前的汉字字符(最多8个) + for j in range(i - 1, max(-1, i - 9), -1): + if j < 0: + break + char = text[j] + if self.query_engine.is_chinese_char(char): + # 获取该字符的第一个拼音变体的ID + results = self.query_engine.query_by_char(char, limit=1) + if results: + history_slot_ids.append(results[0][0]) + else: + # 如果查询失败,使用0 + history_slot_ids.append(0) + else: + # 非汉字字符,使用0 + history_slot_ids.append(0) + if len(history_slot_ids) >= 8: + break + + # 如果历史槽位为空,随机添加一些常见字符ID(模拟已确认的输入) + if not history_slot_ids and random.random() < 0.5: + # 随机选择1-3个常见字符ID(范围1-1000) + num_slots = random.randint(1, 3) + for _ in range(num_slots): + history_slot_ids.append(random.randint(1, 1000)) + + # 填充到8个槽位 + if len(history_slot_ids) < 8: + history_slot_ids.extend([0] * (8 - len(history_slot_ids))) + else: + history_slot_ids = history_slot_ids[:8] + + # Tokenize输入 + encoded = self.tokenizer( + f"{part4}|{part1}", + part3, + max_length=128, + padding="max_length", + truncation=True, + return_tensors="pt", + return_token_type_ids=True, + ) + + # 计算有效槽位数量(非零元素) + valid_slot_count = sum(1 for slot_id in history_slot_ids if slot_id != 0) + + # 拼音输入长度(字符数) + pinyin_input_length = len(part2) + + # 构建样本 + sample = { + "input_ids": encoded["input_ids"], + "token_type_ids": encoded["token_type_ids"], + "attention_mask": encoded["attention_mask"], + "history_slot_ids": torch.tensor( + history_slot_ids, dtype=torch.long + ).unsqueeze(0), + "prefix": f"{part4}^{part1}", + "suffix": part3, + "pinyin": part2, + "pinyin_ids": pinyin_ids_tensor.unsqueeze(0), + "true_labels": labels, # 真实标签(多个字符) + "position": i, + "target_char": text[i] if i < len(text) else "", + "valid_slot_count": valid_slot_count, # 有效槽位数量 + "pinyin_input_length": pinyin_input_length, # 拼音输入长度 + } + + return sample + + def inference(self, sample: Dict) -> Tuple[torch.Tensor, torch.Tensor]: + """ + 执行模型推理。 + + Returns: + (logits, probs) + """ + # 准备输入张量并移动到设备 + input_ids = sample["input_ids"].to(self.device) + token_type_ids = sample["token_type_ids"].to(self.device) + attention_mask = sample["attention_mask"].to(self.device) + pinyin_ids = sample["pinyin_ids"].to(self.device) + history_slot_ids = sample["history_slot_ids"].to(self.device) + + with torch.no_grad(): + if self.device.type == "cuda": + with torch.autocast(device_type="cuda"): + logits = self.model( + input_ids, + token_type_ids, + attention_mask, + pinyin_ids, + history_slot_ids, + ) + else: + logits = self.model( + input_ids, + token_type_ids, + attention_mask, + pinyin_ids, + history_slot_ids, + ) + + probs = F.softmax(logits, dim=-1) + return logits, probs + + def analyze_probability_distribution( + self, probs: torch.Tensor, threshold_low_var: float = 0.01 + ) -> Dict: + """ + 分析概率分布,检测低方差情况。 + + Args: + probs: 概率张量,形状为 (1, vocab_size) + threshold_low_var: 低方差阈值,最高概率低于此值则标记为低方差 + + Returns: + 分析结果字典 + """ + probs_np = probs.cpu().numpy().flatten() + + # 基本统计 + max_prob = np.max(probs_np) + min_prob = np.min(probs_np) + mean_prob = np.mean(probs_np) + std_prob = np.std(probs_np) + + # 熵 + entropy = -np.sum(probs_np * np.log(probs_np + 1e-12)) + + # 低方差检测 + low_variance = max_prob < threshold_low_var + + # Top-k 分析 + top_k = 20 + top_indices = np.argsort(probs_np)[-top_k:][::-1] + top_probs = probs_np[top_indices] + + # 检查前5个概率是否接近均匀 + top5_probs = top_probs[:5] + top5_ratio = ( + top5_probs[0] / (top5_probs[1] + 1e-12) if len(top5_probs) > 1 else 1.0 + ) + + analysis = { + "max_prob": max_prob, + "min_prob": min_prob, + "mean_prob": mean_prob, + "std_prob": std_prob, + "entropy": entropy, + "low_variance": low_variance, + "top_indices": top_indices, + "top_probs": top_probs, + "top5_ratio": top5_ratio, + "vocab_size": len(probs_np), + } + + return analysis + + def evaluate_sample(self, sample: Dict, print_details: bool = True) -> Dict: + """ + 评估单个样本。 + + Returns: + 评估结果字典 + """ + # 执行推理 + logits, probs = self.inference(sample) + + # 分析概率分布 + analysis = self.analyze_probability_distribution(probs) + + # 获取真实标签 + true_labels = sample.get("true_labels", []) + target_char = sample.get("target_char", "") + + # 检查预测是否正确(只检查第一个字符) + correct = False + pred_top_idx = analysis["top_indices"][0] + if true_labels and len(true_labels) > 0: + correct = pred_top_idx == true_labels[0] + + # 打印结果 + if print_details: + # 获取槽位信息 + valid_slot_count = sample.get("valid_slot_count", 0) + pinyin_input_length = sample.get("pinyin_input_length", 0) + + # 简化输出格式(在同一行继续输出) + print(f" {target_char}", end="") + if true_labels: + print(f"(ID:{true_labels[0]})", end="") + print( + f" | 拼音:{sample.get('pinyin', '')[:10]}{'...' if len(sample.get('pinyin', '')) > 10 else ''}", + end="", + ) + print(f" | 槽位:{valid_slot_count}/8 拼音长:{pinyin_input_length}", end="") + + # 关键概率信息 + print(f" | 最高概率:{analysis['max_prob']:.3f}", end="") + if analysis["low_variance"]: + print(f" ⚠️", end="") + + # Top-2 预测 + top_strs = [] + for j in range(min(2, len(analysis["top_indices"]))): + idx = analysis["top_indices"][j] + prob = analysis["top_probs"][j] + # 处理特殊ID + if idx == 0: + char = "[空]" + else: + char_info = ( + self.query_engine.query_by_id(idx) + if self.query_engine + else None + ) + char = char_info.char if char_info else f"[{idx}]" + is_correct = ( + (true_labels and idx == true_labels[0]) if true_labels else False + ) + correct_mark = "✓" if is_correct else "" + top_strs.append(f"{char}{correct_mark}({prob:.2f})") + print(f" | Top-2: {' '.join(top_strs)}", end="") + + print(f" | 结果:{'✓' if correct else '✗'}") + + # 返回评估结果 + result = { + "sample": sample, + "analysis": analysis, + "correct": correct, + "target_char": target_char, + "true_label": true_labels[0] if true_labels else None, + "predicted_label": pred_top_idx, + } + + return result + + def evaluate_text(self, text: str, num_samples: int = 10) -> List[Dict]: + """ + 评估给定文本,生成多个样本进行评估。 + + Args: + text: 输入文本 + num_samples: 生成的样本数量 + + Returns: + 评估结果列表 + """ + # 限制文本长度 + if len(text) > 300: + start = random.randint(0, len(text) - 300) + text = text[start : start + 300] + print(f"📝 文本长度超过300,随机截取300字符: {text[:50]}...") + else: + print(f"📝 文本长度: {len(text)} 字符") + + results = [] + low_variance_samples = [] + + for sample_idx in range(num_samples): + print(f"\n[样本 {sample_idx + 1}/{num_samples}]", end=" ") + + try: + # 创建样本 + sample = self.create_sample_from_text(text) + + # 评估样本 + result = self.evaluate_sample(sample, print_details=True) + results.append(result) + + # 记录低方差样本 + if result["analysis"]["low_variance"]: + low_variance_samples.append( + { + "sample_idx": sample_idx, + "max_prob": result["analysis"]["max_prob"], + "entropy": result["analysis"]["entropy"], + } + ) + + except Exception as e: + print(f"❌ 样本生成或评估失败: {e}") + import traceback + + traceback.print_exc() + continue + + # 打印汇总统计 + if results: + self.print_summary(results, low_variance_samples) + + return results + + def print_summary(self, results: List[Dict], low_variance_samples: List[Dict]): + """打印评估汇总""" + print(f"\n--- 汇总 ({len(results)}样本) ---") + + # 计算准确率 + correct_count = sum(1 for r in results if r["correct"]) + accuracy = correct_count / len(results) if results else 0 + print(f"准确率: {accuracy:.1%} ({correct_count}/{len(results)})") + + # 概率统计 + max_probs = [r["analysis"]["max_prob"] for r in results] + mean_max_prob = np.mean(max_probs) if max_probs else 0 + print(f"平均最高概率: {mean_max_prob:.4f}") + + # 预测ID:0的比例 + zero_pred_count = sum(1 for r in results if r.get("predicted_label") == 0) + zero_pred_ratio = zero_pred_count / len(results) if results else 0 + print(f"预测[空]比例: {zero_pred_ratio:.1%} ({zero_pred_count}/{len(results)})") + + # 平均熵 + entropies = [r["analysis"]["entropy"] for r in results] + mean_entropy = np.mean(entropies) if entropies else 0 + print(f"平均熵: {mean_entropy:.2f}") + + # 槽位和拼音长度统计 + if results and "sample" in results[0]: + valid_slots = [r["sample"].get("valid_slot_count", 0) for r in results] + pinyin_lengths = [ + r["sample"].get("pinyin_input_length", 0) for r in results + ] + avg_slots = np.mean(valid_slots) if valid_slots else 0 + avg_pinyin = np.mean(pinyin_lengths) if pinyin_lengths else 0 + print(f"平均槽位: {avg_slots:.1f}/8, 平均拼音长度: {avg_pinyin:.1f}") + + # 低方差样本 + if low_variance_samples: + print(f"低方差样本: {len(low_variance_samples)}个") + if len(low_variance_samples) <= 5: + for lv in low_variance_samples: + print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}") + else: + for lv in low_variance_samples[:3]: + print(f" 样本{lv['sample_idx'] + 1}: 最高概率{lv['max_prob']:.4f}") + print(f" ... 还有{len(low_variance_samples) - 3}个") + else: + print(f"低方差样本: 0个") + + +def main(): + parser = argparse.ArgumentParser(description="评估模型在文本上的表现") + parser.add_argument("--text-file", type=str, required=True, help="文本文件路径") + parser.add_argument( + "--checkpoint", + type=str, + default="/home/songsenand/下载/best_model.pt", + help="模型checkpoint路径 (默认: /home/songsenand/下载/best_model.pt)", + ) + parser.add_argument( + "--num-samples", type=int, default=20, help="评估样本数量 (默认: 20)" + ) + parser.add_argument( + "--device", + type=str, + default="auto", + choices=["auto", "cpu", "cuda"], + help="推理设备 (默认: auto)", + ) + + args = parser.parse_args() + + # 选择设备 + if args.device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + else: + device = args.device + + # 读取文本文件 + try: + with open(args.text_file, "r", encoding="utf-8") as f: + text = f.read().strip() + except Exception as e: + print(f"❌ 无法读取文本文件: {e}") + sys.exit(1) + + if not text: + print("❌ 文本文件为空") + sys.exit(1) + + print(f"📄 读取文本文件: {args.text_file}") + print(f"📏 原始文本长度: {len(text)} 字符") + print(f"🔧 使用设备: {device}") + print(f"📦 模型checkpoint: {args.checkpoint}") + print(f"🔬 评估样本数: {args.num_samples}") + + # 初始化评估器 + try: + evaluator = TextEvaluator(args.checkpoint, device) + except Exception as e: + print(f"❌ 初始化评估器失败: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + # 执行评估 + evaluator.evaluate_text(text, args.num_samples) + + +if __name__ == "__main__": + main() diff --git a/src/model/flask_monitor.py b/src/model/flask_monitor.py index 350384b..677e4db 100644 --- a/src/model/flask_monitor.py +++ b/src/model/flask_monitor.py @@ -9,14 +9,17 @@ from typing import Optional, Union import pandas as pd from flask import Flask, render_template, request, jsonify, send_from_directory -app = Flask(__name__, - template_folder=Path(__file__).parent / 'templates', - static_folder=Path(__file__).parent / 'static') +app = Flask( + __name__, + template_folder=Path(__file__).parent / "templates", + static_folder=Path(__file__).parent / "static", +) # 全局配置 DEFAULT_STATUS_FILE = "./output/training_status.json" DEFAULT_PORT = 8501 DEFAULT_HOST = "0.0.0.0" +ALLOWED_DATA_SOURCE_TYPES = ["local", "remote"] # 允许的数据源类型 def load_training_data(file_path: str, is_url: bool = False) -> list: @@ -33,15 +36,15 @@ def load_from_url(url: str) -> list: import requests except ImportError: raise RuntimeError("requests库未安装,无法从HTTP URL加载数据") - + try: response = requests.get(url, timeout=10) response.raise_for_status() data = response.json() - + if not isinstance(data, list): raise ValueError("远程返回的数据不是列表格式") - + return data except Exception as e: raise RuntimeError(f"从URL加载数据失败: {e}") @@ -52,13 +55,13 @@ def load_from_local_file(file_path: str) -> list: try: if not os.path.exists(file_path): return [] - + with open(file_path, "r", encoding="utf-8") as f: data = json.load(f) - + if not isinstance(data, list): raise ValueError("文件内容不是列表格式") - + return data except json.JSONDecodeError: return [] @@ -70,75 +73,149 @@ def validate_and_clean_data(data: list) -> list: """验证和清理数据""" if not isinstance(data, list): return [] - + cleaned = [] for item in data: - if isinstance(item, dict) and ('step' in item or 'train/loss' in item or 'timestamp' in item): + if isinstance(item, dict) and ( + "step" in item or "train/loss" in item or "timestamp" in item + ): cleaned.append(item) - + return cleaned -@app.route('/') +@app.route("/") def index(): """主页""" - data_source_type = request.args.get('data_source_type', 'local') - data_source = request.args.get('data_source', DEFAULT_STATUS_FILE) - refresh_interval = int(request.args.get('refresh_interval', 5)) - - return render_template('index.html', - data_source_type=data_source_type, - data_source=data_source, - refresh_interval=refresh_interval) + # 根据允许的数据源类型设置默认值 + if ( + "local" in ALLOWED_DATA_SOURCE_TYPES + and "remote" not in ALLOWED_DATA_SOURCE_TYPES + ): + default_type = "local" + default_source = DEFAULT_STATUS_FILE + elif ( + "remote" in ALLOWED_DATA_SOURCE_TYPES + and "local" not in ALLOWED_DATA_SOURCE_TYPES + ): + default_type = "remote" + default_source = "" + else: + default_type = "local" + default_source = DEFAULT_STATUS_FILE + + data_source_type = request.args.get("data_source_type", default_type) + data_source = request.args.get("data_source", default_source) + refresh_interval = int(request.args.get("refresh_interval", 5)) + + return render_template( + "index.html", + data_source_type=data_source_type, + data_source=data_source, + refresh_interval=refresh_interval, + allowed_data_source_types=ALLOWED_DATA_SOURCE_TYPES, + ) -@app.route('/api/status') +@app.route("/api/status") def api_status(): """API接口,返回训练状态数据""" - data_source_type = request.args.get('data_source_type', 'local') - data_source = request.args.get('data_source', DEFAULT_STATUS_FILE) - + # 根据允许的数据源类型设置默认值 + if ( + "local" in ALLOWED_DATA_SOURCE_TYPES + and "remote" not in ALLOWED_DATA_SOURCE_TYPES + ): + default_type = "local" + elif ( + "remote" in ALLOWED_DATA_SOURCE_TYPES + and "local" not in ALLOWED_DATA_SOURCE_TYPES + ): + default_type = "remote" + else: + default_type = "local" + + data_source_type = request.args.get("data_source_type", default_type) + + # 检查请求的数据源类型是否被允许 + if data_source_type not in ALLOWED_DATA_SOURCE_TYPES: + return jsonify( + { + "error": f"数据源类型 '{data_source_type}' 不被允许。允许的类型: {ALLOWED_DATA_SOURCE_TYPES}", + "data": [], + } + ), 400 + + # 根据数据源类型设置数据源路径 + if data_source_type == "local": + # 本地文件模式:固定使用DEFAULT_STATUS_FILE,忽略用户提供的路径 + data_source = DEFAULT_STATUS_FILE + else: + # 远程URL模式:使用用户提供的URL + data_source = request.args.get("data_source", "") + if not data_source: + return jsonify( + {"error": "远程数据源模式下必须提供data_source参数(URL)", "data": []} + ), 400 + try: - is_url = data_source_type == 'remote' + is_url = data_source_type == "remote" raw_data = load_training_data(data_source, is_url) cleaned_data = validate_and_clean_data(raw_data) - + return jsonify(cleaned_data) except Exception as e: - return jsonify({ - 'error': str(e), - 'data': [] - }), 500 + return jsonify({"error": str(e), "data": []}), 500 -@app.route('/api/info') +@app.route("/api/info") def api_info(): """获取服务器信息""" - return jsonify({ - 'server_time': datetime.now().isoformat(), - 'default_status_file': DEFAULT_STATUS_FILE, - 'python_version': sys.version, - 'working_directory': os.getcwd() - }) + return jsonify( + { + "server_time": datetime.now().isoformat(), + "default_status_file": DEFAULT_STATUS_FILE, + "python_version": sys.version, + "working_directory": os.getcwd(), + } + ) -def start_flask_server(host: str, port: int, debug: bool = False, use_wsgi: bool = False): +def start_flask_server( + host: str, + port: int, + debug: bool = False, + use_wsgi: bool = False, + status_file: Optional[str] = None, + allowed_data_sources: Optional[list] = None, +): """启动Flask服务器""" from flask import cli - + # 禁用Flask的默认启动消息 cli.show_server_banner = lambda *args: None - + + # 更新默认状态文件 + global DEFAULT_STATUS_FILE + if status_file is not None: + DEFAULT_STATUS_FILE = status_file + + # 设置允许的数据源类型 + global ALLOWED_DATA_SOURCE_TYPES + if allowed_data_sources is not None: + ALLOWED_DATA_SOURCE_TYPES = allowed_data_sources + print(f"🚀 启动训练监控服务 ({'Waitress WSGI' if use_wsgi else 'Flask'})...") print(f"📁 默认状态文件: {os.path.abspath(DEFAULT_STATUS_FILE)}") print(f"🌐 监控地址: http://{host}:{port}") print(f"📊 API接口: http://{host}:{port}/api/status") + print(f"📋 允许的数据源: {', '.join(ALLOWED_DATA_SOURCE_TYPES)}") print("\n按 Ctrl+C 停止监控服务\n") - + try: if use_wsgi: try: import waitress + waitress.serve(app, host=host, port=port, threads=4) except ImportError: print("⚠️ waitress未安装,回退到Flask开发服务器") @@ -151,30 +228,43 @@ def start_flask_server(host: str, port: int, debug: bool = False, use_wsgi: bool except Exception as e: print(f"❌ 启动监控服务时出错: {e}") return 1 - + return 0 def main(): """命令行入口点""" import argparse - + global DEFAULT_STATUS_FILE - + parser = argparse.ArgumentParser(description="AI模型训练监控工具 - Flask版本") - parser.add_argument('--host', default=DEFAULT_HOST, help=f'监控服务主机地址 (默认: {DEFAULT_HOST})') - parser.add_argument('--port', type=int, default=DEFAULT_PORT, help=f'监控服务端口号 (默认: {DEFAULT_PORT})') - parser.add_argument('--debug', action='store_true', help='启用调试模式') - parser.add_argument('--use-wsgi', action='store_true', help='使用Waitress WSGI服务器替代Flask开发服务器') - parser.add_argument('--status-file', default=DEFAULT_STATUS_FILE, help='默认状态文件路径') - + parser.add_argument( + "--host", default=DEFAULT_HOST, help=f"监控服务主机地址 (默认: {DEFAULT_HOST})" + ) + parser.add_argument( + "--port", + type=int, + default=DEFAULT_PORT, + help=f"监控服务端口号 (默认: {DEFAULT_PORT})", + ) + parser.add_argument("--debug", action="store_true", help="启用调试模式") + parser.add_argument( + "--use-wsgi", + action="store_true", + help="使用Waitress WSGI服务器替代Flask开发服务器", + ) + parser.add_argument( + "--status-file", default=DEFAULT_STATUS_FILE, help="默认状态文件路径" + ) + args = parser.parse_args() - + # 更新默认状态文件 DEFAULT_STATUS_FILE = args.status_file - + return start_flask_server(args.host, args.port, args.debug, args.use_wsgi) -if __name__ == '__main__': - sys.exit(main()) \ No newline at end of file +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/model/monitor.py b/src/model/monitor.py index 01253a7..f0a6ee7 100644 --- a/src/model/monitor.py +++ b/src/model/monitor.py @@ -18,7 +18,11 @@ app = typer.Typer(help="AI模型训练监控工具 - 基于JSON旁路记录法 # 尝试导入Flask,如果失败则提供友好错误提示 try: - from .flask_monitor import start_flask_server, DEFAULT_STATUS_FILE as FLASK_DEFAULT_STATUS_FILE + from .flask_monitor import ( + start_flask_server, + DEFAULT_STATUS_FILE as FLASK_DEFAULT_STATUS_FILE, + ) + FLASK_AVAILABLE = True except ImportError as e: FLASK_AVAILABLE = False @@ -41,6 +45,7 @@ def start_flask_monitor_server( host: str, open_browser: bool, use_wsgi: bool = False, + allowed_data_sources: Optional[list] = None, ) -> int: """ 启动Flask监控服务器 @@ -51,6 +56,7 @@ def start_flask_monitor_server( host: 主机地址 open_browser: 是否自动打开浏览器 use_wsgi: 是否使用Waitress WSGI服务器替代Flask开发服务器 + allowed_data_sources: 允许的数据源类型列表,如["local"]或["remote"],默认None表示两者都允许 Returns: 进程退出码 @@ -61,27 +67,36 @@ def start_flask_monitor_server( typer.echo("请安装Flask: pip install flask") typer.echo("或在pyproject.toml中添加flask依赖") return 1 - - # 设置环境变量,传递状态文件路径 - os.environ["TRAINING_STATUS_FILE"] = os.path.abspath(status_file) - + server_type = "Waitress WSGI" if use_wsgi else "Flask" typer.echo(f"🚀 启动训练监控服务 ({server_type}版本)...") typer.echo(f"📁 状态文件: {os.path.abspath(status_file)}") typer.echo(f"🌐 监控地址: http://{host}:{port}") typer.echo(f"📊 API接口: http://{host}:{port}/api/status") - + if allowed_data_sources is None: + typer.echo(f"📋 允许的数据源: local, remote") + else: + typer.echo(f"📋 允许的数据源: {', '.join(allowed_data_sources)}") + if open_browser: # 等待服务器启动后打开浏览器 threading.Timer(2.0, lambda: webbrowser.open(f"http://{host}:{port}")).start() typer.echo("🌐 正在打开浏览器...") - + typer.echo("\n按 Ctrl+C 停止监控服务\n") - + try: # 导入并启动Flask服务器 from .flask_monitor import start_flask_server - return start_flask_server(host=host, port=port, debug=False, use_wsgi=use_wsgi) + + return start_flask_server( + host=host, + port=port, + debug=False, + use_wsgi=use_wsgi, + status_file=status_file, + allowed_data_sources=allowed_data_sources, + ) except KeyboardInterrupt: typer.echo("\n🛑 监控服务已停止") return 0 @@ -101,7 +116,9 @@ def monitor_training( port: int = typer.Option(8501, "--port", "-p", help="监控服务端口号"), host: str = typer.Option("0.0.0.0", "--host", help="监控服务主机地址"), open_browser: bool = typer.Option(False, "--open-browser", help="自动打开浏览器"), - use_wsgi: bool = typer.Option(False, "--use-wsgi", help="使用Waitress WSGI服务器替代Flask开发服务器"), + use_wsgi: bool = typer.Option( + False, "--use-wsgi", help="使用Waitress WSGI服务器替代Flask开发服务器" + ), ): """ 启动AI模型训练监控服务 (Flask版本) @@ -131,6 +148,13 @@ def monitor_training( typer.echo("或在pyproject.toml中添加flask依赖") raise typer.Exit(code=1) + # 根据是否提供了-s参数决定允许的数据源 + default_status_file = "./output/training_status.json" + if status_file == default_status_file: + allowed_data_sources = ["remote"] # 未提供-s,只允许远程 + else: + allowed_data_sources = ["local"] # 提供了-s,只允许本地 + # 启动Flask服务器 return_code = start_flask_monitor_server( status_file=status_file, @@ -138,6 +162,7 @@ def monitor_training( host=host, open_browser=open_browser, use_wsgi=use_wsgi, + allowed_data_sources=allowed_data_sources, ) raise typer.Exit(code=return_code) diff --git a/src/model/templates/index.html b/src/model/templates/index.html index e2cb27f..9af898c 100644 --- a/src/model/templates/index.html +++ b/src/model/templates/index.html @@ -273,6 +273,7 @@ {% endblock %} {% block extra_js %}