From e8eab1f260560c9aeb0ec5ada3e5c525310569c3 Mon Sep 17 00:00:00 2001 From: songsenand Date: Sat, 9 May 2026 13:36:48 +0800 Subject: [PATCH] =?UTF-8?q?refactor(generate=5Fpinyin):=20=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E6=8B=BC=E9=9F=B3=E7=94=9F=E6=88=90=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E5=88=A9=E7=94=A8=20pypinyin=20=E5=88=86=E8=AF=8D?= =?UTF-8?q?=E8=83=BD=E5=8A=9B=E5=A4=84=E7=90=86=E5=A4=9A=E9=9F=B3=E5=AD=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- eval.py | 33 +- .../analyze_frequency.py | 140 +++++---- analyze_range.py => scripts/analyze_range.py | 84 +++-- check_weights.py => scripts/check_weights.py | 9 +- .../comprehensive_analysis.py | 146 ++++++--- resign_stat.py => scripts/resign_stat.py | 22 +- .../visualize_distribution.py | 100 +++--- src/model/dataset.py | 292 ++++++++++-------- src/model/supplement_missing.py | 197 +++++++++++- test.py => tests/test.py | 11 +- test_compile.py => tests/test_compile.py | 8 +- .../test_cpu_inference.py | 6 +- test_dataset.py => tests/test_dataset.py | 56 +++- .../test_epoch_checkpoint.py | 0 test_trainer.py => tests/test_trainer.py | 4 +- 15 files changed, 739 insertions(+), 369 deletions(-) rename analyze_frequency.py => scripts/analyze_frequency.py (79%) rename analyze_range.py => scripts/analyze_range.py (69%) rename check_weights.py => scripts/check_weights.py (68%) rename comprehensive_analysis.py => scripts/comprehensive_analysis.py (68%) rename resign_stat.py => scripts/resign_stat.py (84%) rename visualize_distribution.py => scripts/visualize_distribution.py (69%) rename test.py => tests/test.py (91%) rename test_compile.py => tests/test_compile.py (98%) rename test_cpu_inference.py => tests/test_cpu_inference.py (98%) rename test_dataset.py => tests/test_dataset.py (58%) rename test_epoch_checkpoint.py => tests/test_epoch_checkpoint.py (100%) rename test_trainer.py => tests/test_trainer.py (99%) diff --git a/eval.py b/eval.py index 4429e83..8e8d6f1 100644 --- a/eval.py +++ b/eval.py @@ -14,7 +14,6 @@ eval.py - 评估模型在给定文本上的表现 import argparse import random -import re import sys from pathlib import Path from typing import Dict, List, Tuple, Optional @@ -33,8 +32,6 @@ from src.model.model import InputMethodEngine from src.model.query import QueryEngine from src.model.dataset import text_to_pinyin_ids -_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") - class TextEvaluator: def __init__( @@ -171,34 +168,20 @@ class TextEvaluator: def generate_pinyin(self, text: str) -> List[str]: """ - 流式处理单条文本,转换为拼音列表。 - 参考dataset.py中的generate_pinyin方法。 + 将文本转换为拼音列表。对整段文本调用 lazy_pinyin, + 利用 pypinyin 内部的分词能力处理多音字。 + 参考 dataset.py 中的 generate_pinyin 方法。 """ if not text: return [] - text_len = len(text) - result: List[str] = [""] * text_len + pinyin_list = lazy_pinyin(text) - # 遍历所有连续汉字片段 - for match in _HANZI_RE.finditer(text): - start_idx = match.start() - hanzi_segment = match.group() + # 健壮性兜底:若长度不匹配(极罕见),降级为逐字转换 + if len(pinyin_list) != len(text): + pinyin_list = [lazy_pinyin(c)[0] for c in text] - pinyin_list = lazy_pinyin(hanzi_segment) - - if len(pinyin_list) != len(hanzi_segment): - pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment] - - for i, py in enumerate(pinyin_list): - result[start_idx + i] = py - - # 填充非汉字字符 - for i, char in enumerate(text): - if not result[i]: - result[i] = char - - return result + return pinyin_list def get_mask_pinyin( self, text: str, pinyin_list: List[str] diff --git a/analyze_frequency.py b/scripts/analyze_frequency.py similarity index 79% rename from analyze_frequency.py rename to scripts/analyze_frequency.py index c018105..5215014 100644 --- a/analyze_frequency.py +++ b/scripts/analyze_frequency.py @@ -10,34 +10,41 @@ import math from collections import defaultdict from pathlib import Path + def main(): # Path to the JSON file - json_path = Path("src/model/assets/pinyin_char_statistics.json") + json_path = ( + Path(__file__).parent.parent + / "src" + / "model" + / "assets" + / "pinyin_char_statistics.json" + ) if not json_path.exists(): print(f"Error: File not found: {json_path}") sys.exit(1) - + print(f"Loading {json_path}...") - with open(json_path, 'r', encoding='utf-8') as f: + with open(json_path, "r", encoding="utf-8") as f: data = json.load(f) - + print(f"Timestamp: {data.get('timestamp')}") print(f"Total characters: {data.get('total_characters')}") print(f"Total pinyins: {data.get('total_pinyins')}") print(f"Valid input character count: {data.get('valid_input_character_count')}") - - pairs = data.get('pairs', {}) + + pairs = data.get("pairs", {}) print(f"Number of pairs: {len(pairs)}") - + # Extract counts and IDs counts = [] id_to_count = {} char_to_count = {} for key, pair in pairs.items(): try: - char_id = pair.get('id') - count = pair.get('count') - char = pair.get('char', '') + char_id = pair.get("id") + count = pair.get("count") + char = pair.get("char", "") if char_id is not None and count is not None: counts.append(count) id_to_count[char_id] = count @@ -46,21 +53,21 @@ def main(): except (ValueError, TypeError) as e: print(f"Warning: Could not parse pair {key}: {e}") continue - + if not counts: print("No valid count data found.") return - + # Basic statistics min_count = min(counts) max_count = max(counts) total_count = sum(counts) mean_count = total_count / len(counts) - + # Sort counts for percentiles sorted_counts = sorted(counts) n = len(sorted_counts) - + # Percentiles p10 = sorted_counts[int(0.1 * n)] p25 = sorted_counts[int(0.25 * n)] @@ -68,11 +75,11 @@ def main(): p75 = sorted_counts[int(0.75 * n)] p90 = sorted_counts[int(0.9 * n)] p99 = sorted_counts[int(0.99 * n)] - + # Variance and std dev variance = sum((x - mean_count) ** 2 for x in counts) / n std_dev = math.sqrt(variance) - + print("\n=== BASIC STATISTICS ===") print(f"Min frequency: {min_count}") print(f"Max frequency: {max_count}") @@ -80,7 +87,7 @@ def main(): print(f"Standard deviation: {std_dev:.2f}") print(f"Total frequency sum: {total_count}") print(f"Number of entries: {n}") - + print("\n=== PERCENTILES ===") print(f"10th percentile: {p10}") print(f"25th percentile: {p25}") @@ -88,52 +95,54 @@ def main(): print(f"75th percentile: {p75}") print(f"90th percentile: {p90}") print(f"99th percentile: {p99}") - + # Find IDs with min and max counts min_ids = [id for id, count in id_to_count.items() if count == min_count] max_ids = [id for id, count in id_to_count.items() if count == max_count] - + print(f"\nIDs with min frequency ({min_count}): {min_ids}") print(f"IDs with max frequency ({max_count}): {max_ids}") - + # Check if IDs are assigned in frequency order # Compute correlation between ID and count ids = list(id_to_count.keys()) id_counts = [id_to_count[id] for id in ids] - + # Sort by ID and check if counts are decreasing sorted_by_id = sorted(ids) counts_by_id = [id_to_count[id] for id in sorted_by_id] - + # Calculate monotonicity: count of times count decreases as ID increases decreases = 0 increases = 0 for i in range(1, len(counts_by_id)): - if counts_by_id[i] < counts_by_id[i-1]: + if counts_by_id[i] < counts_by_id[i - 1]: decreases += 1 - elif counts_by_id[i] > counts_by_id[i-1]: + elif counts_by_id[i] > counts_by_id[i - 1]: increases += 1 - + print(f"\n=== ID ORDER ANALYSIS ===") print(f"Total pairs: {len(counts_by_id)}") print(f"Decreases as ID increases: {decreases} times") print(f"Increases as ID increases: {increases} times") - print(f"Percentage decreasing: {decreases/(len(counts_by_id)-1)*100:.2f}%") - + print(f"Percentage decreasing: {decreases / (len(counts_by_id) - 1) * 100:.2f}%") + # Check if IDs are roughly sorted by frequency # Compute Spearman rank correlation (simplified) sorted_by_count = sorted(ids, key=lambda x: id_to_count[x], reverse=True) rank_by_id = {id: i for i, id in enumerate(sorted_by_id)} rank_by_count = {id: i for i, id in enumerate(sorted_by_count)} - + # Average rank difference rank_diffs = [abs(rank_by_id[id] - rank_by_count[id]) for id in ids] avg_rank_diff = sum(rank_diffs) / len(rank_diffs) max_rank_diff = max(rank_diffs) - - print(f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}") + + print( + f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}" + ) print(f"Maximum rank difference: {max_rank_diff}") - + # Analyze specific ID range 5000-5500 print("\n=== ANALYSIS OF ID RANGE 5000-5500 ===") range_counts = [] @@ -142,7 +151,7 @@ def main(): if id in id_to_count: range_counts.append(id_to_count[id]) range_ids.append(id) - + if range_counts: range_min = min(range_counts) range_max = max(range_counts) @@ -152,7 +161,7 @@ def main(): range_p10 = range_sorted[int(0.1 * range_n)] if range_n > 0 else 0 range_p50 = range_sorted[int(0.5 * range_n)] if range_n > 0 else 0 range_p90 = range_sorted[int(0.9 * range_n)] if range_n > 0 else 0 - + print(f"IDs in range 5000-5500: {len(range_counts)}") print(f"Min frequency in range: {range_min}") print(f"Max frequency in range: {range_max}") @@ -160,66 +169,82 @@ def main(): print(f"10th percentile in range: {range_p10}") print(f"50th percentile in range: {range_p50}") print(f"90th percentile in range: {range_p90}") - + # Find IDs with min frequency in this range min_in_range_ids = [id for id in range_ids if id_to_count[id] == range_min] - print(f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}") + print( + f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}" + ) else: print("No IDs found in range 5000-5500") - + # Histogram of frequencies (log bins) print("\n=== FREQUENCY DISTRIBUTION (LOG BINS) ===") if max_count > 0: log_min = math.log10(min_count) if min_count > 0 else 0 log_max = math.log10(max_count) num_bins = 20 - bin_edges = [10**(log_min + i*(log_max-log_min)/num_bins) for i in range(num_bins+1)] - + bin_edges = [ + 10 ** (log_min + i * (log_max - log_min) / num_bins) + for i in range(num_bins + 1) + ] + hist = [0] * num_bins for count in counts: if count > 0: log_val = math.log10(count) - bin_idx = min(int((log_val - log_min) / (log_max - log_min) * num_bins), num_bins-1) + bin_idx = min( + int((log_val - log_min) / (log_max - log_min) * num_bins), + num_bins - 1, + ) hist[bin_idx] += 1 - + print("Log-scale histogram (count range -> frequency count):") for i in range(num_bins): if hist[i] > 0: lower = bin_edges[i] - upper = bin_edges[i+1] + upper = bin_edges[i + 1] print(f" {lower:.2e} - {upper:.2e}: {hist[i]} entries") - + # Check for zero or near-zero frequencies zero_count = sum(1 for c in counts if c == 0) low_count = sum(1 for c in counts if 0 < c <= 10) very_low_count = sum(1 for c in counts if 0 < c <= 100) - + print(f"\n=== LOW FREQUENCY ANALYSIS ===") print(f"Entries with zero frequency: {zero_count}") print(f"Entries with frequency <= 10: {low_count}") print(f"Entries with frequency <= 100: {very_low_count}") - + # Find the actual min frequency (excluding zeros if any) non_zero_counts = [c for c in counts if c > 0] if non_zero_counts: actual_min = min(non_zero_counts) print(f"Actual min frequency (non-zero): {actual_min}") - actual_min_ids = [id for id, count in id_to_count.items() if count == actual_min] - print(f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}") - + actual_min_ids = [ + id for id, count in id_to_count.items() if count == actual_min + ] + print( + f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}" + ) + # Summary for smoothing algorithm design print("\n=== SUMMARY FOR SMOOTHING ALGORITHM DESIGN ===") - print(f"Frequency range spans {max_count/min_count if min_count>0 else 'inf'}:1 ratio") + print( + f"Frequency range spans {max_count / min_count if min_count > 0 else 'inf'}:1 ratio" + ) print(f"Most entries ({p50}) have frequency around {p50}") print(f"Top 10% of entries have frequency > {p90}") print(f"Bottom 10% of entries have frequency < {p10}") - print(f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency") - + print( + f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency" + ) + # Save detailed data for further analysis output_file = "frequency_analysis_results.txt" - with open(output_file, 'w', encoding='utf-8') as f: + with open(output_file, "w", encoding="utf-8") as f: f.write("Frequency Analysis Results\n") - f.write("="*50 + "\n") + f.write("=" * 50 + "\n") f.write(f"Min frequency: {min_count}\n") f.write(f"Max frequency: {max_count}\n") f.write(f"Mean frequency: {mean_count:.2f}\n") @@ -227,10 +252,15 @@ def main(): f.write(f"10th percentile: {p10}\n") f.write(f"50th percentile: {p50}\n") f.write(f"90th percentile: {p90}\n") - f.write(f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n") - f.write(f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n") - + f.write( + f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n" + ) + f.write( + f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n" + ) + print(f"\nDetailed results saved to {output_file}") + if __name__ == "__main__": main() diff --git a/analyze_range.py b/scripts/analyze_range.py similarity index 69% rename from analyze_range.py rename to scripts/analyze_range.py index 0df339d..914790f 100644 --- a/analyze_range.py +++ b/scripts/analyze_range.py @@ -7,58 +7,70 @@ import json import sys from pathlib import Path + def main(): - json_path = Path("src/model/assets/pinyin_char_statistics.json") - with open(json_path, 'r', encoding='utf-8') as f: + json_path = ( + Path(__file__).parent.parent + / "src" + / "model" + / "assets" + / "pinyin_char_statistics.json" + ) + with open(json_path, "r", encoding="utf-8") as f: data = json.load(f) - - pairs = data.get('pairs', {}) - + + pairs = data.get("pairs", {}) + # Build ID to count mapping id_to_count = {} for key, pair in pairs.items(): - char_id = pair.get('id') - count = pair.get('count') + char_id = pair.get("id") + count = pair.get("count") if char_id is not None and count is not None: id_to_count[char_id] = count - + # Analyze range 5000-5500 in detail print("ID range 5000-5500 detailed analysis:") print("ID\tCount\tChar\tPinyin") - + range_data = [] for id in range(5000, 5501): if id in id_to_count: # Find the pair to get char and pinyin for key, pair in pairs.items(): - if pair.get('id') == id: - char = pair.get('char', '') - pinyin = pair.get('pinyin', '') - count = pair.get('count', 0) + if pair.get("id") == id: + char = pair.get("char", "") + pinyin = pair.get("pinyin", "") + count = pair.get("count", 0) range_data.append((id, count, char, pinyin)) if id % 100 == 0: # Print every 100th for overview print(f"{id}\t{count}\t{char}\t{pinyin}") break - + # Print min and max in range if range_data: min_item = min(range_data, key=lambda x: x[1]) max_item = max(range_data, key=lambda x: x[1]) - print(f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'") - print(f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'") - + print( + f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'" + ) + print( + f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'" + ) + # Check if frequencies are monotonic in this range counts = [item[1] for item in range_data] - increasing = all(counts[i] <= counts[i+1] for i in range(len(counts)-1)) - decreasing = all(counts[i] >= counts[i+1] for i in range(len(counts)-1)) + increasing = all(counts[i] <= counts[i + 1] for i in range(len(counts) - 1)) + decreasing = all(counts[i] >= counts[i + 1] for i in range(len(counts) - 1)) print(f"Monotonic in range: increasing={increasing}, decreasing={decreasing}") - + # Check for frequency plateaus from collections import Counter + freq_count = Counter(counts) most_common = freq_count.most_common(5) print(f"Most common frequencies in range: {most_common}") - + # Analyze the tail (IDs with frequency 1) print("\n\nAnalysis of frequency=1 entries:") freq_one_ids = [id for id, count in id_to_count.items() if count == 1] @@ -67,30 +79,34 @@ def main(): print(f"ID range of frequency=1: {min(freq_one_ids)} to {max(freq_one_ids)}") print(f"First 10 IDs: {freq_one_ids[:10]}") print(f"Last 10 IDs: {freq_one_ids[-10:]}") - + # Check if they're contiguous sorted_ids = sorted(freq_one_ids) - contiguous = all(sorted_ids[i] + 1 == sorted_ids[i+1] for i in range(len(sorted_ids)-1)) + contiguous = all( + sorted_ids[i] + 1 == sorted_ids[i + 1] for i in range(len(sorted_ids) - 1) + ) print(f"Are they contiguous IDs? {contiguous}") - + # Sample some characters print("\nSample characters with frequency=1:") sample_count = 0 for key, pair in pairs.items(): - if pair.get('count') == 1 and sample_count < 10: - print(f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'") + if pair.get("count") == 1 and sample_count < 10: + print( + f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'" + ) sample_count += 1 - + # Check overall ID-frequency ordering print("\n\nOverall ID-frequency ordering analysis:") all_ids = sorted(id_to_count.keys()) all_counts = [id_to_count[id] for id in all_ids] - + # Count monotonic segments non_increasing_segments = 0 current_segment_length = 1 for i in range(1, len(all_counts)): - if all_counts[i] <= all_counts[i-1]: + if all_counts[i] <= all_counts[i - 1]: current_segment_length += 1 else: if current_segment_length > 1: @@ -98,18 +114,22 @@ def main(): current_segment_length = 1 if current_segment_length > 1: non_increasing_segments += 1 - + print(f"Total IDs: {len(all_ids)}") print(f"Non-increasing segments: {non_increasing_segments}") - + # Check for frequency plateaus overall from collections import Counter + overall_freq_count = Counter(all_counts) - plateaus = [(freq, count) for freq, count in overall_freq_count.items() if count > 1] + plateaus = [ + (freq, count) for freq, count in overall_freq_count.items() if count > 1 + ] plateaus_sorted = sorted(plateaus, key=lambda x: x[1], reverse=True)[:10] print(f"Top 10 frequency plateaus (freq: count of IDs sharing that freq):") for freq, count in plateaus_sorted: print(f" {freq}: {count} IDs") + if __name__ == "__main__": main() diff --git a/check_weights.py b/scripts/check_weights.py similarity index 68% rename from check_weights.py rename to scripts/check_weights.py index d77e015..4c4a544 100755 --- a/check_weights.py +++ b/scripts/check_weights.py @@ -1,10 +1,15 @@ +import os +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) + from model.dataset import PinyinInputDataset from torch.utils.data import DataLoader from model.trainer import collate_fn, worker_init_fn -data = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/') +data = PinyinInputDataset("/home/songsenand/Data/corpus/CCI-Data/") dataloader = DataLoader( data, @@ -18,5 +23,5 @@ dataloader = DataLoader( ) for i in dataloader: - print((i['labels'] == 1).sum()) + print((i["labels"] == 1).sum()) break diff --git a/comprehensive_analysis.py b/scripts/comprehensive_analysis.py similarity index 68% rename from comprehensive_analysis.py rename to scripts/comprehensive_analysis.py index 25ca936..0c386d5 100644 --- a/comprehensive_analysis.py +++ b/scripts/comprehensive_analysis.py @@ -9,90 +9,125 @@ import math from collections import Counter from pathlib import Path + def main(): - json_path = Path("src/model/assets/pinyin_char_statistics.json") - with open(json_path, 'r', encoding='utf-8') as f: + json_path = ( + Path(__file__).parent.parent + / "src" + / "model" + / "assets" + / "pinyin_char_statistics.json" + ) + with open(json_path, "r", encoding="utf-8") as f: data = json.load(f) - - pairs = data.get('pairs', {}) - + + pairs = data.get("pairs", {}) + # Extract counts counts = [] for key, pair in pairs.items(): - count = pair.get('count') + count = pair.get("count") if count is not None: counts.append(count) - + n = len(counts) print(f"Total entries: {n}") - + # Sort descending for rank-frequency analysis counts_sorted_desc = sorted(counts, reverse=True) - + # Basic statistics min_count = min(counts) max_count = max(counts) mean_count = sum(counts) / n - + # Percentiles percentiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99] print("\n=== PERCENTILE DISTRIBUTION ===") for p in percentiles: idx = int(p * n) value = counts_sorted_desc[idx] - print(f"{p*100:5.1f}%: {value:>12} (rank ~{idx})") - + print(f"{p * 100:5.1f}%: {value:>12} (rank ~{idx})") + # Cumulative distribution print("\n=== CUMULATIVE DISTRIBUTION ===") - thresholds = [1, 2, 3, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 500000, 1000000, 5000000, 10000000, 50000000, 100000000, 500000000] + thresholds = [ + 1, + 2, + 3, + 5, + 10, + 20, + 50, + 100, + 200, + 500, + 1000, + 2000, + 5000, + 10000, + 20000, + 50000, + 100000, + 200000, + 500000, + 1000000, + 5000000, + 10000000, + 50000000, + 100000000, + 500000000, + ] for thresh in thresholds: if thresh > max_count: break below = sum(1 for c in counts if c <= thresh) above = sum(1 for c in counts if c >= thresh) - print(f"Count <= {thresh:10}: {below:6} entries ({below/n*100:5.1f}%)") + print(f"Count <= {thresh:10}: {below:6} entries ({below / n * 100:5.1f}%)") # print(f"Count >= {thresh:10}: {above:6} entries ({above/n*100:5.1f}%)") - + # Check min_count=109 parameter print("\n=== ANALYSIS OF THRESHOLD 109 ===") below_109 = sum(1 for c in counts if c < 109) at_or_above_109 = sum(1 for c in counts if c >= 109) - print(f"Entries with count < 109: {below_109} ({below_109/n*100:.1f}%)") - print(f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109/n*100:.1f}%)") - + print(f"Entries with count < 109: {below_109} ({below_109 / n * 100:.1f}%)") + print( + f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109 / n * 100:.1f}%)" + ) + # If 109 is a threshold, what's the actual min among those >= 109? counts_ge_109 = [c for c in counts if c >= 109] if counts_ge_109: actual_min_ge_109 = min(counts_ge_109) print(f"Actual min frequency among those >= 109: {actual_min_ge_109}") - + # Rank-frequency analysis (Zipf's law) print("\n=== RANK-FREQUENCY ANALYSIS (Top 100) ===") print("Rank\tFrequency\tlog(rank)\tlog(freq)") for rank in range(1, 101): - freq = counts_sorted_desc[rank-1] + freq = counts_sorted_desc[rank - 1] print(f"{rank}\t{freq}\t{math.log(rank):.3f}\t{math.log(freq):.3f}") - + # Frequency spectrum (how many distinct frequencies) freq_counter = Counter(counts) print(f"\n=== FREQUENCY SPECTRUM ===") print(f"Distinct frequency values: {len(freq_counter)}") - + # Most common frequencies print("\nTop 20 most common frequencies (plateau sizes):") for freq, freq_count in freq_counter.most_common(20): print(f" Frequency {freq}: {freq_count} entries") - + # Analyze ID ranges print("\n=== ID RANGE ANALYSIS ===") # Build ID to count mapping id_to_count = {} for key, pair in pairs.items(): - char_id = pair.get('id') - count = pair.get('count') + char_id = pair.get("id") + count = pair.get("count") if char_id is not None and count is not None: id_to_count[char_id] = count - + ranges = [ (0, 100, "Top 100 IDs"), (100, 500, "IDs 100-500"), @@ -106,66 +141,80 @@ def main(): (19000, 19500, "IDs 19000-19500 (before freq=1)"), (19499, 20647, "IDs with freq=1"), ] - + for start, end, label in ranges: - range_counts = [id_to_count[id] for id in range(start, end) if id in id_to_count] + range_counts = [ + id_to_count[id] for id in range(start, end) if id in id_to_count + ] if range_counts: min_c = min(range_counts) max_c = max(range_counts) mean_c = sum(range_counts) / len(range_counts) - median_c = sorted(range_counts)[len(range_counts)//2] - print(f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}") - + median_c = sorted(range_counts)[len(range_counts) // 2] + print( + f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}" + ) + # Check if IDs are perfectly sorted by frequency print("\n=== ID ORDER VERIFICATION ===") all_ids = sorted(id_to_count.keys()) all_counts = [id_to_count[id] for id in all_ids] - + # Check for any violations of non-increasing order violations = 0 for i in range(1, len(all_counts)): - if all_counts[i] > all_counts[i-1]: + if all_counts[i] > all_counts[i - 1]: violations += 1 if violations <= 5: - print(f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i-1]} (ID {all_ids[i-1]})") - + print( + f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i - 1]} (ID {all_ids[i - 1]})" + ) + print(f"Total violations of non-increasing order: {violations}") - + # Check if equal frequencies are grouped together print("\n=== FREQUENCY GROUPING ANALYSIS ===") current_freq = None group_start = None group_sizes = [] - + for i, (id, count) in enumerate(zip(all_ids, all_counts)): if count != current_freq: if current_freq is not None: - group_sizes.append((current_freq, group_start, all_ids[i-1], i - group_start)) + group_sizes.append( + (current_freq, group_start, all_ids[i - 1], i - group_start) + ) current_freq = count group_start = i - + # Last group if current_freq is not None: - group_sizes.append((current_freq, group_start, all_ids[-1], len(all_ids) - group_start)) - + group_sizes.append( + (current_freq, group_start, all_ids[-1], len(all_ids) - group_start) + ) + # Sort groups by size group_sizes.sort(key=lambda x: x[3], reverse=True) print("Top 10 largest frequency groups (plateaus):") for freq, start_id_idx, end_id, size in group_sizes[:10]: start_id = all_ids[start_id_idx] print(f" Frequency {freq}: IDs {start_id}-{end_id} ({size} entries)") - + # Summary for smoothing algorithm print("\n=== SMOOTHING ALGORITHM IMPLICATIONS ===") print("1. IDs are perfectly sorted by frequency (non-increasing).") - print(f"2. Frequency range: {min_count} to {max_count} (ratio {max_count/min_count:.1e}:1).") - print(f"3. {below_109} entries ({below_109/n*100:.1f}%) have frequency < 109.") - print(f"4. Median frequency: {counts_sorted_desc[n//2]}.") - print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9*n)]}.") - print(f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01*n)]}.") + print( + f"2. Frequency range: {min_count} to {max_count} (ratio {max_count / min_count:.1e}:1)." + ) + print(f"3. {below_109} entries ({below_109 / n * 100:.1f}%) have frequency < 109.") + print(f"4. Median frequency: {counts_sorted_desc[n // 2]}.") + print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9 * n)]}.") + print( + f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01 * n)]}." + ) print("7. Large frequency plateaus exist (many IDs share same frequency).") print("8. Smoothing should handle extreme frequency ratios (1:5e8).") - + # Save data for plotting with open("rank_freq.csv", "w") as f: f.write("rank,frequency\n") @@ -173,5 +222,6 @@ def main(): f.write(f"{rank},{freq}\n") print("\nRank-frequency data saved to rank_freq.csv") + if __name__ == "__main__": main() diff --git a/resign_stat.py b/scripts/resign_stat.py similarity index 84% rename from resign_stat.py rename to scripts/resign_stat.py index a451002..6f52764 100644 --- a/resign_stat.py +++ b/scripts/resign_stat.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Dict, Any import sys + def modify_pinyin_statistics(file_path: Path) -> None: """ 一次性修改拼音统计JSON文件。 @@ -13,7 +14,7 @@ def modify_pinyin_statistics(file_path: Path) -> None: """ # 1. 加载原数据 try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: data: Dict[str, Any] = json.load(f) except FileNotFoundError: print(f"错误:文件不存在 {file_path}", file=sys.stderr) @@ -34,7 +35,7 @@ def modify_pinyin_statistics(file_path: Path) -> None: "id": 0, "char": "", "pinyin": "", - "count": original_zero_count + 1 # 原count + 1 + "count": original_zero_count + 1, # 原count + 1 } # 3.2 处理其他所有记录,键和id都+1 @@ -54,15 +55,16 @@ def modify_pinyin_statistics(file_path: Path) -> None: # 这里保持原时间戳不变,因为是一次性修改 # 写回文件,保持可读格式 - backup_path = file_path.with_suffix('.json.bak') + backup_path = file_path.with_suffix(".json.bak") try: # 先备份原文件 import shutil + shutil.copy2(file_path, backup_path) print(f"已创建备份: {backup_path}") # 写入新数据 - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) print(f"修改完成!") @@ -78,15 +80,21 @@ def modify_pinyin_statistics(file_path: Path) -> None: # 使用示例 if __name__ == "__main__": - # 假设你的JSON文件在当前目录 - json_file = Path("./src/model/assets/pinyin_char_statistics.json") + # JSON文件相对于项目根目录 + json_file = ( + Path(__file__).parent.parent + / "src" + / "model" + / "assets" + / "pinyin_char_statistics.json" + ) # 执行修改 modify_pinyin_statistics(json_file) # 验证修改:读取并显示前几条记录 print("\n验证前5条记录:") - with open(json_file, 'r', encoding='utf-8') as f: + with open(json_file, "r", encoding="utf-8") as f: data = json.load(f) for i in range(5): diff --git a/visualize_distribution.py b/scripts/visualize_distribution.py similarity index 69% rename from visualize_distribution.py rename to scripts/visualize_distribution.py index 2cd7be9..2f18ae1 100644 --- a/visualize_distribution.py +++ b/scripts/visualize_distribution.py @@ -8,110 +8,127 @@ import math import sys from pathlib import Path + def ascii_histogram(data, bins=20, width=60): """Create ASCII histogram""" if not data: return "" - + min_val = min(data) max_val = max(data) - + # Use log bins for wide range if max_val / min_val > 1000: log_min = math.log10(min_val) if min_val > 0 else 0 log_max = math.log10(max_val) - bin_edges = [10**(log_min + i*(log_max-log_min)/bins) for i in range(bins+1)] + bin_edges = [ + 10 ** (log_min + i * (log_max - log_min) / bins) for i in range(bins + 1) + ] hist = [0] * bins for val in data: if val > 0: log_val = math.log10(val) - bin_idx = min(int((log_val - log_min) / (log_max - log_min) * bins), bins-1) + bin_idx = min( + int((log_val - log_min) / (log_max - log_min) * bins), bins - 1 + ) hist[bin_idx] += 1 - bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i+1]:.1e}" for i in range(bins)] + bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i + 1]:.1e}" for i in range(bins)] else: bin_width = (max_val - min_val) / bins - bin_edges = [min_val + i*bin_width for i in range(bins+1)] + bin_edges = [min_val + i * bin_width for i in range(bins + 1)] hist = [0] * bins for val in data: - bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins-1) + bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins - 1) hist[bin_idx] += 1 - bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i+1]:.1f}" for i in range(bins)] - + bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i + 1]:.1f}" for i in range(bins)] + max_count = max(hist) result = [] for i in range(bins): if hist[i] == 0: continue - bar = '#' * int(hist[i] / max_count * width) + bar = "#" * int(hist[i] / max_count * width) result.append(f"{bin_labels[i]:20} | {bar} {hist[i]}") - + return "\n".join(result) + def main(): - json_path = Path("src/model/assets/pinyin_char_statistics.json") - with open(json_path, 'r', encoding='utf-8') as f: + json_path = ( + Path(__file__).parent.parent + / "src" + / "model" + / "assets" + / "pinyin_char_statistics.json" + ) + with open(json_path, "r", encoding="utf-8") as f: data = json.load(f) - - pairs = data.get('pairs', {}) - counts = [pair.get('count', 0) for pair in pairs.values() if pair.get('count') is not None] - + + pairs = data.get("pairs", {}) + counts = [ + pair.get("count", 0) for pair in pairs.values() if pair.get("count") is not None + ] + print("FREQUENCY DISTRIBUTION ANALYSIS") - print("="*60) - + print("=" * 60) + print("\n1. ASCII Histogram (log bins):") print(ascii_histogram(counts, bins=20, width=60)) - + # Rank-frequency plot in ASCII print("\n2. Rank-Frequency Relationship (Top 50):") counts_sorted_desc = sorted(counts, reverse=True) max_freq = counts_sorted_desc[0] max_rank = 50 - + for rank in range(1, max_rank + 1): - freq = counts_sorted_desc[rank-1] + freq = counts_sorted_desc[rank - 1] bar_length = int(math.log(freq) / math.log(max_freq) * 40) - bar = '#' * bar_length + bar = "#" * bar_length print(f"Rank {rank:3}: {freq:12} {bar}") - + # ID vs Frequency plot (sampled) print("\n3. ID vs Frequency (sampled every 500 IDs):") # Build ID to count mapping id_to_count = {} for key, pair in pairs.items(): - char_id = pair.get('id') - count = pair.get('count') + char_id = pair.get("id") + count = pair.get("count") if char_id is not None and count is not None: id_to_count[char_id] = count - + all_ids = sorted(id_to_count.keys()) max_id = all_ids[-1] - + print("ID Frequency log10(freq)") for id in range(0, max_id + 1, 500): if id in id_to_count: freq = id_to_count[id] log_freq = math.log10(freq) if freq > 0 else 0 - bar = '#' * int(log_freq / math.log10(max_freq) * 40) + bar = "#" * int(log_freq / math.log10(max_freq) * 40) print(f"{id:6} {freq:10} {log_freq:6.2f} {bar}") - + # Zipf's law fit print("\n4. Zipf's Law Analysis:") print(" Rank * Frequency ≈ constant for Zipf's law") print(" Top 10 ranks:") for rank in range(1, 11): - freq = counts_sorted_desc[rank-1] + freq = counts_sorted_desc[rank - 1] product = rank * freq print(f" Rank {rank}: {freq:12} rank*freq = {product:.3e}") - + # Check if product is roughly constant - products = [(rank+1) * counts_sorted_desc[rank] for rank in range(10)] + products = [(rank + 1) * counts_sorted_desc[rank] for rank in range(10)] avg_product = sum(products) / len(products) - std_product = math.sqrt(sum((p - avg_product)**2 for p in products) / len(products)) + std_product = math.sqrt( + sum((p - avg_product) ** 2 for p in products) / len(products) + ) print(f" Average product (ranks 2-11): {avg_product:.3e} ± {std_product:.3e}") - print(f" Coefficient of variation: {std_product/avg_product*100:.1f}%") - + print(f" Coefficient of variation: {std_product / avg_product * 100:.1f}%") + # Frequency spectrum from collections import Counter + freq_counter = Counter(counts) print("\n5. Frequency Spectrum (how many entries have each frequency):") print(" Frequency Count Cumulative") @@ -120,21 +137,21 @@ def main(): count = freq_counter[freq] cum += count print(f" {freq:10} {count:6} {cum:6}") - + # Summary statistics print("\n6. Key Statistics:") n = len(counts) print(f" Total entries: {n}") print(f" Min frequency: {min(counts)}") print(f" Max frequency: {max(counts)}") - print(f" Ratio max/min: {max(counts)/min(counts):.2e}") - + print(f" Ratio max/min: {max(counts) / min(counts):.2e}") + percentiles = [0.01, 0.1, 0.5, 0.9, 0.99] for p in percentiles: idx = int(p * n) value = counts_sorted_desc[idx] - print(f" {p*100:5.1f}th percentile: {value:12} (rank ~{idx})") - + print(f" {p * 100:5.1f}th percentile: {value:12} (rank ~{idx})") + # Save data for external plotting with open("id_vs_freq.csv", "w") as f: f.write("id,frequency\n") @@ -142,5 +159,6 @@ def main(): f.write(f"{id},{id_to_count[id]}\n") print("\nData saved to id_vs_freq.csv for external plotting") + if __name__ == "__main__": main() diff --git a/src/model/dataset.py b/src/model/dataset.py index 9571291..7e84db7 100644 --- a/src/model/dataset.py +++ b/src/model/dataset.py @@ -5,10 +5,9 @@ warnings.filterwarnings("ignore", message=".*pkg_resources.*") import jieba import math import random -import re from importlib.resources import files from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Set, Tuple import numpy as np import torch @@ -21,7 +20,6 @@ from torch.utils.data import IterableDataset from .query import QueryEngine -_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26 CHAR_TO_ID["`"] = 27 # 显式添加反引号 @@ -76,6 +74,8 @@ class PinyinInputDataset(IterableDataset): merge_max_total_chars: int = 6, low_freq_repeat: float = 50.0, high_freq_repeat: float = 0.1, + data_kwargs: Optional[Dict] = None, + target_labels: Optional[Set[int]] = None, ): # 频率调整参数 - 幂律平滑方案 self.min_freq = 109 @@ -88,6 +88,9 @@ class PinyinInputDataset(IterableDataset): self.merge_max_short_words = merge_max_short_words self.merge_max_total_chars = merge_max_total_chars + self.data_kwargs = data_kwargs or {} + self.target_labels = target_labels + jieba.initialize() self.tokenizer = AutoTokenizer.from_pretrained( @@ -98,7 +101,9 @@ class PinyinInputDataset(IterableDataset): self.max_iter_length = max_iter_length self.max_seq_length = max_seq_length self.text_field = text_field - self.dataset = load_dataset(data_path, split="train", streaming=True) + load_kwargs = {"split": "train", "streaming": True} + load_kwargs.update(self.data_kwargs) + self.dataset = load_dataset(data_path, **load_kwargs) self.max_workers = max_workers self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight) self.shuffle_buffer_size = shuffle_buffer_size @@ -155,12 +160,13 @@ class PinyinInputDataset(IterableDataset): # 生成对应文本的拼音 def generate_pinyin(self, text: str) -> List[str]: """ - 流式处理单条文本,转换为拼音列表。 + 将文本转换为拼音列表。对整段文本调用 lazy_pinyin, + 利用 errors 回调确保一一对应,对生僻字从 QueryEngine 回退。 特性: 1. 严格一一对应:len(result) == len(text) - 2. 高多音字准确率:利用 pypinyin 内部的词语分词能力 - 3. 高性能:预分配内存,无多余对象创建 + 2. 对 pypinyin 不认识的生僻字,回退到 QueryEngine 最高频读音 + 3. 非汉字字符原样占位 Args: text: 输入字符串 @@ -171,39 +177,35 @@ class PinyinInputDataset(IterableDataset): if not text: return [] - text_len = len(text) - # 2. 预分配结果列表,初始化占位符。 - # 使用 None 或空字符串均可,这里用空字符串方便后续判断 - result: List[str] = [""] * text_len + def _fallback(chars): + # lazy_pinyin 会把连续无拼音的字符聚合成一个字符串传入, + # 必须逐字符处理,确保返回列表长度与输入字符数一致。 + result = [] + for char in chars: + if self.query_engine.is_chinese_char(char): + ids = self.query_engine.query_by_char(char, limit=1) + if ids: + result.append(ids[0][1]) + else: + result.append(char) + else: + result.append(char) + return result - # 3. 遍历所有连续汉字片段 - for match in _HANZI_RE.finditer(text): - start_idx = match.start() - hanzi_segment = match.group() + pinyin_list = lazy_pinyin(text, errors=_fallback) - # 4. 核心转换:利用 pypinyin 的分词能力处理该片段 - # style=Style.NORMAL 获取不带声调的拼音 - pinyin_list = lazy_pinyin(hanzi_segment) + # 防御性校验:若长度仍不匹配(极罕见),逐字回退 + if len(pinyin_list) != len(text): + logger.warning( + f"pinyin length mismatch: text_len={len(text)}, " + f"pinyin_len={len(pinyin_list)}, text={text[:50]!r}" + ) + pinyin_list = [] + for c in text: + result = lazy_pinyin(c, errors=_fallback) + pinyin_list.append(result[0] if result else c) - # 5. 健壮性兜底: - # 正常情况下,pypinyin 返回的拼音数应等于汉字数。 - # 若不等(极罕见,如遇到特殊 Unicode 标点被误判为汉字),降级为单字转换 - if len(pinyin_list) != len(hanzi_segment): - pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment] - - # 6. 直接通过索引填充到预分配的位置 - # 这比 list slicing assignment (result[start:end] = pinyin_list) 略快且更直观 - for i, py in enumerate(pinyin_list): - result[start_idx + i] = py - - # 7. 填充非汉字字符 - # 遍历原文,如果 result 对应位置为空,则填入原字符 - # 注意:对于纯汉字文本,这一步很快;对于混合文本,这是必要的 - for i, char in enumerate(text): - if not result[i]: - result[i] = char - - return result + return pinyin_list def get_mask_pinyin( self, text: str, pinyin_list: List[str] @@ -243,51 +245,61 @@ class PinyinInputDataset(IterableDataset): pinyin_ids = pinyin_ids[:24] return torch.tensor(pinyin_ids, dtype=torch.long) - def _add_word_samples( + def _build_single_sample( self, - batch_samples: list, - labels: list, - encoded: dict, - part4: str, - part1: str, - part3: str, - pinyin_str: str, + label: int, + history: list, + text: str, + word_start: int, + word_end: int, + part2: str, pinyin_ids: torch.Tensor, - ) -> list: - for label_idx, label in enumerate(labels): - base_repeats = self.adjust_frequency(self.sample_freqs.get(label, 0)) - if base_repeats == 0: - continue - weight = ( - self._history_weights[label_idx] - if label_idx < len(self._history_weights) - else 3.0 - ) - repeats = max(1, int(base_repeats * weight)) + words: list, + ) -> dict: + """构造单条样本,每次调用都会重新随机采样上下文""" - history = labels[:label_idx] - if len(history) > 8: - history = history[-8:] - else: - history.extend([0] * (8 - len(history))) + # part1 长度:高斯分布 N(36, 6^2),截断 [0, min(48, word_start)] + part1_len = min(max(int(random.gauss(36, 6)), 0), 48, word_start) + part1 = text[word_start - part1_len : word_start] - sample_dict = { - "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), - "token_type_ids": torch.tensor( - encoded["token_type_ids"], dtype=torch.long - ), - "attention_mask": torch.tensor( - encoded["attention_mask"], dtype=torch.long - ), - "label": torch.tensor([label], dtype=torch.long), - "history_slot_ids": torch.tensor(history, dtype=torch.long), - "prefix": f"{part4}^{part1}", - "suffix": part3, - "pinyin": pinyin_str, - "pinyin_ids": pinyin_ids, - } - batch_samples.extend([sample_dict] * repeats) - return batch_samples + # part3:每次重新 roll + part3 = "" + if random.random() > 0.7: + part3 = text[word_end : word_end + random.randint(1, 16)] + + # part4:每次重新 roll + part4 = "" + if random.random() > 0.7 and words: + num_words = random.randint(1, 3) + selected_words = random.sample(words, min(num_words, len(words))) + part4 = "|".join(selected_words) + + encoded = self.tokenizer( + f"{part4}|{part1}", + part3, + max_length=self.max_seq_length, + truncation=True, + return_token_type_ids=True, + ) + + # 确保 history 长度为 8 + hist = list(history) + if len(hist) > 8: + hist = hist[-8:] + while len(hist) < 8: + hist.append(0) + + return { + "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), + "token_type_ids": torch.tensor(encoded["token_type_ids"], dtype=torch.long), + "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long), + "label": torch.tensor([label], dtype=torch.long), + "history_slot_ids": torch.tensor(hist, dtype=torch.long), + "prefix": f"{part4}^{part1}", + "suffix": part3, + "pinyin": part2, + "pinyin_ids": pinyin_ids, + } def __iter__(self): worker_info = torch.utils.data.get_worker_info() @@ -436,42 +448,43 @@ class PinyinInputDataset(IterableDataset): if not should_break and random.random() <= 0.1: labels.append(0) - # part1: 词起点前的文本(所有样本共享) - part1 = text[max(0, word_start - 48) : word_start] + # 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文 + processed_history = [] + for label_idx, label in enumerate(labels): + base_repeats = self.adjust_frequency( + self.sample_freqs.get(label, 0) + ) + if base_repeats == 0: + processed_history.append(label) + continue + if ( + self.target_labels is not None + and label not in self.target_labels + ): + processed_history.append(label) + continue - # part3: 词后文本 - part3 = "" - if random.random() > 0.7: - part3 = text[word_end : word_end + random.randint(1, 16)] + weight = ( + self._history_weights[label_idx] + if label_idx < len(self._history_weights) + else 3.0 + ) + repeats = max(1, int(base_repeats * weight)) - # part4: 词提示 - part4 = "" - if random.random() > 0.7: - num_words = random.randint(1, 3) - if words: - selected_words = random.sample( - words, min(num_words, len(words)) + for _ in range(repeats): + sample = self._build_single_sample( + label=label, + history=processed_history, + text=text, + word_start=word_start, + word_end=word_end, + part2=part2, + pinyin_ids=pinyin_ids, + words=words, ) - part4 = "|".join(selected_words) + batch_samples.append(sample) - encoded = self.tokenizer( - f"{part4}|{part1}", - part3, - max_length=self.max_seq_length, - truncation=True, - return_token_type_ids=True, - ) - - batch_samples = self._add_word_samples( - batch_samples, - labels, - encoded, - part4, - part1, - part3, - part2, - pinyin_ids, - ) + processed_history.append(label) # ========== Phase 2: 破词续接 ========== if should_break and break_pos < word_len_chars: @@ -533,33 +546,44 @@ class PinyinInputDataset(IterableDataset): if random.random() <= 0.1: cont_labels.append(0) - # part1_cont: 包含已确认前缀的上下文 - part1_cont = text[max(0, cont_start - 48) : cont_start] - - # part3_cont: 续接目标后的文本 + # 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文 + cont_processed_history = [] cont_end = cont_positions[-1] + 1 - part3_cont = "" - if random.random() > 0.7: - part3_cont = text[cont_end : cont_end + random.randint(1, 16)] + for label_idx, label in enumerate(cont_labels): + base_repeats = self.adjust_frequency( + self.sample_freqs.get(label, 0) + ) + if base_repeats == 0: + cont_processed_history.append(label) + continue + if ( + self.target_labels is not None + and label not in self.target_labels + ): + cont_processed_history.append(label) + continue - encoded_cont = self.tokenizer( - f"{part4}|{part1_cont}", - part3_cont, - max_length=self.max_seq_length, - truncation=True, - return_token_type_ids=True, - ) + weight = ( + self._history_weights[label_idx] + if label_idx < len(self._history_weights) + else 3.0 + ) + repeats = max(1, int(base_repeats * weight)) - batch_samples = self._add_word_samples( - batch_samples, - cont_labels, - encoded_cont, - part4, - part1_cont, - part3_cont, - part2_cont, - pinyin_ids_cont, - ) + for _ in range(repeats): + sample = self._build_single_sample( + label=label, + history=cont_processed_history, + text=text, + word_start=cont_start, + word_end=cont_end, + part2=part2_cont, + pinyin_ids=pinyin_ids_cont, + words=words, + ) + batch_samples.append(sample) + + cont_processed_history.append(label) idx = merge_end_idx diff --git a/src/model/supplement_missing.py b/src/model/supplement_missing.py index ba7dab5..0d0cd3e 100644 --- a/src/model/supplement_missing.py +++ b/src/model/supplement_missing.py @@ -4,6 +4,7 @@ 步骤 1: find-missing — 扫描已预处理数据,找出从未出现的 label ID,输出 JSON 步骤 2: generate-template — 根据 JSON 生成 JSONL 占位文件,供用户手动填入包含缺失字的真实文本 +步骤 3: preprocess-supplement — 将填好的 JSONL 文本预处理为 .npz 分片,输出到独立目录 用法: python -m model.supplement_missing find-missing \ @@ -13,6 +14,12 @@ python -m model.supplement_missing generate-template \ --missing-chars missing_chars.json \ --output supplement_texts.jsonl + + python -m model.supplement_missing preprocess-supplement \ + --missing-chars missing_chars.json \ + --supplement-texts supplement_texts.jsonl \ + --output-dir ./preprocessed/supplement \ + --num-samples 100000 """ import argparse @@ -21,12 +28,17 @@ from pathlib import Path from typing import Set import numpy as np +import torch from loguru import logger from rich.console import Console from rich.table import Table +from torch.utils.data import DataLoader from tqdm import tqdm +from .dataset import PinyinInputDataset +from .preprocess import collect_samples from .query import QueryEngine +from .trainer import preprocess_collate_fn, worker_init_fn def scan_labels(preprocessed_dir: Path) -> Set[int]: @@ -175,14 +187,116 @@ def cmd_generate_template(args): ) +def cmd_preprocess_supplement(args): + console = Console() + + # 加载缺失字符 + missing_path = Path(args.missing_chars) + if not missing_path.exists(): + console.print(f"[bold red]文件不存在: {missing_path}[/bold red]") + return + + with open(missing_path, "r", encoding="utf-8") as f: + data = json.load(f) + + missing_chars = data.get("missing_chars", []) + if not missing_chars: + console.print("[bold green]没有缺失字符,无需处理[/bold green]") + return + + target_labels = {entry["id"] for entry in missing_chars} + target_labels.add(0) # 包含 EOS + + # 解析参数 + py_style_weight = tuple(int(x) for x in args.py_style_weight.split(",")) + length_weights = { + int(k): int(v) + for k, v in (item.split(":") for item in args.length_weights.split(",")) + } + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + max_iter = args.num_samples * 5 + num_workers = args.num_workers + + console.print("[bold cyan]=== 补充数据预处理 ===[/bold cyan]") + console.print(f"补充文本: {args.supplement_texts}") + console.print(f"缺失字符数: {len(missing_chars)}") + console.print(f"目标样本: {args.num_samples:,}") + console.print(f"输出目录: {output_dir}") + console.print(f"Worker 数: {num_workers}") + console.print() + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + console.print("[bold cyan]创建补充数据集...[/bold cyan]") + dataset = PinyinInputDataset( + data_path="json", + max_workers=num_workers, + max_iter_length=max_iter, + max_seq_length=args.max_seq_length, + text_field="text", + py_style_weight=py_style_weight, + shuffle_buffer_size=100, + length_weights=length_weights, + data_kwargs={ + "data_files": args.supplement_texts, + "streaming": False, + }, + target_labels=target_labels, + ) + + dataloader_kwargs = { + "batch_size": args.batch_size, + "num_workers": num_workers, + "pin_memory": False, + "worker_init_fn": worker_init_fn, + "collate_fn": preprocess_collate_fn(args.max_seq_length), + } + if num_workers > 0: + dataloader_kwargs["prefetch_factor"] = 2 + dataloader_kwargs["persistent_workers"] = True + + dataloader = DataLoader(dataset, **dataloader_kwargs) + + logger.info("开始收集补充数据...") + count = collect_samples( + dataloader, + args.num_samples, + output_dir, + "supplement", + args.max_seq_length, + args.shard_size, + ) + + if count < args.num_samples: + logger.warning(f"补充样本不足: 目标 {args.num_samples}, 实际 {count}") + + console.print("\n[bold green]=== 补充预处理完成 ===[/bold green]") + console.print(f"生成样本: {count:,}") + console.print(f"输出目录: {output_dir}") + + total_size = sum( + f.stat().st_size for f in output_dir.iterdir() if f.suffix == ".npz" + ) + console.print(f"总大小: {total_size / (1024**3):.2f} GB (compressed)") + console.print() + console.print( + "[bold yellow]提示[/bold yellow]: 请检查补充数据质量,清洗无误后手动将 shard_*.npz 合并到 train/ 目录并更新 metadata.json" + ) + + def main(): parser = argparse.ArgumentParser( description="缺失字符补充工具", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 子命令: - find-missing 扫描已预处理数据,找出从未出现的 label ID - generate-template 根据缺失字符 JSON 生成 JSONL 占位文件 + find-missing 扫描已预处理数据,找出从未出现的 label ID + generate-template 根据缺失字符 JSON 生成 JSONL 占位文件 + preprocess-supplement 将填好的 JSONL 预处理为 .npz 分片(独立目录) 示例: python -m model.supplement_missing find-missing \\ @@ -192,6 +306,12 @@ def main(): python -m model.supplement_missing generate-template \\ --missing-chars missing_chars.json \\ --output supplement_texts.jsonl + + python -m model.supplement_missing preprocess-supplement \\ + --missing-chars missing_chars.json \\ + --supplement-texts supplement_texts.jsonl \\ + --output-dir ./preprocessed/supplement \\ + --num-samples 100000 """, ) subparsers = parser.add_subparsers(dest="command", help="子命令") @@ -232,6 +352,77 @@ def main(): help="每个缺失字符生成的模板条数(默认: 3)", ) + # preprocess-supplement + p_pre = subparsers.add_parser( + "preprocess-supplement", help="将 JSONL 预处理为 .npz 分片" + ) + p_pre.add_argument( + "--missing-chars", + type=str, + required=True, + help="缺失字符 JSON 文件路径(由 find-missing 生成)", + ) + p_pre.add_argument( + "--supplement-texts", + type=str, + required=True, + help="已填写的补充文本 JSONL 文件路径", + ) + p_pre.add_argument( + "--output-dir", + type=str, + required=True, + help="输出目录(独立目录,不会覆盖已有数据)", + ) + p_pre.add_argument( + "--num-samples", + type=int, + required=True, + help="目标样本数量", + ) + p_pre.add_argument( + "--batch-size", + type=int, + default=128, + help="批大小(默认: 128)", + ) + p_pre.add_argument( + "--num-workers", + type=int, + default=0, + help="DataLoader worker 数量。本地 JSONL 小文件建议 0(默认: 0)", + ) + p_pre.add_argument( + "--max-seq-length", + type=int, + default=128, + help="最大序列长度(默认: 128)", + ) + p_pre.add_argument( + "--seed", + type=int, + default=42, + help="随机种子(默认: 42)", + ) + p_pre.add_argument( + "--shard-size", + type=int, + default=5_000_000, + help="分片大小(样本数),控制内存峰值(默认: 500万)", + ) + p_pre.add_argument( + "--py-style-weight", + type=str, + default="9,2,1", + help="拼音风格权重(逗号分隔,默认: 9,2,1)", + ) + p_pre.add_argument( + "--length-weights", + type=str, + default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2", + help="词长权重(默认: 1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2)", + ) + args = parser.parse_args() if args.command is None: @@ -242,6 +433,8 @@ def main(): cmd_find_missing(args) elif args.command == "generate-template": cmd_generate_template(args) + elif args.command == "preprocess-supplement": + cmd_preprocess_supplement(args) app = main diff --git a/test.py b/tests/test.py similarity index 91% rename from test.py rename to tests/test.py index 8ff8453..b809dc5 100644 --- a/test.py +++ b/tests/test.py @@ -1,6 +1,7 @@ +import os import sys -sys.path.append("src") +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) import time import torch @@ -26,7 +27,7 @@ from pypinyin.contrib.tone_convert import to_initials from torch.utils.data import IterableDataset tokenizer = AutoTokenizer.from_pretrained( - Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer" + Path(__file__).parent.parent / "src" / "model" / "assets" / "tokenizer" ) _HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") @@ -83,7 +84,9 @@ sample = { model = InputMethodEngine(pinyin_vocab_size=30, compile=False) -checkpoint = torch.load("/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu") +checkpoint = torch.load( + "/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu" +) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() @@ -100,7 +103,7 @@ for k, v in sample.items(): start = time.time() with torch.no_grad(): res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) - print(f'计算时长: {(time.time() - start) * 1000:4f}ms') + print(f"计算时长: {(time.time() - start) * 1000:4f}ms") sort_res = sorted( [(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True ) diff --git a/test_compile.py b/tests/test_compile.py similarity index 98% rename from test_compile.py rename to tests/test_compile.py index dfa1d98..2bb86d2 100644 --- a/test_compile.py +++ b/tests/test_compile.py @@ -392,7 +392,13 @@ def check_compile_issues(): issues = [] # 检查 components.py 中的潜在问题 - with open("src/model/components.py", "r") as f: + components_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "src", + "model", + "components.py", + ) + with open(components_path, "r") as f: content = f.read() # 检查 float('-inf') diff --git a/test_cpu_inference.py b/tests/test_cpu_inference.py similarity index 98% rename from test_cpu_inference.py rename to tests/test_cpu_inference.py index 17e983e..7a9b1c8 100644 --- a/test_cpu_inference.py +++ b/tests/test_cpu_inference.py @@ -4,9 +4,13 @@ 解决设备转换和权重加载问题 """ +import os import sys from pathlib import Path +# 添加项目根目录到路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + import torch @@ -196,7 +200,7 @@ def test_id_mapping(): query_engine = QueryEngine() stats_path = ( - Path(__file__).parent + Path(__file__).parent.parent / "src" / "model" / "assets" diff --git a/test_dataset.py b/tests/test_dataset.py similarity index 58% rename from test_dataset.py rename to tests/test_dataset.py index 6984dcb..1cf5da2 100644 --- a/test_dataset.py +++ b/tests/test_dataset.py @@ -1,6 +1,7 @@ +import os import sys -sys.path.append("src") +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) from typing import Any, Callable, Dict, List, Optional, Tuple, Union import time @@ -42,23 +43,49 @@ def worker_init_fn(worker_id: int) -> None: def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: """ - 自定义批处理函数,将多个样本组合成一个batch - - Args: - batch: 样本列表,每个样本是一个字典 - - Returns: - 批处理后的字典,tensor字段已stack,字符串字段保持为列表 + 自定义批处理函数,将多个样本组合成一个batch。 + 支持动态padding:根据batch内最大序列长度进行padding。 """ - # 处理tensor字段 - 使用squeeze去除多余的batch维度 - input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch]) - token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch]) - attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch]) + input_ids_list = [item["input_ids"] for item in batch] + token_type_ids_list = [item["token_type_ids"] for item in batch] + attention_mask_list = [item["attention_mask"] for item in batch] + + target_len = max(ids.shape[0] for ids in input_ids_list) + + padded_input_ids = [] + padded_token_type_ids = [] + padded_attention_mask = [] + for ids, tt_ids, mask in zip( + input_ids_list, token_type_ids_list, attention_mask_list + ): + seq_len = ids.shape[0] + if seq_len < target_len: + pad_len = target_len - seq_len + padded_input_ids.append( + torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)]) + ) + padded_token_type_ids.append( + torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)]) + ) + padded_attention_mask.append( + torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)]) + ) + elif seq_len > target_len: + padded_input_ids.append(ids[:target_len]) + padded_token_type_ids.append(tt_ids[:target_len]) + padded_attention_mask.append(mask[:target_len]) + else: + padded_input_ids.append(ids) + padded_token_type_ids.append(tt_ids) + padded_attention_mask.append(mask) + + input_ids = torch.stack(padded_input_ids) + token_type_ids = torch.stack(padded_token_type_ids) + attention_mask = torch.stack(padded_attention_mask) labels = torch.stack([item["label"].squeeze(0) for item in batch]) history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch]) pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch]) - # 字符串字段保持为列表 prefixes = [item["prefix"] for item in batch] suffixes = [item["suffix"] for item in batch] pinyins = [item["pinyin"] for item in batch] @@ -96,6 +123,5 @@ dataloader = DataLoader( persistent_workers=True, ) -for i, shape in tqdm(enumerate(dataloader), total=1000000/512): +for i, shape in tqdm(enumerate(dataloader), total=1000000 / 512): pass - diff --git a/test_epoch_checkpoint.py b/tests/test_epoch_checkpoint.py similarity index 100% rename from test_epoch_checkpoint.py rename to tests/test_epoch_checkpoint.py diff --git a/test_trainer.py b/tests/test_trainer.py similarity index 99% rename from test_trainer.py rename to tests/test_trainer.py index 7136abd..585f009 100644 --- a/test_trainer.py +++ b/tests/test_trainer.py @@ -6,8 +6,8 @@ import torch.nn as nn from rich.console import Console from torch.utils.data import DataLoader, Dataset -# 添加src目录到路径 -sys.path.insert(0, str(Path(__file__).parent)) +# 添加项目根目录到路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) from src.model.model import InputMethodEngine from src.model.trainer import Trainer