refactor(generate_pinyin): 优化拼音生成逻辑,利用 pypinyin 分词能力处理多音字

This commit is contained in:
songsenand 2026-05-09 13:36:48 +08:00
parent 8b41bcdc6f
commit e8eab1f260
15 changed files with 739 additions and 369 deletions

33
eval.py
View File

@ -14,7 +14,6 @@ eval.py - 评估模型在给定文本上的表现
import argparse import argparse
import random import random
import re
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Tuple, Optional
@ -33,8 +32,6 @@ from src.model.model import InputMethodEngine
from src.model.query import QueryEngine from src.model.query import QueryEngine
from src.model.dataset import text_to_pinyin_ids from src.model.dataset import text_to_pinyin_ids
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
class TextEvaluator: class TextEvaluator:
def __init__( def __init__(
@ -171,34 +168,20 @@ class TextEvaluator:
def generate_pinyin(self, text: str) -> List[str]: def generate_pinyin(self, text: str) -> List[str]:
""" """
流式处理单条文本转换为拼音列表 将文本转换为拼音列表对整段文本调用 lazy_pinyin
参考dataset.py中的generate_pinyin方法 利用 pypinyin 内部的分词能力处理多音字
参考 dataset.py 中的 generate_pinyin 方法
""" """
if not text: if not text:
return [] return []
text_len = len(text) pinyin_list = lazy_pinyin(text)
result: List[str] = [""] * text_len
# 遍历所有连续汉字片段 # 健壮性兜底:若长度不匹配(极罕见),降级为逐字转换
for match in _HANZI_RE.finditer(text): if len(pinyin_list) != len(text):
start_idx = match.start() pinyin_list = [lazy_pinyin(c)[0] for c in text]
hanzi_segment = match.group()
pinyin_list = lazy_pinyin(hanzi_segment) return pinyin_list
if len(pinyin_list) != len(hanzi_segment):
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
for i, py in enumerate(pinyin_list):
result[start_idx + i] = py
# 填充非汉字字符
for i, char in enumerate(text):
if not result[i]:
result[i] = char
return result
def get_mask_pinyin( def get_mask_pinyin(
self, text: str, pinyin_list: List[str] self, text: str, pinyin_list: List[str]

View File

@ -10,34 +10,41 @@ import math
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
def main(): def main():
# Path to the JSON file # Path to the JSON file
json_path = Path("src/model/assets/pinyin_char_statistics.json") json_path = (
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if not json_path.exists(): if not json_path.exists():
print(f"Error: File not found: {json_path}") print(f"Error: File not found: {json_path}")
sys.exit(1) sys.exit(1)
print(f"Loading {json_path}...") print(f"Loading {json_path}...")
with open(json_path, 'r', encoding='utf-8') as f: with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
print(f"Timestamp: {data.get('timestamp')}") print(f"Timestamp: {data.get('timestamp')}")
print(f"Total characters: {data.get('total_characters')}") print(f"Total characters: {data.get('total_characters')}")
print(f"Total pinyins: {data.get('total_pinyins')}") print(f"Total pinyins: {data.get('total_pinyins')}")
print(f"Valid input character count: {data.get('valid_input_character_count')}") print(f"Valid input character count: {data.get('valid_input_character_count')}")
pairs = data.get('pairs', {}) pairs = data.get("pairs", {})
print(f"Number of pairs: {len(pairs)}") print(f"Number of pairs: {len(pairs)}")
# Extract counts and IDs # Extract counts and IDs
counts = [] counts = []
id_to_count = {} id_to_count = {}
char_to_count = {} char_to_count = {}
for key, pair in pairs.items(): for key, pair in pairs.items():
try: try:
char_id = pair.get('id') char_id = pair.get("id")
count = pair.get('count') count = pair.get("count")
char = pair.get('char', '') char = pair.get("char", "")
if char_id is not None and count is not None: if char_id is not None and count is not None:
counts.append(count) counts.append(count)
id_to_count[char_id] = count id_to_count[char_id] = count
@ -46,21 +53,21 @@ def main():
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
print(f"Warning: Could not parse pair {key}: {e}") print(f"Warning: Could not parse pair {key}: {e}")
continue continue
if not counts: if not counts:
print("No valid count data found.") print("No valid count data found.")
return return
# Basic statistics # Basic statistics
min_count = min(counts) min_count = min(counts)
max_count = max(counts) max_count = max(counts)
total_count = sum(counts) total_count = sum(counts)
mean_count = total_count / len(counts) mean_count = total_count / len(counts)
# Sort counts for percentiles # Sort counts for percentiles
sorted_counts = sorted(counts) sorted_counts = sorted(counts)
n = len(sorted_counts) n = len(sorted_counts)
# Percentiles # Percentiles
p10 = sorted_counts[int(0.1 * n)] p10 = sorted_counts[int(0.1 * n)]
p25 = sorted_counts[int(0.25 * n)] p25 = sorted_counts[int(0.25 * n)]
@ -68,11 +75,11 @@ def main():
p75 = sorted_counts[int(0.75 * n)] p75 = sorted_counts[int(0.75 * n)]
p90 = sorted_counts[int(0.9 * n)] p90 = sorted_counts[int(0.9 * n)]
p99 = sorted_counts[int(0.99 * n)] p99 = sorted_counts[int(0.99 * n)]
# Variance and std dev # Variance and std dev
variance = sum((x - mean_count) ** 2 for x in counts) / n variance = sum((x - mean_count) ** 2 for x in counts) / n
std_dev = math.sqrt(variance) std_dev = math.sqrt(variance)
print("\n=== BASIC STATISTICS ===") print("\n=== BASIC STATISTICS ===")
print(f"Min frequency: {min_count}") print(f"Min frequency: {min_count}")
print(f"Max frequency: {max_count}") print(f"Max frequency: {max_count}")
@ -80,7 +87,7 @@ def main():
print(f"Standard deviation: {std_dev:.2f}") print(f"Standard deviation: {std_dev:.2f}")
print(f"Total frequency sum: {total_count}") print(f"Total frequency sum: {total_count}")
print(f"Number of entries: {n}") print(f"Number of entries: {n}")
print("\n=== PERCENTILES ===") print("\n=== PERCENTILES ===")
print(f"10th percentile: {p10}") print(f"10th percentile: {p10}")
print(f"25th percentile: {p25}") print(f"25th percentile: {p25}")
@ -88,52 +95,54 @@ def main():
print(f"75th percentile: {p75}") print(f"75th percentile: {p75}")
print(f"90th percentile: {p90}") print(f"90th percentile: {p90}")
print(f"99th percentile: {p99}") print(f"99th percentile: {p99}")
# Find IDs with min and max counts # Find IDs with min and max counts
min_ids = [id for id, count in id_to_count.items() if count == min_count] min_ids = [id for id, count in id_to_count.items() if count == min_count]
max_ids = [id for id, count in id_to_count.items() if count == max_count] max_ids = [id for id, count in id_to_count.items() if count == max_count]
print(f"\nIDs with min frequency ({min_count}): {min_ids}") print(f"\nIDs with min frequency ({min_count}): {min_ids}")
print(f"IDs with max frequency ({max_count}): {max_ids}") print(f"IDs with max frequency ({max_count}): {max_ids}")
# Check if IDs are assigned in frequency order # Check if IDs are assigned in frequency order
# Compute correlation between ID and count # Compute correlation between ID and count
ids = list(id_to_count.keys()) ids = list(id_to_count.keys())
id_counts = [id_to_count[id] for id in ids] id_counts = [id_to_count[id] for id in ids]
# Sort by ID and check if counts are decreasing # Sort by ID and check if counts are decreasing
sorted_by_id = sorted(ids) sorted_by_id = sorted(ids)
counts_by_id = [id_to_count[id] for id in sorted_by_id] counts_by_id = [id_to_count[id] for id in sorted_by_id]
# Calculate monotonicity: count of times count decreases as ID increases # Calculate monotonicity: count of times count decreases as ID increases
decreases = 0 decreases = 0
increases = 0 increases = 0
for i in range(1, len(counts_by_id)): for i in range(1, len(counts_by_id)):
if counts_by_id[i] < counts_by_id[i-1]: if counts_by_id[i] < counts_by_id[i - 1]:
decreases += 1 decreases += 1
elif counts_by_id[i] > counts_by_id[i-1]: elif counts_by_id[i] > counts_by_id[i - 1]:
increases += 1 increases += 1
print(f"\n=== ID ORDER ANALYSIS ===") print(f"\n=== ID ORDER ANALYSIS ===")
print(f"Total pairs: {len(counts_by_id)}") print(f"Total pairs: {len(counts_by_id)}")
print(f"Decreases as ID increases: {decreases} times") print(f"Decreases as ID increases: {decreases} times")
print(f"Increases as ID increases: {increases} times") print(f"Increases as ID increases: {increases} times")
print(f"Percentage decreasing: {decreases/(len(counts_by_id)-1)*100:.2f}%") print(f"Percentage decreasing: {decreases / (len(counts_by_id) - 1) * 100:.2f}%")
# Check if IDs are roughly sorted by frequency # Check if IDs are roughly sorted by frequency
# Compute Spearman rank correlation (simplified) # Compute Spearman rank correlation (simplified)
sorted_by_count = sorted(ids, key=lambda x: id_to_count[x], reverse=True) sorted_by_count = sorted(ids, key=lambda x: id_to_count[x], reverse=True)
rank_by_id = {id: i for i, id in enumerate(sorted_by_id)} rank_by_id = {id: i for i, id in enumerate(sorted_by_id)}
rank_by_count = {id: i for i, id in enumerate(sorted_by_count)} rank_by_count = {id: i for i, id in enumerate(sorted_by_count)}
# Average rank difference # Average rank difference
rank_diffs = [abs(rank_by_id[id] - rank_by_count[id]) for id in ids] rank_diffs = [abs(rank_by_id[id] - rank_by_count[id]) for id in ids]
avg_rank_diff = sum(rank_diffs) / len(rank_diffs) avg_rank_diff = sum(rank_diffs) / len(rank_diffs)
max_rank_diff = max(rank_diffs) max_rank_diff = max(rank_diffs)
print(f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}") print(
f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}"
)
print(f"Maximum rank difference: {max_rank_diff}") print(f"Maximum rank difference: {max_rank_diff}")
# Analyze specific ID range 5000-5500 # Analyze specific ID range 5000-5500
print("\n=== ANALYSIS OF ID RANGE 5000-5500 ===") print("\n=== ANALYSIS OF ID RANGE 5000-5500 ===")
range_counts = [] range_counts = []
@ -142,7 +151,7 @@ def main():
if id in id_to_count: if id in id_to_count:
range_counts.append(id_to_count[id]) range_counts.append(id_to_count[id])
range_ids.append(id) range_ids.append(id)
if range_counts: if range_counts:
range_min = min(range_counts) range_min = min(range_counts)
range_max = max(range_counts) range_max = max(range_counts)
@ -152,7 +161,7 @@ def main():
range_p10 = range_sorted[int(0.1 * range_n)] if range_n > 0 else 0 range_p10 = range_sorted[int(0.1 * range_n)] if range_n > 0 else 0
range_p50 = range_sorted[int(0.5 * range_n)] if range_n > 0 else 0 range_p50 = range_sorted[int(0.5 * range_n)] if range_n > 0 else 0
range_p90 = range_sorted[int(0.9 * range_n)] if range_n > 0 else 0 range_p90 = range_sorted[int(0.9 * range_n)] if range_n > 0 else 0
print(f"IDs in range 5000-5500: {len(range_counts)}") print(f"IDs in range 5000-5500: {len(range_counts)}")
print(f"Min frequency in range: {range_min}") print(f"Min frequency in range: {range_min}")
print(f"Max frequency in range: {range_max}") print(f"Max frequency in range: {range_max}")
@ -160,66 +169,82 @@ def main():
print(f"10th percentile in range: {range_p10}") print(f"10th percentile in range: {range_p10}")
print(f"50th percentile in range: {range_p50}") print(f"50th percentile in range: {range_p50}")
print(f"90th percentile in range: {range_p90}") print(f"90th percentile in range: {range_p90}")
# Find IDs with min frequency in this range # Find IDs with min frequency in this range
min_in_range_ids = [id for id in range_ids if id_to_count[id] == range_min] min_in_range_ids = [id for id in range_ids if id_to_count[id] == range_min]
print(f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}") print(
f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}"
)
else: else:
print("No IDs found in range 5000-5500") print("No IDs found in range 5000-5500")
# Histogram of frequencies (log bins) # Histogram of frequencies (log bins)
print("\n=== FREQUENCY DISTRIBUTION (LOG BINS) ===") print("\n=== FREQUENCY DISTRIBUTION (LOG BINS) ===")
if max_count > 0: if max_count > 0:
log_min = math.log10(min_count) if min_count > 0 else 0 log_min = math.log10(min_count) if min_count > 0 else 0
log_max = math.log10(max_count) log_max = math.log10(max_count)
num_bins = 20 num_bins = 20
bin_edges = [10**(log_min + i*(log_max-log_min)/num_bins) for i in range(num_bins+1)] bin_edges = [
10 ** (log_min + i * (log_max - log_min) / num_bins)
for i in range(num_bins + 1)
]
hist = [0] * num_bins hist = [0] * num_bins
for count in counts: for count in counts:
if count > 0: if count > 0:
log_val = math.log10(count) log_val = math.log10(count)
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * num_bins), num_bins-1) bin_idx = min(
int((log_val - log_min) / (log_max - log_min) * num_bins),
num_bins - 1,
)
hist[bin_idx] += 1 hist[bin_idx] += 1
print("Log-scale histogram (count range -> frequency count):") print("Log-scale histogram (count range -> frequency count):")
for i in range(num_bins): for i in range(num_bins):
if hist[i] > 0: if hist[i] > 0:
lower = bin_edges[i] lower = bin_edges[i]
upper = bin_edges[i+1] upper = bin_edges[i + 1]
print(f" {lower:.2e} - {upper:.2e}: {hist[i]} entries") print(f" {lower:.2e} - {upper:.2e}: {hist[i]} entries")
# Check for zero or near-zero frequencies # Check for zero or near-zero frequencies
zero_count = sum(1 for c in counts if c == 0) zero_count = sum(1 for c in counts if c == 0)
low_count = sum(1 for c in counts if 0 < c <= 10) low_count = sum(1 for c in counts if 0 < c <= 10)
very_low_count = sum(1 for c in counts if 0 < c <= 100) very_low_count = sum(1 for c in counts if 0 < c <= 100)
print(f"\n=== LOW FREQUENCY ANALYSIS ===") print(f"\n=== LOW FREQUENCY ANALYSIS ===")
print(f"Entries with zero frequency: {zero_count}") print(f"Entries with zero frequency: {zero_count}")
print(f"Entries with frequency <= 10: {low_count}") print(f"Entries with frequency <= 10: {low_count}")
print(f"Entries with frequency <= 100: {very_low_count}") print(f"Entries with frequency <= 100: {very_low_count}")
# Find the actual min frequency (excluding zeros if any) # Find the actual min frequency (excluding zeros if any)
non_zero_counts = [c for c in counts if c > 0] non_zero_counts = [c for c in counts if c > 0]
if non_zero_counts: if non_zero_counts:
actual_min = min(non_zero_counts) actual_min = min(non_zero_counts)
print(f"Actual min frequency (non-zero): {actual_min}") print(f"Actual min frequency (non-zero): {actual_min}")
actual_min_ids = [id for id, count in id_to_count.items() if count == actual_min] actual_min_ids = [
print(f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}") id for id, count in id_to_count.items() if count == actual_min
]
print(
f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}"
)
# Summary for smoothing algorithm design # Summary for smoothing algorithm design
print("\n=== SUMMARY FOR SMOOTHING ALGORITHM DESIGN ===") print("\n=== SUMMARY FOR SMOOTHING ALGORITHM DESIGN ===")
print(f"Frequency range spans {max_count/min_count if min_count>0 else 'inf'}:1 ratio") print(
f"Frequency range spans {max_count / min_count if min_count > 0 else 'inf'}:1 ratio"
)
print(f"Most entries ({p50}) have frequency around {p50}") print(f"Most entries ({p50}) have frequency around {p50}")
print(f"Top 10% of entries have frequency > {p90}") print(f"Top 10% of entries have frequency > {p90}")
print(f"Bottom 10% of entries have frequency < {p10}") print(f"Bottom 10% of entries have frequency < {p10}")
print(f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency") print(
f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency"
)
# Save detailed data for further analysis # Save detailed data for further analysis
output_file = "frequency_analysis_results.txt" output_file = "frequency_analysis_results.txt"
with open(output_file, 'w', encoding='utf-8') as f: with open(output_file, "w", encoding="utf-8") as f:
f.write("Frequency Analysis Results\n") f.write("Frequency Analysis Results\n")
f.write("="*50 + "\n") f.write("=" * 50 + "\n")
f.write(f"Min frequency: {min_count}\n") f.write(f"Min frequency: {min_count}\n")
f.write(f"Max frequency: {max_count}\n") f.write(f"Max frequency: {max_count}\n")
f.write(f"Mean frequency: {mean_count:.2f}\n") f.write(f"Mean frequency: {mean_count:.2f}\n")
@ -227,10 +252,15 @@ def main():
f.write(f"10th percentile: {p10}\n") f.write(f"10th percentile: {p10}\n")
f.write(f"50th percentile: {p50}\n") f.write(f"50th percentile: {p50}\n")
f.write(f"90th percentile: {p90}\n") f.write(f"90th percentile: {p90}\n")
f.write(f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n") f.write(
f.write(f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n") f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n"
)
f.write(
f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n"
)
print(f"\nDetailed results saved to {output_file}") print(f"\nDetailed results saved to {output_file}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -7,58 +7,70 @@ import json
import sys import sys
from pathlib import Path from pathlib import Path
def main(): def main():
json_path = Path("src/model/assets/pinyin_char_statistics.json") json_path = (
with open(json_path, 'r', encoding='utf-8') as f: Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
pairs = data.get('pairs', {}) pairs = data.get("pairs", {})
# Build ID to count mapping # Build ID to count mapping
id_to_count = {} id_to_count = {}
for key, pair in pairs.items(): for key, pair in pairs.items():
char_id = pair.get('id') char_id = pair.get("id")
count = pair.get('count') count = pair.get("count")
if char_id is not None and count is not None: if char_id is not None and count is not None:
id_to_count[char_id] = count id_to_count[char_id] = count
# Analyze range 5000-5500 in detail # Analyze range 5000-5500 in detail
print("ID range 5000-5500 detailed analysis:") print("ID range 5000-5500 detailed analysis:")
print("ID\tCount\tChar\tPinyin") print("ID\tCount\tChar\tPinyin")
range_data = [] range_data = []
for id in range(5000, 5501): for id in range(5000, 5501):
if id in id_to_count: if id in id_to_count:
# Find the pair to get char and pinyin # Find the pair to get char and pinyin
for key, pair in pairs.items(): for key, pair in pairs.items():
if pair.get('id') == id: if pair.get("id") == id:
char = pair.get('char', '') char = pair.get("char", "")
pinyin = pair.get('pinyin', '') pinyin = pair.get("pinyin", "")
count = pair.get('count', 0) count = pair.get("count", 0)
range_data.append((id, count, char, pinyin)) range_data.append((id, count, char, pinyin))
if id % 100 == 0: # Print every 100th for overview if id % 100 == 0: # Print every 100th for overview
print(f"{id}\t{count}\t{char}\t{pinyin}") print(f"{id}\t{count}\t{char}\t{pinyin}")
break break
# Print min and max in range # Print min and max in range
if range_data: if range_data:
min_item = min(range_data, key=lambda x: x[1]) min_item = min(range_data, key=lambda x: x[1])
max_item = max(range_data, key=lambda x: x[1]) max_item = max(range_data, key=lambda x: x[1])
print(f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'") print(
print(f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'") f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'"
)
print(
f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'"
)
# Check if frequencies are monotonic in this range # Check if frequencies are monotonic in this range
counts = [item[1] for item in range_data] counts = [item[1] for item in range_data]
increasing = all(counts[i] <= counts[i+1] for i in range(len(counts)-1)) increasing = all(counts[i] <= counts[i + 1] for i in range(len(counts) - 1))
decreasing = all(counts[i] >= counts[i+1] for i in range(len(counts)-1)) decreasing = all(counts[i] >= counts[i + 1] for i in range(len(counts) - 1))
print(f"Monotonic in range: increasing={increasing}, decreasing={decreasing}") print(f"Monotonic in range: increasing={increasing}, decreasing={decreasing}")
# Check for frequency plateaus # Check for frequency plateaus
from collections import Counter from collections import Counter
freq_count = Counter(counts) freq_count = Counter(counts)
most_common = freq_count.most_common(5) most_common = freq_count.most_common(5)
print(f"Most common frequencies in range: {most_common}") print(f"Most common frequencies in range: {most_common}")
# Analyze the tail (IDs with frequency 1) # Analyze the tail (IDs with frequency 1)
print("\n\nAnalysis of frequency=1 entries:") print("\n\nAnalysis of frequency=1 entries:")
freq_one_ids = [id for id, count in id_to_count.items() if count == 1] freq_one_ids = [id for id, count in id_to_count.items() if count == 1]
@ -67,30 +79,34 @@ def main():
print(f"ID range of frequency=1: {min(freq_one_ids)} to {max(freq_one_ids)}") print(f"ID range of frequency=1: {min(freq_one_ids)} to {max(freq_one_ids)}")
print(f"First 10 IDs: {freq_one_ids[:10]}") print(f"First 10 IDs: {freq_one_ids[:10]}")
print(f"Last 10 IDs: {freq_one_ids[-10:]}") print(f"Last 10 IDs: {freq_one_ids[-10:]}")
# Check if they're contiguous # Check if they're contiguous
sorted_ids = sorted(freq_one_ids) sorted_ids = sorted(freq_one_ids)
contiguous = all(sorted_ids[i] + 1 == sorted_ids[i+1] for i in range(len(sorted_ids)-1)) contiguous = all(
sorted_ids[i] + 1 == sorted_ids[i + 1] for i in range(len(sorted_ids) - 1)
)
print(f"Are they contiguous IDs? {contiguous}") print(f"Are they contiguous IDs? {contiguous}")
# Sample some characters # Sample some characters
print("\nSample characters with frequency=1:") print("\nSample characters with frequency=1:")
sample_count = 0 sample_count = 0
for key, pair in pairs.items(): for key, pair in pairs.items():
if pair.get('count') == 1 and sample_count < 10: if pair.get("count") == 1 and sample_count < 10:
print(f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'") print(
f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'"
)
sample_count += 1 sample_count += 1
# Check overall ID-frequency ordering # Check overall ID-frequency ordering
print("\n\nOverall ID-frequency ordering analysis:") print("\n\nOverall ID-frequency ordering analysis:")
all_ids = sorted(id_to_count.keys()) all_ids = sorted(id_to_count.keys())
all_counts = [id_to_count[id] for id in all_ids] all_counts = [id_to_count[id] for id in all_ids]
# Count monotonic segments # Count monotonic segments
non_increasing_segments = 0 non_increasing_segments = 0
current_segment_length = 1 current_segment_length = 1
for i in range(1, len(all_counts)): for i in range(1, len(all_counts)):
if all_counts[i] <= all_counts[i-1]: if all_counts[i] <= all_counts[i - 1]:
current_segment_length += 1 current_segment_length += 1
else: else:
if current_segment_length > 1: if current_segment_length > 1:
@ -98,18 +114,22 @@ def main():
current_segment_length = 1 current_segment_length = 1
if current_segment_length > 1: if current_segment_length > 1:
non_increasing_segments += 1 non_increasing_segments += 1
print(f"Total IDs: {len(all_ids)}") print(f"Total IDs: {len(all_ids)}")
print(f"Non-increasing segments: {non_increasing_segments}") print(f"Non-increasing segments: {non_increasing_segments}")
# Check for frequency plateaus overall # Check for frequency plateaus overall
from collections import Counter from collections import Counter
overall_freq_count = Counter(all_counts) overall_freq_count = Counter(all_counts)
plateaus = [(freq, count) for freq, count in overall_freq_count.items() if count > 1] plateaus = [
(freq, count) for freq, count in overall_freq_count.items() if count > 1
]
plateaus_sorted = sorted(plateaus, key=lambda x: x[1], reverse=True)[:10] plateaus_sorted = sorted(plateaus, key=lambda x: x[1], reverse=True)[:10]
print(f"Top 10 frequency plateaus (freq: count of IDs sharing that freq):") print(f"Top 10 frequency plateaus (freq: count of IDs sharing that freq):")
for freq, count in plateaus_sorted: for freq, count in plateaus_sorted:
print(f" {freq}: {count} IDs") print(f" {freq}: {count} IDs")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,10 +1,15 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from model.dataset import PinyinInputDataset from model.dataset import PinyinInputDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from model.trainer import collate_fn, worker_init_fn from model.trainer import collate_fn, worker_init_fn
data = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/') data = PinyinInputDataset("/home/songsenand/Data/corpus/CCI-Data/")
dataloader = DataLoader( dataloader = DataLoader(
data, data,
@ -18,5 +23,5 @@ dataloader = DataLoader(
) )
for i in dataloader: for i in dataloader:
print((i['labels'] == 1).sum()) print((i["labels"] == 1).sum())
break break

View File

@ -9,90 +9,125 @@ import math
from collections import Counter from collections import Counter
from pathlib import Path from pathlib import Path
def main(): def main():
json_path = Path("src/model/assets/pinyin_char_statistics.json") json_path = (
with open(json_path, 'r', encoding='utf-8') as f: Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
pairs = data.get('pairs', {}) pairs = data.get("pairs", {})
# Extract counts # Extract counts
counts = [] counts = []
for key, pair in pairs.items(): for key, pair in pairs.items():
count = pair.get('count') count = pair.get("count")
if count is not None: if count is not None:
counts.append(count) counts.append(count)
n = len(counts) n = len(counts)
print(f"Total entries: {n}") print(f"Total entries: {n}")
# Sort descending for rank-frequency analysis # Sort descending for rank-frequency analysis
counts_sorted_desc = sorted(counts, reverse=True) counts_sorted_desc = sorted(counts, reverse=True)
# Basic statistics # Basic statistics
min_count = min(counts) min_count = min(counts)
max_count = max(counts) max_count = max(counts)
mean_count = sum(counts) / n mean_count = sum(counts) / n
# Percentiles # Percentiles
percentiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99] percentiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
print("\n=== PERCENTILE DISTRIBUTION ===") print("\n=== PERCENTILE DISTRIBUTION ===")
for p in percentiles: for p in percentiles:
idx = int(p * n) idx = int(p * n)
value = counts_sorted_desc[idx] value = counts_sorted_desc[idx]
print(f"{p*100:5.1f}%: {value:>12} (rank ~{idx})") print(f"{p * 100:5.1f}%: {value:>12} (rank ~{idx})")
# Cumulative distribution # Cumulative distribution
print("\n=== CUMULATIVE DISTRIBUTION ===") print("\n=== CUMULATIVE DISTRIBUTION ===")
thresholds = [1, 2, 3, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 500000, 1000000, 5000000, 10000000, 50000000, 100000000, 500000000] thresholds = [
1,
2,
3,
5,
10,
20,
50,
100,
200,
500,
1000,
2000,
5000,
10000,
20000,
50000,
100000,
200000,
500000,
1000000,
5000000,
10000000,
50000000,
100000000,
500000000,
]
for thresh in thresholds: for thresh in thresholds:
if thresh > max_count: if thresh > max_count:
break break
below = sum(1 for c in counts if c <= thresh) below = sum(1 for c in counts if c <= thresh)
above = sum(1 for c in counts if c >= thresh) above = sum(1 for c in counts if c >= thresh)
print(f"Count <= {thresh:10}: {below:6} entries ({below/n*100:5.1f}%)") print(f"Count <= {thresh:10}: {below:6} entries ({below / n * 100:5.1f}%)")
# print(f"Count >= {thresh:10}: {above:6} entries ({above/n*100:5.1f}%)") # print(f"Count >= {thresh:10}: {above:6} entries ({above/n*100:5.1f}%)")
# Check min_count=109 parameter # Check min_count=109 parameter
print("\n=== ANALYSIS OF THRESHOLD 109 ===") print("\n=== ANALYSIS OF THRESHOLD 109 ===")
below_109 = sum(1 for c in counts if c < 109) below_109 = sum(1 for c in counts if c < 109)
at_or_above_109 = sum(1 for c in counts if c >= 109) at_or_above_109 = sum(1 for c in counts if c >= 109)
print(f"Entries with count < 109: {below_109} ({below_109/n*100:.1f}%)") print(f"Entries with count < 109: {below_109} ({below_109 / n * 100:.1f}%)")
print(f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109/n*100:.1f}%)") print(
f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109 / n * 100:.1f}%)"
)
# If 109 is a threshold, what's the actual min among those >= 109? # If 109 is a threshold, what's the actual min among those >= 109?
counts_ge_109 = [c for c in counts if c >= 109] counts_ge_109 = [c for c in counts if c >= 109]
if counts_ge_109: if counts_ge_109:
actual_min_ge_109 = min(counts_ge_109) actual_min_ge_109 = min(counts_ge_109)
print(f"Actual min frequency among those >= 109: {actual_min_ge_109}") print(f"Actual min frequency among those >= 109: {actual_min_ge_109}")
# Rank-frequency analysis (Zipf's law) # Rank-frequency analysis (Zipf's law)
print("\n=== RANK-FREQUENCY ANALYSIS (Top 100) ===") print("\n=== RANK-FREQUENCY ANALYSIS (Top 100) ===")
print("Rank\tFrequency\tlog(rank)\tlog(freq)") print("Rank\tFrequency\tlog(rank)\tlog(freq)")
for rank in range(1, 101): for rank in range(1, 101):
freq = counts_sorted_desc[rank-1] freq = counts_sorted_desc[rank - 1]
print(f"{rank}\t{freq}\t{math.log(rank):.3f}\t{math.log(freq):.3f}") print(f"{rank}\t{freq}\t{math.log(rank):.3f}\t{math.log(freq):.3f}")
# Frequency spectrum (how many distinct frequencies) # Frequency spectrum (how many distinct frequencies)
freq_counter = Counter(counts) freq_counter = Counter(counts)
print(f"\n=== FREQUENCY SPECTRUM ===") print(f"\n=== FREQUENCY SPECTRUM ===")
print(f"Distinct frequency values: {len(freq_counter)}") print(f"Distinct frequency values: {len(freq_counter)}")
# Most common frequencies # Most common frequencies
print("\nTop 20 most common frequencies (plateau sizes):") print("\nTop 20 most common frequencies (plateau sizes):")
for freq, freq_count in freq_counter.most_common(20): for freq, freq_count in freq_counter.most_common(20):
print(f" Frequency {freq}: {freq_count} entries") print(f" Frequency {freq}: {freq_count} entries")
# Analyze ID ranges # Analyze ID ranges
print("\n=== ID RANGE ANALYSIS ===") print("\n=== ID RANGE ANALYSIS ===")
# Build ID to count mapping # Build ID to count mapping
id_to_count = {} id_to_count = {}
for key, pair in pairs.items(): for key, pair in pairs.items():
char_id = pair.get('id') char_id = pair.get("id")
count = pair.get('count') count = pair.get("count")
if char_id is not None and count is not None: if char_id is not None and count is not None:
id_to_count[char_id] = count id_to_count[char_id] = count
ranges = [ ranges = [
(0, 100, "Top 100 IDs"), (0, 100, "Top 100 IDs"),
(100, 500, "IDs 100-500"), (100, 500, "IDs 100-500"),
@ -106,66 +141,80 @@ def main():
(19000, 19500, "IDs 19000-19500 (before freq=1)"), (19000, 19500, "IDs 19000-19500 (before freq=1)"),
(19499, 20647, "IDs with freq=1"), (19499, 20647, "IDs with freq=1"),
] ]
for start, end, label in ranges: for start, end, label in ranges:
range_counts = [id_to_count[id] for id in range(start, end) if id in id_to_count] range_counts = [
id_to_count[id] for id in range(start, end) if id in id_to_count
]
if range_counts: if range_counts:
min_c = min(range_counts) min_c = min(range_counts)
max_c = max(range_counts) max_c = max(range_counts)
mean_c = sum(range_counts) / len(range_counts) mean_c = sum(range_counts) / len(range_counts)
median_c = sorted(range_counts)[len(range_counts)//2] median_c = sorted(range_counts)[len(range_counts) // 2]
print(f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}") print(
f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}"
)
# Check if IDs are perfectly sorted by frequency # Check if IDs are perfectly sorted by frequency
print("\n=== ID ORDER VERIFICATION ===") print("\n=== ID ORDER VERIFICATION ===")
all_ids = sorted(id_to_count.keys()) all_ids = sorted(id_to_count.keys())
all_counts = [id_to_count[id] for id in all_ids] all_counts = [id_to_count[id] for id in all_ids]
# Check for any violations of non-increasing order # Check for any violations of non-increasing order
violations = 0 violations = 0
for i in range(1, len(all_counts)): for i in range(1, len(all_counts)):
if all_counts[i] > all_counts[i-1]: if all_counts[i] > all_counts[i - 1]:
violations += 1 violations += 1
if violations <= 5: if violations <= 5:
print(f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i-1]} (ID {all_ids[i-1]})") print(
f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i - 1]} (ID {all_ids[i - 1]})"
)
print(f"Total violations of non-increasing order: {violations}") print(f"Total violations of non-increasing order: {violations}")
# Check if equal frequencies are grouped together # Check if equal frequencies are grouped together
print("\n=== FREQUENCY GROUPING ANALYSIS ===") print("\n=== FREQUENCY GROUPING ANALYSIS ===")
current_freq = None current_freq = None
group_start = None group_start = None
group_sizes = [] group_sizes = []
for i, (id, count) in enumerate(zip(all_ids, all_counts)): for i, (id, count) in enumerate(zip(all_ids, all_counts)):
if count != current_freq: if count != current_freq:
if current_freq is not None: if current_freq is not None:
group_sizes.append((current_freq, group_start, all_ids[i-1], i - group_start)) group_sizes.append(
(current_freq, group_start, all_ids[i - 1], i - group_start)
)
current_freq = count current_freq = count
group_start = i group_start = i
# Last group # Last group
if current_freq is not None: if current_freq is not None:
group_sizes.append((current_freq, group_start, all_ids[-1], len(all_ids) - group_start)) group_sizes.append(
(current_freq, group_start, all_ids[-1], len(all_ids) - group_start)
)
# Sort groups by size # Sort groups by size
group_sizes.sort(key=lambda x: x[3], reverse=True) group_sizes.sort(key=lambda x: x[3], reverse=True)
print("Top 10 largest frequency groups (plateaus):") print("Top 10 largest frequency groups (plateaus):")
for freq, start_id_idx, end_id, size in group_sizes[:10]: for freq, start_id_idx, end_id, size in group_sizes[:10]:
start_id = all_ids[start_id_idx] start_id = all_ids[start_id_idx]
print(f" Frequency {freq}: IDs {start_id}-{end_id} ({size} entries)") print(f" Frequency {freq}: IDs {start_id}-{end_id} ({size} entries)")
# Summary for smoothing algorithm # Summary for smoothing algorithm
print("\n=== SMOOTHING ALGORITHM IMPLICATIONS ===") print("\n=== SMOOTHING ALGORITHM IMPLICATIONS ===")
print("1. IDs are perfectly sorted by frequency (non-increasing).") print("1. IDs are perfectly sorted by frequency (non-increasing).")
print(f"2. Frequency range: {min_count} to {max_count} (ratio {max_count/min_count:.1e}:1).") print(
print(f"3. {below_109} entries ({below_109/n*100:.1f}%) have frequency < 109.") f"2. Frequency range: {min_count} to {max_count} (ratio {max_count / min_count:.1e}:1)."
print(f"4. Median frequency: {counts_sorted_desc[n//2]}.") )
print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9*n)]}.") print(f"3. {below_109} entries ({below_109 / n * 100:.1f}%) have frequency < 109.")
print(f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01*n)]}.") print(f"4. Median frequency: {counts_sorted_desc[n // 2]}.")
print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9 * n)]}.")
print(
f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01 * n)]}."
)
print("7. Large frequency plateaus exist (many IDs share same frequency).") print("7. Large frequency plateaus exist (many IDs share same frequency).")
print("8. Smoothing should handle extreme frequency ratios (1:5e8).") print("8. Smoothing should handle extreme frequency ratios (1:5e8).")
# Save data for plotting # Save data for plotting
with open("rank_freq.csv", "w") as f: with open("rank_freq.csv", "w") as f:
f.write("rank,frequency\n") f.write("rank,frequency\n")
@ -173,5 +222,6 @@ def main():
f.write(f"{rank},{freq}\n") f.write(f"{rank},{freq}\n")
print("\nRank-frequency data saved to rank_freq.csv") print("\nRank-frequency data saved to rank_freq.csv")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import Dict, Any from typing import Dict, Any
import sys import sys
def modify_pinyin_statistics(file_path: Path) -> None: def modify_pinyin_statistics(file_path: Path) -> None:
""" """
一次性修改拼音统计JSON文件 一次性修改拼音统计JSON文件
@ -13,7 +14,7 @@ def modify_pinyin_statistics(file_path: Path) -> None:
""" """
# 1. 加载原数据 # 1. 加载原数据
try: try:
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, "r", encoding="utf-8") as f:
data: Dict[str, Any] = json.load(f) data: Dict[str, Any] = json.load(f)
except FileNotFoundError: except FileNotFoundError:
print(f"错误:文件不存在 {file_path}", file=sys.stderr) print(f"错误:文件不存在 {file_path}", file=sys.stderr)
@ -34,7 +35,7 @@ def modify_pinyin_statistics(file_path: Path) -> None:
"id": 0, "id": 0,
"char": "", "char": "",
"pinyin": "", "pinyin": "",
"count": original_zero_count + 1 # 原count + 1 "count": original_zero_count + 1, # 原count + 1
} }
# 3.2 处理其他所有记录键和id都+1 # 3.2 处理其他所有记录键和id都+1
@ -54,15 +55,16 @@ def modify_pinyin_statistics(file_path: Path) -> None:
# 这里保持原时间戳不变,因为是一次性修改 # 这里保持原时间戳不变,因为是一次性修改
# 写回文件,保持可读格式 # 写回文件,保持可读格式
backup_path = file_path.with_suffix('.json.bak') backup_path = file_path.with_suffix(".json.bak")
try: try:
# 先备份原文件 # 先备份原文件
import shutil import shutil
shutil.copy2(file_path, backup_path) shutil.copy2(file_path, backup_path)
print(f"已创建备份: {backup_path}") print(f"已创建备份: {backup_path}")
# 写入新数据 # 写入新数据
with open(file_path, 'w', encoding='utf-8') as f: with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2) json.dump(data, f, ensure_ascii=False, indent=2)
print(f"修改完成!") print(f"修改完成!")
@ -78,15 +80,21 @@ def modify_pinyin_statistics(file_path: Path) -> None:
# 使用示例 # 使用示例
if __name__ == "__main__": if __name__ == "__main__":
# 假设你的JSON文件在当前目录 # JSON文件相对于项目根目录
json_file = Path("./src/model/assets/pinyin_char_statistics.json") json_file = (
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
# 执行修改 # 执行修改
modify_pinyin_statistics(json_file) modify_pinyin_statistics(json_file)
# 验证修改:读取并显示前几条记录 # 验证修改:读取并显示前几条记录
print("\n验证前5条记录:") print("\n验证前5条记录:")
with open(json_file, 'r', encoding='utf-8') as f: with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
for i in range(5): for i in range(5):

View File

@ -8,110 +8,127 @@ import math
import sys import sys
from pathlib import Path from pathlib import Path
def ascii_histogram(data, bins=20, width=60): def ascii_histogram(data, bins=20, width=60):
"""Create ASCII histogram""" """Create ASCII histogram"""
if not data: if not data:
return "" return ""
min_val = min(data) min_val = min(data)
max_val = max(data) max_val = max(data)
# Use log bins for wide range # Use log bins for wide range
if max_val / min_val > 1000: if max_val / min_val > 1000:
log_min = math.log10(min_val) if min_val > 0 else 0 log_min = math.log10(min_val) if min_val > 0 else 0
log_max = math.log10(max_val) log_max = math.log10(max_val)
bin_edges = [10**(log_min + i*(log_max-log_min)/bins) for i in range(bins+1)] bin_edges = [
10 ** (log_min + i * (log_max - log_min) / bins) for i in range(bins + 1)
]
hist = [0] * bins hist = [0] * bins
for val in data: for val in data:
if val > 0: if val > 0:
log_val = math.log10(val) log_val = math.log10(val)
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * bins), bins-1) bin_idx = min(
int((log_val - log_min) / (log_max - log_min) * bins), bins - 1
)
hist[bin_idx] += 1 hist[bin_idx] += 1
bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i+1]:.1e}" for i in range(bins)] bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i + 1]:.1e}" for i in range(bins)]
else: else:
bin_width = (max_val - min_val) / bins bin_width = (max_val - min_val) / bins
bin_edges = [min_val + i*bin_width for i in range(bins+1)] bin_edges = [min_val + i * bin_width for i in range(bins + 1)]
hist = [0] * bins hist = [0] * bins
for val in data: for val in data:
bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins-1) bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins - 1)
hist[bin_idx] += 1 hist[bin_idx] += 1
bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i+1]:.1f}" for i in range(bins)] bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i + 1]:.1f}" for i in range(bins)]
max_count = max(hist) max_count = max(hist)
result = [] result = []
for i in range(bins): for i in range(bins):
if hist[i] == 0: if hist[i] == 0:
continue continue
bar = '#' * int(hist[i] / max_count * width) bar = "#" * int(hist[i] / max_count * width)
result.append(f"{bin_labels[i]:20} | {bar} {hist[i]}") result.append(f"{bin_labels[i]:20} | {bar} {hist[i]}")
return "\n".join(result) return "\n".join(result)
def main(): def main():
json_path = Path("src/model/assets/pinyin_char_statistics.json") json_path = (
with open(json_path, 'r', encoding='utf-8') as f: Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
pairs = data.get('pairs', {}) pairs = data.get("pairs", {})
counts = [pair.get('count', 0) for pair in pairs.values() if pair.get('count') is not None] counts = [
pair.get("count", 0) for pair in pairs.values() if pair.get("count") is not None
]
print("FREQUENCY DISTRIBUTION ANALYSIS") print("FREQUENCY DISTRIBUTION ANALYSIS")
print("="*60) print("=" * 60)
print("\n1. ASCII Histogram (log bins):") print("\n1. ASCII Histogram (log bins):")
print(ascii_histogram(counts, bins=20, width=60)) print(ascii_histogram(counts, bins=20, width=60))
# Rank-frequency plot in ASCII # Rank-frequency plot in ASCII
print("\n2. Rank-Frequency Relationship (Top 50):") print("\n2. Rank-Frequency Relationship (Top 50):")
counts_sorted_desc = sorted(counts, reverse=True) counts_sorted_desc = sorted(counts, reverse=True)
max_freq = counts_sorted_desc[0] max_freq = counts_sorted_desc[0]
max_rank = 50 max_rank = 50
for rank in range(1, max_rank + 1): for rank in range(1, max_rank + 1):
freq = counts_sorted_desc[rank-1] freq = counts_sorted_desc[rank - 1]
bar_length = int(math.log(freq) / math.log(max_freq) * 40) bar_length = int(math.log(freq) / math.log(max_freq) * 40)
bar = '#' * bar_length bar = "#" * bar_length
print(f"Rank {rank:3}: {freq:12} {bar}") print(f"Rank {rank:3}: {freq:12} {bar}")
# ID vs Frequency plot (sampled) # ID vs Frequency plot (sampled)
print("\n3. ID vs Frequency (sampled every 500 IDs):") print("\n3. ID vs Frequency (sampled every 500 IDs):")
# Build ID to count mapping # Build ID to count mapping
id_to_count = {} id_to_count = {}
for key, pair in pairs.items(): for key, pair in pairs.items():
char_id = pair.get('id') char_id = pair.get("id")
count = pair.get('count') count = pair.get("count")
if char_id is not None and count is not None: if char_id is not None and count is not None:
id_to_count[char_id] = count id_to_count[char_id] = count
all_ids = sorted(id_to_count.keys()) all_ids = sorted(id_to_count.keys())
max_id = all_ids[-1] max_id = all_ids[-1]
print("ID Frequency log10(freq)") print("ID Frequency log10(freq)")
for id in range(0, max_id + 1, 500): for id in range(0, max_id + 1, 500):
if id in id_to_count: if id in id_to_count:
freq = id_to_count[id] freq = id_to_count[id]
log_freq = math.log10(freq) if freq > 0 else 0 log_freq = math.log10(freq) if freq > 0 else 0
bar = '#' * int(log_freq / math.log10(max_freq) * 40) bar = "#" * int(log_freq / math.log10(max_freq) * 40)
print(f"{id:6} {freq:10} {log_freq:6.2f} {bar}") print(f"{id:6} {freq:10} {log_freq:6.2f} {bar}")
# Zipf's law fit # Zipf's law fit
print("\n4. Zipf's Law Analysis:") print("\n4. Zipf's Law Analysis:")
print(" Rank * Frequency ≈ constant for Zipf's law") print(" Rank * Frequency ≈ constant for Zipf's law")
print(" Top 10 ranks:") print(" Top 10 ranks:")
for rank in range(1, 11): for rank in range(1, 11):
freq = counts_sorted_desc[rank-1] freq = counts_sorted_desc[rank - 1]
product = rank * freq product = rank * freq
print(f" Rank {rank}: {freq:12} rank*freq = {product:.3e}") print(f" Rank {rank}: {freq:12} rank*freq = {product:.3e}")
# Check if product is roughly constant # Check if product is roughly constant
products = [(rank+1) * counts_sorted_desc[rank] for rank in range(10)] products = [(rank + 1) * counts_sorted_desc[rank] for rank in range(10)]
avg_product = sum(products) / len(products) avg_product = sum(products) / len(products)
std_product = math.sqrt(sum((p - avg_product)**2 for p in products) / len(products)) std_product = math.sqrt(
sum((p - avg_product) ** 2 for p in products) / len(products)
)
print(f" Average product (ranks 2-11): {avg_product:.3e} ± {std_product:.3e}") print(f" Average product (ranks 2-11): {avg_product:.3e} ± {std_product:.3e}")
print(f" Coefficient of variation: {std_product/avg_product*100:.1f}%") print(f" Coefficient of variation: {std_product / avg_product * 100:.1f}%")
# Frequency spectrum # Frequency spectrum
from collections import Counter from collections import Counter
freq_counter = Counter(counts) freq_counter = Counter(counts)
print("\n5. Frequency Spectrum (how many entries have each frequency):") print("\n5. Frequency Spectrum (how many entries have each frequency):")
print(" Frequency Count Cumulative") print(" Frequency Count Cumulative")
@ -120,21 +137,21 @@ def main():
count = freq_counter[freq] count = freq_counter[freq]
cum += count cum += count
print(f" {freq:10} {count:6} {cum:6}") print(f" {freq:10} {count:6} {cum:6}")
# Summary statistics # Summary statistics
print("\n6. Key Statistics:") print("\n6. Key Statistics:")
n = len(counts) n = len(counts)
print(f" Total entries: {n}") print(f" Total entries: {n}")
print(f" Min frequency: {min(counts)}") print(f" Min frequency: {min(counts)}")
print(f" Max frequency: {max(counts)}") print(f" Max frequency: {max(counts)}")
print(f" Ratio max/min: {max(counts)/min(counts):.2e}") print(f" Ratio max/min: {max(counts) / min(counts):.2e}")
percentiles = [0.01, 0.1, 0.5, 0.9, 0.99] percentiles = [0.01, 0.1, 0.5, 0.9, 0.99]
for p in percentiles: for p in percentiles:
idx = int(p * n) idx = int(p * n)
value = counts_sorted_desc[idx] value = counts_sorted_desc[idx]
print(f" {p*100:5.1f}th percentile: {value:12} (rank ~{idx})") print(f" {p * 100:5.1f}th percentile: {value:12} (rank ~{idx})")
# Save data for external plotting # Save data for external plotting
with open("id_vs_freq.csv", "w") as f: with open("id_vs_freq.csv", "w") as f:
f.write("id,frequency\n") f.write("id,frequency\n")
@ -142,5 +159,6 @@ def main():
f.write(f"{id},{id_to_count[id]}\n") f.write(f"{id},{id_to_count[id]}\n")
print("\nData saved to id_vs_freq.csv for external plotting") print("\nData saved to id_vs_freq.csv for external plotting")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -5,10 +5,9 @@ warnings.filterwarnings("ignore", message=".*pkg_resources.*")
import jieba import jieba
import math import math
import random import random
import re
from importlib.resources import files from importlib.resources import files
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Set, Tuple
import numpy as np import numpy as np
import torch import torch
@ -21,7 +20,6 @@ from torch.utils.data import IterableDataset
from .query import QueryEngine from .query import QueryEngine
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26 CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
CHAR_TO_ID["`"] = 27 # 显式添加反引号 CHAR_TO_ID["`"] = 27 # 显式添加反引号
@ -76,6 +74,8 @@ class PinyinInputDataset(IterableDataset):
merge_max_total_chars: int = 6, merge_max_total_chars: int = 6,
low_freq_repeat: float = 50.0, low_freq_repeat: float = 50.0,
high_freq_repeat: float = 0.1, high_freq_repeat: float = 0.1,
data_kwargs: Optional[Dict] = None,
target_labels: Optional[Set[int]] = None,
): ):
# 频率调整参数 - 幂律平滑方案 # 频率调整参数 - 幂律平滑方案
self.min_freq = 109 self.min_freq = 109
@ -88,6 +88,9 @@ class PinyinInputDataset(IterableDataset):
self.merge_max_short_words = merge_max_short_words self.merge_max_short_words = merge_max_short_words
self.merge_max_total_chars = merge_max_total_chars self.merge_max_total_chars = merge_max_total_chars
self.data_kwargs = data_kwargs or {}
self.target_labels = target_labels
jieba.initialize() jieba.initialize()
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
@ -98,7 +101,9 @@ class PinyinInputDataset(IterableDataset):
self.max_iter_length = max_iter_length self.max_iter_length = max_iter_length
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.text_field = text_field self.text_field = text_field
self.dataset = load_dataset(data_path, split="train", streaming=True) load_kwargs = {"split": "train", "streaming": True}
load_kwargs.update(self.data_kwargs)
self.dataset = load_dataset(data_path, **load_kwargs)
self.max_workers = max_workers self.max_workers = max_workers
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight) self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
self.shuffle_buffer_size = shuffle_buffer_size self.shuffle_buffer_size = shuffle_buffer_size
@ -155,12 +160,13 @@ class PinyinInputDataset(IterableDataset):
# 生成对应文本的拼音 # 生成对应文本的拼音
def generate_pinyin(self, text: str) -> List[str]: def generate_pinyin(self, text: str) -> List[str]:
""" """
流式处理单条文本转换为拼音列表 将文本转换为拼音列表对整段文本调用 lazy_pinyin
利用 errors 回调确保一一对应对生僻字从 QueryEngine 回退
特性 特性
1. 严格一一对应len(result) == len(text) 1. 严格一一对应len(result) == len(text)
2. 高多音字准确率利用 pypinyin 内部的词语分词能力 2. pypinyin 不认识的生僻字回退到 QueryEngine 最高频读音
3. 高性能预分配内存无多余对象创建 3. 非汉字字符原样占位
Args: Args:
text: 输入字符串 text: 输入字符串
@ -171,39 +177,35 @@ class PinyinInputDataset(IterableDataset):
if not text: if not text:
return [] return []
text_len = len(text) def _fallback(chars):
# 2. 预分配结果列表,初始化占位符。 # lazy_pinyin 会把连续无拼音的字符聚合成一个字符串传入,
# 使用 None 或空字符串均可,这里用空字符串方便后续判断 # 必须逐字符处理,确保返回列表长度与输入字符数一致。
result: List[str] = [""] * text_len result = []
for char in chars:
if self.query_engine.is_chinese_char(char):
ids = self.query_engine.query_by_char(char, limit=1)
if ids:
result.append(ids[0][1])
else:
result.append(char)
else:
result.append(char)
return result
# 3. 遍历所有连续汉字片段 pinyin_list = lazy_pinyin(text, errors=_fallback)
for match in _HANZI_RE.finditer(text):
start_idx = match.start()
hanzi_segment = match.group()
# 4. 核心转换:利用 pypinyin 的分词能力处理该片段 # 防御性校验:若长度仍不匹配(极罕见),逐字回退
# style=Style.NORMAL 获取不带声调的拼音 if len(pinyin_list) != len(text):
pinyin_list = lazy_pinyin(hanzi_segment) logger.warning(
f"pinyin length mismatch: text_len={len(text)}, "
f"pinyin_len={len(pinyin_list)}, text={text[:50]!r}"
)
pinyin_list = []
for c in text:
result = lazy_pinyin(c, errors=_fallback)
pinyin_list.append(result[0] if result else c)
# 5. 健壮性兜底: return pinyin_list
# 正常情况下pypinyin 返回的拼音数应等于汉字数。
# 若不等(极罕见,如遇到特殊 Unicode 标点被误判为汉字),降级为单字转换
if len(pinyin_list) != len(hanzi_segment):
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
# 6. 直接通过索引填充到预分配的位置
# 这比 list slicing assignment (result[start:end] = pinyin_list) 略快且更直观
for i, py in enumerate(pinyin_list):
result[start_idx + i] = py
# 7. 填充非汉字字符
# 遍历原文,如果 result 对应位置为空,则填入原字符
# 注意:对于纯汉字文本,这一步很快;对于混合文本,这是必要的
for i, char in enumerate(text):
if not result[i]:
result[i] = char
return result
def get_mask_pinyin( def get_mask_pinyin(
self, text: str, pinyin_list: List[str] self, text: str, pinyin_list: List[str]
@ -243,51 +245,61 @@ class PinyinInputDataset(IterableDataset):
pinyin_ids = pinyin_ids[:24] pinyin_ids = pinyin_ids[:24]
return torch.tensor(pinyin_ids, dtype=torch.long) return torch.tensor(pinyin_ids, dtype=torch.long)
def _add_word_samples( def _build_single_sample(
self, self,
batch_samples: list, label: int,
labels: list, history: list,
encoded: dict, text: str,
part4: str, word_start: int,
part1: str, word_end: int,
part3: str, part2: str,
pinyin_str: str,
pinyin_ids: torch.Tensor, pinyin_ids: torch.Tensor,
) -> list: words: list,
for label_idx, label in enumerate(labels): ) -> dict:
base_repeats = self.adjust_frequency(self.sample_freqs.get(label, 0)) """构造单条样本,每次调用都会重新随机采样上下文"""
if base_repeats == 0:
continue
weight = (
self._history_weights[label_idx]
if label_idx < len(self._history_weights)
else 3.0
)
repeats = max(1, int(base_repeats * weight))
history = labels[:label_idx] # part1 长度:高斯分布 N(36, 6^2),截断 [0, min(48, word_start)]
if len(history) > 8: part1_len = min(max(int(random.gauss(36, 6)), 0), 48, word_start)
history = history[-8:] part1 = text[word_start - part1_len : word_start]
else:
history.extend([0] * (8 - len(history)))
sample_dict = { # part3每次重新 roll
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), part3 = ""
"token_type_ids": torch.tensor( if random.random() > 0.7:
encoded["token_type_ids"], dtype=torch.long part3 = text[word_end : word_end + random.randint(1, 16)]
),
"attention_mask": torch.tensor( # part4每次重新 roll
encoded["attention_mask"], dtype=torch.long part4 = ""
), if random.random() > 0.7 and words:
"label": torch.tensor([label], dtype=torch.long), num_words = random.randint(1, 3)
"history_slot_ids": torch.tensor(history, dtype=torch.long), selected_words = random.sample(words, min(num_words, len(words)))
"prefix": f"{part4}^{part1}", part4 = "|".join(selected_words)
"suffix": part3,
"pinyin": pinyin_str, encoded = self.tokenizer(
"pinyin_ids": pinyin_ids, f"{part4}|{part1}",
} part3,
batch_samples.extend([sample_dict] * repeats) max_length=self.max_seq_length,
return batch_samples truncation=True,
return_token_type_ids=True,
)
# 确保 history 长度为 8
hist = list(history)
if len(hist) > 8:
hist = hist[-8:]
while len(hist) < 8:
hist.append(0)
return {
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
"token_type_ids": torch.tensor(encoded["token_type_ids"], dtype=torch.long),
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
"label": torch.tensor([label], dtype=torch.long),
"history_slot_ids": torch.tensor(hist, dtype=torch.long),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": part2,
"pinyin_ids": pinyin_ids,
}
def __iter__(self): def __iter__(self):
worker_info = torch.utils.data.get_worker_info() worker_info = torch.utils.data.get_worker_info()
@ -436,42 +448,43 @@ class PinyinInputDataset(IterableDataset):
if not should_break and random.random() <= 0.1: if not should_break and random.random() <= 0.1:
labels.append(0) labels.append(0)
# part1: 词起点前的文本(所有样本共享) # 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
part1 = text[max(0, word_start - 48) : word_start] processed_history = []
for label_idx, label in enumerate(labels):
base_repeats = self.adjust_frequency(
self.sample_freqs.get(label, 0)
)
if base_repeats == 0:
processed_history.append(label)
continue
if (
self.target_labels is not None
and label not in self.target_labels
):
processed_history.append(label)
continue
# part3: 词后文本 weight = (
part3 = "" self._history_weights[label_idx]
if random.random() > 0.7: if label_idx < len(self._history_weights)
part3 = text[word_end : word_end + random.randint(1, 16)] else 3.0
)
repeats = max(1, int(base_repeats * weight))
# part4: 词提示 for _ in range(repeats):
part4 = "" sample = self._build_single_sample(
if random.random() > 0.7: label=label,
num_words = random.randint(1, 3) history=processed_history,
if words: text=text,
selected_words = random.sample( word_start=word_start,
words, min(num_words, len(words)) word_end=word_end,
part2=part2,
pinyin_ids=pinyin_ids,
words=words,
) )
part4 = "|".join(selected_words) batch_samples.append(sample)
encoded = self.tokenizer( processed_history.append(label)
f"{part4}|{part1}",
part3,
max_length=self.max_seq_length,
truncation=True,
return_token_type_ids=True,
)
batch_samples = self._add_word_samples(
batch_samples,
labels,
encoded,
part4,
part1,
part3,
part2,
pinyin_ids,
)
# ========== Phase 2: 破词续接 ========== # ========== Phase 2: 破词续接 ==========
if should_break and break_pos < word_len_chars: if should_break and break_pos < word_len_chars:
@ -533,33 +546,44 @@ class PinyinInputDataset(IterableDataset):
if random.random() <= 0.1: if random.random() <= 0.1:
cont_labels.append(0) cont_labels.append(0)
# part1_cont: 包含已确认前缀的上下文 # 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
part1_cont = text[max(0, cont_start - 48) : cont_start] cont_processed_history = []
# part3_cont: 续接目标后的文本
cont_end = cont_positions[-1] + 1 cont_end = cont_positions[-1] + 1
part3_cont = "" for label_idx, label in enumerate(cont_labels):
if random.random() > 0.7: base_repeats = self.adjust_frequency(
part3_cont = text[cont_end : cont_end + random.randint(1, 16)] self.sample_freqs.get(label, 0)
)
if base_repeats == 0:
cont_processed_history.append(label)
continue
if (
self.target_labels is not None
and label not in self.target_labels
):
cont_processed_history.append(label)
continue
encoded_cont = self.tokenizer( weight = (
f"{part4}|{part1_cont}", self._history_weights[label_idx]
part3_cont, if label_idx < len(self._history_weights)
max_length=self.max_seq_length, else 3.0
truncation=True, )
return_token_type_ids=True, repeats = max(1, int(base_repeats * weight))
)
batch_samples = self._add_word_samples( for _ in range(repeats):
batch_samples, sample = self._build_single_sample(
cont_labels, label=label,
encoded_cont, history=cont_processed_history,
part4, text=text,
part1_cont, word_start=cont_start,
part3_cont, word_end=cont_end,
part2_cont, part2=part2_cont,
pinyin_ids_cont, pinyin_ids=pinyin_ids_cont,
) words=words,
)
batch_samples.append(sample)
cont_processed_history.append(label)
idx = merge_end_idx idx = merge_end_idx

View File

@ -4,6 +4,7 @@
步骤 1: find-missing 扫描已预处理数据找出从未出现的 label ID输出 JSON 步骤 1: find-missing 扫描已预处理数据找出从未出现的 label ID输出 JSON
步骤 2: generate-template 根据 JSON 生成 JSONL 占位文件供用户手动填入包含缺失字的真实文本 步骤 2: generate-template 根据 JSON 生成 JSONL 占位文件供用户手动填入包含缺失字的真实文本
步骤 3: preprocess-supplement 将填好的 JSONL 文本预处理为 .npz 分片输出到独立目录
用法 用法
python -m model.supplement_missing find-missing \ python -m model.supplement_missing find-missing \
@ -13,6 +14,12 @@
python -m model.supplement_missing generate-template \ python -m model.supplement_missing generate-template \
--missing-chars missing_chars.json \ --missing-chars missing_chars.json \
--output supplement_texts.jsonl --output supplement_texts.jsonl
python -m model.supplement_missing preprocess-supplement \
--missing-chars missing_chars.json \
--supplement-texts supplement_texts.jsonl \
--output-dir ./preprocessed/supplement \
--num-samples 100000
""" """
import argparse import argparse
@ -21,12 +28,17 @@ from pathlib import Path
from typing import Set from typing import Set
import numpy as np import numpy as np
import torch
from loguru import logger from loguru import logger
from rich.console import Console from rich.console import Console
from rich.table import Table from rich.table import Table
from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from .dataset import PinyinInputDataset
from .preprocess import collect_samples
from .query import QueryEngine from .query import QueryEngine
from .trainer import preprocess_collate_fn, worker_init_fn
def scan_labels(preprocessed_dir: Path) -> Set[int]: def scan_labels(preprocessed_dir: Path) -> Set[int]:
@ -175,14 +187,116 @@ def cmd_generate_template(args):
) )
def cmd_preprocess_supplement(args):
console = Console()
# 加载缺失字符
missing_path = Path(args.missing_chars)
if not missing_path.exists():
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
return
with open(missing_path, "r", encoding="utf-8") as f:
data = json.load(f)
missing_chars = data.get("missing_chars", [])
if not missing_chars:
console.print("[bold green]没有缺失字符,无需处理[/bold green]")
return
target_labels = {entry["id"] for entry in missing_chars}
target_labels.add(0) # 包含 EOS
# 解析参数
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
length_weights = {
int(k): int(v)
for k, v in (item.split(":") for item in args.length_weights.split(","))
}
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
max_iter = args.num_samples * 5
num_workers = args.num_workers
console.print("[bold cyan]=== 补充数据预处理 ===[/bold cyan]")
console.print(f"补充文本: {args.supplement_texts}")
console.print(f"缺失字符数: {len(missing_chars)}")
console.print(f"目标样本: {args.num_samples:,}")
console.print(f"输出目录: {output_dir}")
console.print(f"Worker 数: {num_workers}")
console.print()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
console.print("[bold cyan]创建补充数据集...[/bold cyan]")
dataset = PinyinInputDataset(
data_path="json",
max_workers=num_workers,
max_iter_length=max_iter,
max_seq_length=args.max_seq_length,
text_field="text",
py_style_weight=py_style_weight,
shuffle_buffer_size=100,
length_weights=length_weights,
data_kwargs={
"data_files": args.supplement_texts,
"streaming": False,
},
target_labels=target_labels,
)
dataloader_kwargs = {
"batch_size": args.batch_size,
"num_workers": num_workers,
"pin_memory": False,
"worker_init_fn": worker_init_fn,
"collate_fn": preprocess_collate_fn(args.max_seq_length),
}
if num_workers > 0:
dataloader_kwargs["prefetch_factor"] = 2
dataloader_kwargs["persistent_workers"] = True
dataloader = DataLoader(dataset, **dataloader_kwargs)
logger.info("开始收集补充数据...")
count = collect_samples(
dataloader,
args.num_samples,
output_dir,
"supplement",
args.max_seq_length,
args.shard_size,
)
if count < args.num_samples:
logger.warning(f"补充样本不足: 目标 {args.num_samples}, 实际 {count}")
console.print("\n[bold green]=== 补充预处理完成 ===[/bold green]")
console.print(f"生成样本: {count:,}")
console.print(f"输出目录: {output_dir}")
total_size = sum(
f.stat().st_size for f in output_dir.iterdir() if f.suffix == ".npz"
)
console.print(f"总大小: {total_size / (1024**3):.2f} GB (compressed)")
console.print()
console.print(
"[bold yellow]提示[/bold yellow]: 请检查补充数据质量,清洗无误后手动将 shard_*.npz 合并到 train/ 目录并更新 metadata.json"
)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="缺失字符补充工具", description="缺失字符补充工具",
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=""" epilog="""
子命令 子命令
find-missing 扫描已预处理数据找出从未出现的 label ID find-missing 扫描已预处理数据找出从未出现的 label ID
generate-template 根据缺失字符 JSON 生成 JSONL 占位文件 generate-template 根据缺失字符 JSON 生成 JSONL 占位文件
preprocess-supplement 将填好的 JSONL 预处理为 .npz 分片独立目录
示例 示例
python -m model.supplement_missing find-missing \\ python -m model.supplement_missing find-missing \\
@ -192,6 +306,12 @@ def main():
python -m model.supplement_missing generate-template \\ python -m model.supplement_missing generate-template \\
--missing-chars missing_chars.json \\ --missing-chars missing_chars.json \\
--output supplement_texts.jsonl --output supplement_texts.jsonl
python -m model.supplement_missing preprocess-supplement \\
--missing-chars missing_chars.json \\
--supplement-texts supplement_texts.jsonl \\
--output-dir ./preprocessed/supplement \\
--num-samples 100000
""", """,
) )
subparsers = parser.add_subparsers(dest="command", help="子命令") subparsers = parser.add_subparsers(dest="command", help="子命令")
@ -232,6 +352,77 @@ def main():
help="每个缺失字符生成的模板条数(默认: 3", help="每个缺失字符生成的模板条数(默认: 3",
) )
# preprocess-supplement
p_pre = subparsers.add_parser(
"preprocess-supplement", help="将 JSONL 预处理为 .npz 分片"
)
p_pre.add_argument(
"--missing-chars",
type=str,
required=True,
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
)
p_pre.add_argument(
"--supplement-texts",
type=str,
required=True,
help="已填写的补充文本 JSONL 文件路径",
)
p_pre.add_argument(
"--output-dir",
type=str,
required=True,
help="输出目录(独立目录,不会覆盖已有数据)",
)
p_pre.add_argument(
"--num-samples",
type=int,
required=True,
help="目标样本数量",
)
p_pre.add_argument(
"--batch-size",
type=int,
default=128,
help="批大小(默认: 128",
)
p_pre.add_argument(
"--num-workers",
type=int,
default=0,
help="DataLoader worker 数量。本地 JSONL 小文件建议 0默认: 0",
)
p_pre.add_argument(
"--max-seq-length",
type=int,
default=128,
help="最大序列长度(默认: 128",
)
p_pre.add_argument(
"--seed",
type=int,
default=42,
help="随机种子(默认: 42",
)
p_pre.add_argument(
"--shard-size",
type=int,
default=5_000_000,
help="分片大小(样本数),控制内存峰值(默认: 500万",
)
p_pre.add_argument(
"--py-style-weight",
type=str,
default="9,2,1",
help="拼音风格权重(逗号分隔,默认: 9,2,1",
)
p_pre.add_argument(
"--length-weights",
type=str,
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
help="词长权重(默认: 1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
)
args = parser.parse_args() args = parser.parse_args()
if args.command is None: if args.command is None:
@ -242,6 +433,8 @@ def main():
cmd_find_missing(args) cmd_find_missing(args)
elif args.command == "generate-template": elif args.command == "generate-template":
cmd_generate_template(args) cmd_generate_template(args)
elif args.command == "preprocess-supplement":
cmd_preprocess_supplement(args)
app = main app = main

View File

@ -1,6 +1,7 @@
import os
import sys import sys
sys.path.append("src") sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
import time import time
import torch import torch
@ -26,7 +27,7 @@ from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer" Path(__file__).parent.parent / "src" / "model" / "assets" / "tokenizer"
) )
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+") _HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
@ -83,7 +84,9 @@ sample = {
model = InputMethodEngine(pinyin_vocab_size=30, compile=False) model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu") checkpoint = torch.load(
"/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu"
)
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
model.eval() model.eval()
@ -100,7 +103,7 @@ for k, v in sample.items():
start = time.time() start = time.time()
with torch.no_grad(): with torch.no_grad():
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids) res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
print(f'计算时长: {(time.time() - start) * 1000:4f}ms') print(f"计算时长: {(time.time() - start) * 1000:4f}ms")
sort_res = sorted( sort_res = sorted(
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True [(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
) )

View File

@ -392,7 +392,13 @@ def check_compile_issues():
issues = [] issues = []
# 检查 components.py 中的潜在问题 # 检查 components.py 中的潜在问题
with open("src/model/components.py", "r") as f: components_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"src",
"model",
"components.py",
)
with open(components_path, "r") as f:
content = f.read() content = f.read()
# 检查 float('-inf') # 检查 float('-inf')

View File

@ -4,9 +4,13 @@
解决设备转换和权重加载问题 解决设备转换和权重加载问题
""" """
import os
import sys import sys
from pathlib import Path from pathlib import Path
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch import torch
@ -196,7 +200,7 @@ def test_id_mapping():
query_engine = QueryEngine() query_engine = QueryEngine()
stats_path = ( stats_path = (
Path(__file__).parent Path(__file__).parent.parent
/ "src" / "src"
/ "model" / "model"
/ "assets" / "assets"

View File

@ -1,6 +1,7 @@
import os
import sys import sys
sys.path.append("src") sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import time import time
@ -42,23 +43,49 @@ def worker_init_fn(worker_id: int) -> None:
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
""" """
自定义批处理函数将多个样本组合成一个batch 自定义批处理函数将多个样本组合成一个batch
支持动态padding根据batch内最大序列长度进行padding
Args:
batch: 样本列表每个样本是一个字典
Returns:
批处理后的字典tensor字段已stack字符串字段保持为列表
""" """
# 处理tensor字段 - 使用squeeze去除多余的batch维度 input_ids_list = [item["input_ids"] for item in batch]
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch]) token_type_ids_list = [item["token_type_ids"] for item in batch]
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch]) attention_mask_list = [item["attention_mask"] for item in batch]
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
target_len = max(ids.shape[0] for ids in input_ids_list)
padded_input_ids = []
padded_token_type_ids = []
padded_attention_mask = []
for ids, tt_ids, mask in zip(
input_ids_list, token_type_ids_list, attention_mask_list
):
seq_len = ids.shape[0]
if seq_len < target_len:
pad_len = target_len - seq_len
padded_input_ids.append(
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
)
padded_token_type_ids.append(
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
)
padded_attention_mask.append(
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
)
elif seq_len > target_len:
padded_input_ids.append(ids[:target_len])
padded_token_type_ids.append(tt_ids[:target_len])
padded_attention_mask.append(mask[:target_len])
else:
padded_input_ids.append(ids)
padded_token_type_ids.append(tt_ids)
padded_attention_mask.append(mask)
input_ids = torch.stack(padded_input_ids)
token_type_ids = torch.stack(padded_token_type_ids)
attention_mask = torch.stack(padded_attention_mask)
labels = torch.stack([item["label"].squeeze(0) for item in batch]) labels = torch.stack([item["label"].squeeze(0) for item in batch])
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch]) history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch]) pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
# 字符串字段保持为列表
prefixes = [item["prefix"] for item in batch] prefixes = [item["prefix"] for item in batch]
suffixes = [item["suffix"] for item in batch] suffixes = [item["suffix"] for item in batch]
pinyins = [item["pinyin"] for item in batch] pinyins = [item["pinyin"] for item in batch]
@ -96,6 +123,5 @@ dataloader = DataLoader(
persistent_workers=True, persistent_workers=True,
) )
for i, shape in tqdm(enumerate(dataloader), total=1000000/512): for i, shape in tqdm(enumerate(dataloader), total=1000000 / 512):
pass pass

View File

@ -6,8 +6,8 @@ import torch.nn as nn
from rich.console import Console from rich.console import Console
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
# 添加src目录到路径 # 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent)) sys.path.insert(0, str(Path(__file__).parent.parent))
from src.model.model import InputMethodEngine from src.model.model import InputMethodEngine
from src.model.trainer import Trainer from src.model.trainer import Trainer