refactor(generate_pinyin): 优化拼音生成逻辑,利用 pypinyin 分词能力处理多音字

This commit is contained in:
songsenand 2026-05-09 13:36:48 +08:00
parent 8b41bcdc6f
commit e8eab1f260
15 changed files with 739 additions and 369 deletions

31
eval.py
View File

@ -14,7 +14,6 @@ eval.py - 评估模型在给定文本上的表现
import argparse
import random
import re
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Optional
@ -33,8 +32,6 @@ from src.model.model import InputMethodEngine
from src.model.query import QueryEngine
from src.model.dataset import text_to_pinyin_ids
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
class TextEvaluator:
def __init__(
@ -171,34 +168,20 @@ class TextEvaluator:
def generate_pinyin(self, text: str) -> List[str]:
"""
流式处理单条文本转换为拼音列表
将文本转换为拼音列表对整段文本调用 lazy_pinyin
利用 pypinyin 内部的分词能力处理多音字
参考 dataset.py 中的 generate_pinyin 方法
"""
if not text:
return []
text_len = len(text)
result: List[str] = [""] * text_len
pinyin_list = lazy_pinyin(text)
# 遍历所有连续汉字片段
for match in _HANZI_RE.finditer(text):
start_idx = match.start()
hanzi_segment = match.group()
# 健壮性兜底:若长度不匹配(极罕见),降级为逐字转换
if len(pinyin_list) != len(text):
pinyin_list = [lazy_pinyin(c)[0] for c in text]
pinyin_list = lazy_pinyin(hanzi_segment)
if len(pinyin_list) != len(hanzi_segment):
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
for i, py in enumerate(pinyin_list):
result[start_idx + i] = py
# 填充非汉字字符
for i, char in enumerate(text):
if not result[i]:
result[i] = char
return result
return pinyin_list
def get_mask_pinyin(
self, text: str, pinyin_list: List[str]

View File

@ -10,15 +10,22 @@ import math
from collections import defaultdict
from pathlib import Path
def main():
# Path to the JSON file
json_path = Path("src/model/assets/pinyin_char_statistics.json")
json_path = (
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
if not json_path.exists():
print(f"Error: File not found: {json_path}")
sys.exit(1)
print(f"Loading {json_path}...")
with open(json_path, 'r', encoding='utf-8') as f:
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
print(f"Timestamp: {data.get('timestamp')}")
@ -26,7 +33,7 @@ def main():
print(f"Total pinyins: {data.get('total_pinyins')}")
print(f"Valid input character count: {data.get('valid_input_character_count')}")
pairs = data.get('pairs', {})
pairs = data.get("pairs", {})
print(f"Number of pairs: {len(pairs)}")
# Extract counts and IDs
@ -35,9 +42,9 @@ def main():
char_to_count = {}
for key, pair in pairs.items():
try:
char_id = pair.get('id')
count = pair.get('count')
char = pair.get('char', '')
char_id = pair.get("id")
count = pair.get("count")
char = pair.get("char", "")
if char_id is not None and count is not None:
counts.append(count)
id_to_count[char_id] = count
@ -131,7 +138,9 @@ def main():
avg_rank_diff = sum(rank_diffs) / len(rank_diffs)
max_rank_diff = max(rank_diffs)
print(f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}")
print(
f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}"
)
print(f"Maximum rank difference: {max_rank_diff}")
# Analyze specific ID range 5000-5500
@ -163,7 +172,9 @@ def main():
# Find IDs with min frequency in this range
min_in_range_ids = [id for id in range_ids if id_to_count[id] == range_min]
print(f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}")
print(
f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}"
)
else:
print("No IDs found in range 5000-5500")
@ -173,13 +184,19 @@ def main():
log_min = math.log10(min_count) if min_count > 0 else 0
log_max = math.log10(max_count)
num_bins = 20
bin_edges = [10**(log_min + i*(log_max-log_min)/num_bins) for i in range(num_bins+1)]
bin_edges = [
10 ** (log_min + i * (log_max - log_min) / num_bins)
for i in range(num_bins + 1)
]
hist = [0] * num_bins
for count in counts:
if count > 0:
log_val = math.log10(count)
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * num_bins), num_bins-1)
bin_idx = min(
int((log_val - log_min) / (log_max - log_min) * num_bins),
num_bins - 1,
)
hist[bin_idx] += 1
print("Log-scale histogram (count range -> frequency count):")
@ -204,20 +221,28 @@ def main():
if non_zero_counts:
actual_min = min(non_zero_counts)
print(f"Actual min frequency (non-zero): {actual_min}")
actual_min_ids = [id for id, count in id_to_count.items() if count == actual_min]
print(f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}")
actual_min_ids = [
id for id, count in id_to_count.items() if count == actual_min
]
print(
f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}"
)
# Summary for smoothing algorithm design
print("\n=== SUMMARY FOR SMOOTHING ALGORITHM DESIGN ===")
print(f"Frequency range spans {max_count/min_count if min_count>0 else 'inf'}:1 ratio")
print(
f"Frequency range spans {max_count / min_count if min_count > 0 else 'inf'}:1 ratio"
)
print(f"Most entries ({p50}) have frequency around {p50}")
print(f"Top 10% of entries have frequency > {p90}")
print(f"Bottom 10% of entries have frequency < {p10}")
print(f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency")
print(
f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency"
)
# Save detailed data for further analysis
output_file = "frequency_analysis_results.txt"
with open(output_file, 'w', encoding='utf-8') as f:
with open(output_file, "w", encoding="utf-8") as f:
f.write("Frequency Analysis Results\n")
f.write("=" * 50 + "\n")
f.write(f"Min frequency: {min_count}\n")
@ -227,10 +252,15 @@ def main():
f.write(f"10th percentile: {p10}\n")
f.write(f"50th percentile: {p50}\n")
f.write(f"90th percentile: {p90}\n")
f.write(f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n")
f.write(f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n")
f.write(
f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n"
)
f.write(
f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n"
)
print(f"\nDetailed results saved to {output_file}")
if __name__ == "__main__":
main()

View File

@ -7,18 +7,25 @@ import json
import sys
from pathlib import Path
def main():
json_path = Path("src/model/assets/pinyin_char_statistics.json")
with open(json_path, 'r', encoding='utf-8') as f:
json_path = (
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
pairs = data.get('pairs', {})
pairs = data.get("pairs", {})
# Build ID to count mapping
id_to_count = {}
for key, pair in pairs.items():
char_id = pair.get('id')
count = pair.get('count')
char_id = pair.get("id")
count = pair.get("count")
if char_id is not None and count is not None:
id_to_count[char_id] = count
@ -31,10 +38,10 @@ def main():
if id in id_to_count:
# Find the pair to get char and pinyin
for key, pair in pairs.items():
if pair.get('id') == id:
char = pair.get('char', '')
pinyin = pair.get('pinyin', '')
count = pair.get('count', 0)
if pair.get("id") == id:
char = pair.get("char", "")
pinyin = pair.get("pinyin", "")
count = pair.get("count", 0)
range_data.append((id, count, char, pinyin))
if id % 100 == 0: # Print every 100th for overview
print(f"{id}\t{count}\t{char}\t{pinyin}")
@ -44,8 +51,12 @@ def main():
if range_data:
min_item = min(range_data, key=lambda x: x[1])
max_item = max(range_data, key=lambda x: x[1])
print(f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'")
print(f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'")
print(
f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'"
)
print(
f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'"
)
# Check if frequencies are monotonic in this range
counts = [item[1] for item in range_data]
@ -55,6 +66,7 @@ def main():
# Check for frequency plateaus
from collections import Counter
freq_count = Counter(counts)
most_common = freq_count.most_common(5)
print(f"Most common frequencies in range: {most_common}")
@ -70,15 +82,19 @@ def main():
# Check if they're contiguous
sorted_ids = sorted(freq_one_ids)
contiguous = all(sorted_ids[i] + 1 == sorted_ids[i+1] for i in range(len(sorted_ids)-1))
contiguous = all(
sorted_ids[i] + 1 == sorted_ids[i + 1] for i in range(len(sorted_ids) - 1)
)
print(f"Are they contiguous IDs? {contiguous}")
# Sample some characters
print("\nSample characters with frequency=1:")
sample_count = 0
for key, pair in pairs.items():
if pair.get('count') == 1 and sample_count < 10:
print(f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'")
if pair.get("count") == 1 and sample_count < 10:
print(
f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'"
)
sample_count += 1
# Check overall ID-frequency ordering
@ -104,12 +120,16 @@ def main():
# Check for frequency plateaus overall
from collections import Counter
overall_freq_count = Counter(all_counts)
plateaus = [(freq, count) for freq, count in overall_freq_count.items() if count > 1]
plateaus = [
(freq, count) for freq, count in overall_freq_count.items() if count > 1
]
plateaus_sorted = sorted(plateaus, key=lambda x: x[1], reverse=True)[:10]
print(f"Top 10 frequency plateaus (freq: count of IDs sharing that freq):")
for freq, count in plateaus_sorted:
print(f" {freq}: {count} IDs")
if __name__ == "__main__":
main()

View File

@ -1,10 +1,15 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from model.dataset import PinyinInputDataset
from torch.utils.data import DataLoader
from model.trainer import collate_fn, worker_init_fn
data = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
data = PinyinInputDataset("/home/songsenand/Data/corpus/CCI-Data/")
dataloader = DataLoader(
data,
@ -18,5 +23,5 @@ dataloader = DataLoader(
)
for i in dataloader:
print((i['labels'] == 1).sum())
print((i["labels"] == 1).sum())
break

View File

@ -9,17 +9,24 @@ import math
from collections import Counter
from pathlib import Path
def main():
json_path = Path("src/model/assets/pinyin_char_statistics.json")
with open(json_path, 'r', encoding='utf-8') as f:
json_path = (
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
pairs = data.get('pairs', {})
pairs = data.get("pairs", {})
# Extract counts
counts = []
for key, pair in pairs.items():
count = pair.get('count')
count = pair.get("count")
if count is not None:
counts.append(count)
@ -44,7 +51,33 @@ def main():
# Cumulative distribution
print("\n=== CUMULATIVE DISTRIBUTION ===")
thresholds = [1, 2, 3, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 500000, 1000000, 5000000, 10000000, 50000000, 100000000, 500000000]
thresholds = [
1,
2,
3,
5,
10,
20,
50,
100,
200,
500,
1000,
2000,
5000,
10000,
20000,
50000,
100000,
200000,
500000,
1000000,
5000000,
10000000,
50000000,
100000000,
500000000,
]
for thresh in thresholds:
if thresh > max_count:
break
@ -58,7 +91,9 @@ def main():
below_109 = sum(1 for c in counts if c < 109)
at_or_above_109 = sum(1 for c in counts if c >= 109)
print(f"Entries with count < 109: {below_109} ({below_109 / n * 100:.1f}%)")
print(f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109/n*100:.1f}%)")
print(
f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109 / n * 100:.1f}%)"
)
# If 109 is a threshold, what's the actual min among those >= 109?
counts_ge_109 = [c for c in counts if c >= 109]
@ -88,8 +123,8 @@ def main():
# Build ID to count mapping
id_to_count = {}
for key, pair in pairs.items():
char_id = pair.get('id')
count = pair.get('count')
char_id = pair.get("id")
count = pair.get("count")
if char_id is not None and count is not None:
id_to_count[char_id] = count
@ -108,13 +143,17 @@ def main():
]
for start, end, label in ranges:
range_counts = [id_to_count[id] for id in range(start, end) if id in id_to_count]
range_counts = [
id_to_count[id] for id in range(start, end) if id in id_to_count
]
if range_counts:
min_c = min(range_counts)
max_c = max(range_counts)
mean_c = sum(range_counts) / len(range_counts)
median_c = sorted(range_counts)[len(range_counts) // 2]
print(f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}")
print(
f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}"
)
# Check if IDs are perfectly sorted by frequency
print("\n=== ID ORDER VERIFICATION ===")
@ -127,7 +166,9 @@ def main():
if all_counts[i] > all_counts[i - 1]:
violations += 1
if violations <= 5:
print(f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i-1]} (ID {all_ids[i-1]})")
print(
f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i - 1]} (ID {all_ids[i - 1]})"
)
print(f"Total violations of non-increasing order: {violations}")
@ -140,13 +181,17 @@ def main():
for i, (id, count) in enumerate(zip(all_ids, all_counts)):
if count != current_freq:
if current_freq is not None:
group_sizes.append((current_freq, group_start, all_ids[i-1], i - group_start))
group_sizes.append(
(current_freq, group_start, all_ids[i - 1], i - group_start)
)
current_freq = count
group_start = i
# Last group
if current_freq is not None:
group_sizes.append((current_freq, group_start, all_ids[-1], len(all_ids) - group_start))
group_sizes.append(
(current_freq, group_start, all_ids[-1], len(all_ids) - group_start)
)
# Sort groups by size
group_sizes.sort(key=lambda x: x[3], reverse=True)
@ -158,11 +203,15 @@ def main():
# Summary for smoothing algorithm
print("\n=== SMOOTHING ALGORITHM IMPLICATIONS ===")
print("1. IDs are perfectly sorted by frequency (non-increasing).")
print(f"2. Frequency range: {min_count} to {max_count} (ratio {max_count/min_count:.1e}:1).")
print(
f"2. Frequency range: {min_count} to {max_count} (ratio {max_count / min_count:.1e}:1)."
)
print(f"3. {below_109} entries ({below_109 / n * 100:.1f}%) have frequency < 109.")
print(f"4. Median frequency: {counts_sorted_desc[n // 2]}.")
print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9 * n)]}.")
print(f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01*n)]}.")
print(
f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01 * n)]}."
)
print("7. Large frequency plateaus exist (many IDs share same frequency).")
print("8. Smoothing should handle extreme frequency ratios (1:5e8).")
@ -173,5 +222,6 @@ def main():
f.write(f"{rank},{freq}\n")
print("\nRank-frequency data saved to rank_freq.csv")
if __name__ == "__main__":
main()

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import Dict, Any
import sys
def modify_pinyin_statistics(file_path: Path) -> None:
"""
一次性修改拼音统计JSON文件
@ -13,7 +14,7 @@ def modify_pinyin_statistics(file_path: Path) -> None:
"""
# 1. 加载原数据
try:
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
data: Dict[str, Any] = json.load(f)
except FileNotFoundError:
print(f"错误:文件不存在 {file_path}", file=sys.stderr)
@ -34,7 +35,7 @@ def modify_pinyin_statistics(file_path: Path) -> None:
"id": 0,
"char": "",
"pinyin": "",
"count": original_zero_count + 1 # 原count + 1
"count": original_zero_count + 1, # 原count + 1
}
# 3.2 处理其他所有记录键和id都+1
@ -54,15 +55,16 @@ def modify_pinyin_statistics(file_path: Path) -> None:
# 这里保持原时间戳不变,因为是一次性修改
# 写回文件,保持可读格式
backup_path = file_path.with_suffix('.json.bak')
backup_path = file_path.with_suffix(".json.bak")
try:
# 先备份原文件
import shutil
shutil.copy2(file_path, backup_path)
print(f"已创建备份: {backup_path}")
# 写入新数据
with open(file_path, 'w', encoding='utf-8') as f:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"修改完成!")
@ -78,15 +80,21 @@ def modify_pinyin_statistics(file_path: Path) -> None:
# 使用示例
if __name__ == "__main__":
# 假设你的JSON文件在当前目录
json_file = Path("./src/model/assets/pinyin_char_statistics.json")
# JSON文件相对于项目根目录
json_file = (
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
# 执行修改
modify_pinyin_statistics(json_file)
# 验证修改:读取并显示前几条记录
print("\n验证前5条记录:")
with open(json_file, 'r', encoding='utf-8') as f:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
for i in range(5):

View File

@ -8,6 +8,7 @@ import math
import sys
from pathlib import Path
def ascii_histogram(data, bins=20, width=60):
"""Create ASCII histogram"""
if not data:
@ -20,12 +21,16 @@ def ascii_histogram(data, bins=20, width=60):
if max_val / min_val > 1000:
log_min = math.log10(min_val) if min_val > 0 else 0
log_max = math.log10(max_val)
bin_edges = [10**(log_min + i*(log_max-log_min)/bins) for i in range(bins+1)]
bin_edges = [
10 ** (log_min + i * (log_max - log_min) / bins) for i in range(bins + 1)
]
hist = [0] * bins
for val in data:
if val > 0:
log_val = math.log10(val)
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * bins), bins-1)
bin_idx = min(
int((log_val - log_min) / (log_max - log_min) * bins), bins - 1
)
hist[bin_idx] += 1
bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i + 1]:.1e}" for i in range(bins)]
else:
@ -42,18 +47,27 @@ def ascii_histogram(data, bins=20, width=60):
for i in range(bins):
if hist[i] == 0:
continue
bar = '#' * int(hist[i] / max_count * width)
bar = "#" * int(hist[i] / max_count * width)
result.append(f"{bin_labels[i]:20} | {bar} {hist[i]}")
return "\n".join(result)
def main():
json_path = Path("src/model/assets/pinyin_char_statistics.json")
with open(json_path, 'r', encoding='utf-8') as f:
json_path = (
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"
/ "pinyin_char_statistics.json"
)
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
pairs = data.get('pairs', {})
counts = [pair.get('count', 0) for pair in pairs.values() if pair.get('count') is not None]
pairs = data.get("pairs", {})
counts = [
pair.get("count", 0) for pair in pairs.values() if pair.get("count") is not None
]
print("FREQUENCY DISTRIBUTION ANALYSIS")
print("=" * 60)
@ -70,7 +84,7 @@ def main():
for rank in range(1, max_rank + 1):
freq = counts_sorted_desc[rank - 1]
bar_length = int(math.log(freq) / math.log(max_freq) * 40)
bar = '#' * bar_length
bar = "#" * bar_length
print(f"Rank {rank:3}: {freq:12} {bar}")
# ID vs Frequency plot (sampled)
@ -78,8 +92,8 @@ def main():
# Build ID to count mapping
id_to_count = {}
for key, pair in pairs.items():
char_id = pair.get('id')
count = pair.get('count')
char_id = pair.get("id")
count = pair.get("count")
if char_id is not None and count is not None:
id_to_count[char_id] = count
@ -91,7 +105,7 @@ def main():
if id in id_to_count:
freq = id_to_count[id]
log_freq = math.log10(freq) if freq > 0 else 0
bar = '#' * int(log_freq / math.log10(max_freq) * 40)
bar = "#" * int(log_freq / math.log10(max_freq) * 40)
print(f"{id:6} {freq:10} {log_freq:6.2f} {bar}")
# Zipf's law fit
@ -106,12 +120,15 @@ def main():
# Check if product is roughly constant
products = [(rank + 1) * counts_sorted_desc[rank] for rank in range(10)]
avg_product = sum(products) / len(products)
std_product = math.sqrt(sum((p - avg_product)**2 for p in products) / len(products))
std_product = math.sqrt(
sum((p - avg_product) ** 2 for p in products) / len(products)
)
print(f" Average product (ranks 2-11): {avg_product:.3e} ± {std_product:.3e}")
print(f" Coefficient of variation: {std_product / avg_product * 100:.1f}%")
# Frequency spectrum
from collections import Counter
freq_counter = Counter(counts)
print("\n5. Frequency Spectrum (how many entries have each frequency):")
print(" Frequency Count Cumulative")
@ -142,5 +159,6 @@ def main():
f.write(f"{id},{id_to_count[id]}\n")
print("\nData saved to id_vs_freq.csv for external plotting")
if __name__ == "__main__":
main()

View File

@ -5,10 +5,9 @@ warnings.filterwarnings("ignore", message=".*pkg_resources.*")
import jieba
import math
import random
import re
from importlib.resources import files
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Set, Tuple
import numpy as np
import torch
@ -21,7 +20,6 @@ from torch.utils.data import IterableDataset
from .query import QueryEngine
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
CHAR_TO_ID["`"] = 27 # 显式添加反引号
@ -76,6 +74,8 @@ class PinyinInputDataset(IterableDataset):
merge_max_total_chars: int = 6,
low_freq_repeat: float = 50.0,
high_freq_repeat: float = 0.1,
data_kwargs: Optional[Dict] = None,
target_labels: Optional[Set[int]] = None,
):
# 频率调整参数 - 幂律平滑方案
self.min_freq = 109
@ -88,6 +88,9 @@ class PinyinInputDataset(IterableDataset):
self.merge_max_short_words = merge_max_short_words
self.merge_max_total_chars = merge_max_total_chars
self.data_kwargs = data_kwargs or {}
self.target_labels = target_labels
jieba.initialize()
self.tokenizer = AutoTokenizer.from_pretrained(
@ -98,7 +101,9 @@ class PinyinInputDataset(IterableDataset):
self.max_iter_length = max_iter_length
self.max_seq_length = max_seq_length
self.text_field = text_field
self.dataset = load_dataset(data_path, split="train", streaming=True)
load_kwargs = {"split": "train", "streaming": True}
load_kwargs.update(self.data_kwargs)
self.dataset = load_dataset(data_path, **load_kwargs)
self.max_workers = max_workers
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
self.shuffle_buffer_size = shuffle_buffer_size
@ -155,12 +160,13 @@ class PinyinInputDataset(IterableDataset):
# 生成对应文本的拼音
def generate_pinyin(self, text: str) -> List[str]:
"""
流式处理单条文本转换为拼音列表
将文本转换为拼音列表对整段文本调用 lazy_pinyin
利用 errors 回调确保一一对应对生僻字从 QueryEngine 回退
特性
1. 严格一一对应len(result) == len(text)
2. 高多音字准确率利用 pypinyin 内部的词语分词能力
3. 高性能预分配内存无多余对象创建
2. pypinyin 不认识的生僻字回退到 QueryEngine 最高频读音
3. 非汉字字符原样占位
Args:
text: 输入字符串
@ -171,40 +177,36 @@ class PinyinInputDataset(IterableDataset):
if not text:
return []
text_len = len(text)
# 2. 预分配结果列表,初始化占位符。
# 使用 None 或空字符串均可,这里用空字符串方便后续判断
result: List[str] = [""] * text_len
# 3. 遍历所有连续汉字片段
for match in _HANZI_RE.finditer(text):
start_idx = match.start()
hanzi_segment = match.group()
# 4. 核心转换:利用 pypinyin 的分词能力处理该片段
# style=Style.NORMAL 获取不带声调的拼音
pinyin_list = lazy_pinyin(hanzi_segment)
# 5. 健壮性兜底:
# 正常情况下pypinyin 返回的拼音数应等于汉字数。
# 若不等(极罕见,如遇到特殊 Unicode 标点被误判为汉字),降级为单字转换
if len(pinyin_list) != len(hanzi_segment):
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
# 6. 直接通过索引填充到预分配的位置
# 这比 list slicing assignment (result[start:end] = pinyin_list) 略快且更直观
for i, py in enumerate(pinyin_list):
result[start_idx + i] = py
# 7. 填充非汉字字符
# 遍历原文,如果 result 对应位置为空,则填入原字符
# 注意:对于纯汉字文本,这一步很快;对于混合文本,这是必要的
for i, char in enumerate(text):
if not result[i]:
result[i] = char
def _fallback(chars):
# lazy_pinyin 会把连续无拼音的字符聚合成一个字符串传入,
# 必须逐字符处理,确保返回列表长度与输入字符数一致。
result = []
for char in chars:
if self.query_engine.is_chinese_char(char):
ids = self.query_engine.query_by_char(char, limit=1)
if ids:
result.append(ids[0][1])
else:
result.append(char)
else:
result.append(char)
return result
pinyin_list = lazy_pinyin(text, errors=_fallback)
# 防御性校验:若长度仍不匹配(极罕见),逐字回退
if len(pinyin_list) != len(text):
logger.warning(
f"pinyin length mismatch: text_len={len(text)}, "
f"pinyin_len={len(pinyin_list)}, text={text[:50]!r}"
)
pinyin_list = []
for c in text:
result = lazy_pinyin(c, errors=_fallback)
pinyin_list.append(result[0] if result else c)
return pinyin_list
def get_mask_pinyin(
self, text: str, pinyin_list: List[str]
) -> Tuple[int, List[str]]:
@ -243,51 +245,61 @@ class PinyinInputDataset(IterableDataset):
pinyin_ids = pinyin_ids[:24]
return torch.tensor(pinyin_ids, dtype=torch.long)
def _add_word_samples(
def _build_single_sample(
self,
batch_samples: list,
labels: list,
encoded: dict,
part4: str,
part1: str,
part3: str,
pinyin_str: str,
label: int,
history: list,
text: str,
word_start: int,
word_end: int,
part2: str,
pinyin_ids: torch.Tensor,
) -> list:
for label_idx, label in enumerate(labels):
base_repeats = self.adjust_frequency(self.sample_freqs.get(label, 0))
if base_repeats == 0:
continue
weight = (
self._history_weights[label_idx]
if label_idx < len(self._history_weights)
else 3.0
words: list,
) -> dict:
"""构造单条样本,每次调用都会重新随机采样上下文"""
# part1 长度:高斯分布 N(36, 6^2),截断 [0, min(48, word_start)]
part1_len = min(max(int(random.gauss(36, 6)), 0), 48, word_start)
part1 = text[word_start - part1_len : word_start]
# part3每次重新 roll
part3 = ""
if random.random() > 0.7:
part3 = text[word_end : word_end + random.randint(1, 16)]
# part4每次重新 roll
part4 = ""
if random.random() > 0.7 and words:
num_words = random.randint(1, 3)
selected_words = random.sample(words, min(num_words, len(words)))
part4 = "|".join(selected_words)
encoded = self.tokenizer(
f"{part4}|{part1}",
part3,
max_length=self.max_seq_length,
truncation=True,
return_token_type_ids=True,
)
repeats = max(1, int(base_repeats * weight))
history = labels[:label_idx]
if len(history) > 8:
history = history[-8:]
else:
history.extend([0] * (8 - len(history)))
# 确保 history 长度为 8
hist = list(history)
if len(hist) > 8:
hist = hist[-8:]
while len(hist) < 8:
hist.append(0)
sample_dict = {
return {
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
"token_type_ids": torch.tensor(
encoded["token_type_ids"], dtype=torch.long
),
"attention_mask": torch.tensor(
encoded["attention_mask"], dtype=torch.long
),
"token_type_ids": torch.tensor(encoded["token_type_ids"], dtype=torch.long),
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
"label": torch.tensor([label], dtype=torch.long),
"history_slot_ids": torch.tensor(history, dtype=torch.long),
"history_slot_ids": torch.tensor(hist, dtype=torch.long),
"prefix": f"{part4}^{part1}",
"suffix": part3,
"pinyin": pinyin_str,
"pinyin": part2,
"pinyin_ids": pinyin_ids,
}
batch_samples.extend([sample_dict] * repeats)
return batch_samples
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
@ -436,42 +448,43 @@ class PinyinInputDataset(IterableDataset):
if not should_break and random.random() <= 0.1:
labels.append(0)
# part1: 词起点前的文本(所有样本共享)
part1 = text[max(0, word_start - 48) : word_start]
# part3: 词后文本
part3 = ""
if random.random() > 0.7:
part3 = text[word_end : word_end + random.randint(1, 16)]
# part4: 词提示
part4 = ""
if random.random() > 0.7:
num_words = random.randint(1, 3)
if words:
selected_words = random.sample(
words, min(num_words, len(words))
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
processed_history = []
for label_idx, label in enumerate(labels):
base_repeats = self.adjust_frequency(
self.sample_freqs.get(label, 0)
)
part4 = "|".join(selected_words)
if base_repeats == 0:
processed_history.append(label)
continue
if (
self.target_labels is not None
and label not in self.target_labels
):
processed_history.append(label)
continue
encoded = self.tokenizer(
f"{part4}|{part1}",
part3,
max_length=self.max_seq_length,
truncation=True,
return_token_type_ids=True,
weight = (
self._history_weights[label_idx]
if label_idx < len(self._history_weights)
else 3.0
)
repeats = max(1, int(base_repeats * weight))
batch_samples = self._add_word_samples(
batch_samples,
labels,
encoded,
part4,
part1,
part3,
part2,
pinyin_ids,
for _ in range(repeats):
sample = self._build_single_sample(
label=label,
history=processed_history,
text=text,
word_start=word_start,
word_end=word_end,
part2=part2,
pinyin_ids=pinyin_ids,
words=words,
)
batch_samples.append(sample)
processed_history.append(label)
# ========== Phase 2: 破词续接 ==========
if should_break and break_pos < word_len_chars:
@ -533,33 +546,44 @@ class PinyinInputDataset(IterableDataset):
if random.random() <= 0.1:
cont_labels.append(0)
# part1_cont: 包含已确认前缀的上下文
part1_cont = text[max(0, cont_start - 48) : cont_start]
# part3_cont: 续接目标后的文本
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
cont_processed_history = []
cont_end = cont_positions[-1] + 1
part3_cont = ""
if random.random() > 0.7:
part3_cont = text[cont_end : cont_end + random.randint(1, 16)]
encoded_cont = self.tokenizer(
f"{part4}|{part1_cont}",
part3_cont,
max_length=self.max_seq_length,
truncation=True,
return_token_type_ids=True,
for label_idx, label in enumerate(cont_labels):
base_repeats = self.adjust_frequency(
self.sample_freqs.get(label, 0)
)
if base_repeats == 0:
cont_processed_history.append(label)
continue
if (
self.target_labels is not None
and label not in self.target_labels
):
cont_processed_history.append(label)
continue
batch_samples = self._add_word_samples(
batch_samples,
cont_labels,
encoded_cont,
part4,
part1_cont,
part3_cont,
part2_cont,
pinyin_ids_cont,
weight = (
self._history_weights[label_idx]
if label_idx < len(self._history_weights)
else 3.0
)
repeats = max(1, int(base_repeats * weight))
for _ in range(repeats):
sample = self._build_single_sample(
label=label,
history=cont_processed_history,
text=text,
word_start=cont_start,
word_end=cont_end,
part2=part2_cont,
pinyin_ids=pinyin_ids_cont,
words=words,
)
batch_samples.append(sample)
cont_processed_history.append(label)
idx = merge_end_idx

View File

@ -4,6 +4,7 @@
步骤 1: find-missing 扫描已预处理数据找出从未出现的 label ID输出 JSON
步骤 2: generate-template 根据 JSON 生成 JSONL 占位文件供用户手动填入包含缺失字的真实文本
步骤 3: preprocess-supplement 将填好的 JSONL 文本预处理为 .npz 分片输出到独立目录
用法
python -m model.supplement_missing find-missing \
@ -13,6 +14,12 @@
python -m model.supplement_missing generate-template \
--missing-chars missing_chars.json \
--output supplement_texts.jsonl
python -m model.supplement_missing preprocess-supplement \
--missing-chars missing_chars.json \
--supplement-texts supplement_texts.jsonl \
--output-dir ./preprocessed/supplement \
--num-samples 100000
"""
import argparse
@ -21,12 +28,17 @@ from pathlib import Path
from typing import Set
import numpy as np
import torch
from loguru import logger
from rich.console import Console
from rich.table import Table
from torch.utils.data import DataLoader
from tqdm import tqdm
from .dataset import PinyinInputDataset
from .preprocess import collect_samples
from .query import QueryEngine
from .trainer import preprocess_collate_fn, worker_init_fn
def scan_labels(preprocessed_dir: Path) -> Set[int]:
@ -175,6 +187,107 @@ def cmd_generate_template(args):
)
def cmd_preprocess_supplement(args):
console = Console()
# 加载缺失字符
missing_path = Path(args.missing_chars)
if not missing_path.exists():
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
return
with open(missing_path, "r", encoding="utf-8") as f:
data = json.load(f)
missing_chars = data.get("missing_chars", [])
if not missing_chars:
console.print("[bold green]没有缺失字符,无需处理[/bold green]")
return
target_labels = {entry["id"] for entry in missing_chars}
target_labels.add(0) # 包含 EOS
# 解析参数
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
length_weights = {
int(k): int(v)
for k, v in (item.split(":") for item in args.length_weights.split(","))
}
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
max_iter = args.num_samples * 5
num_workers = args.num_workers
console.print("[bold cyan]=== 补充数据预处理 ===[/bold cyan]")
console.print(f"补充文本: {args.supplement_texts}")
console.print(f"缺失字符数: {len(missing_chars)}")
console.print(f"目标样本: {args.num_samples:,}")
console.print(f"输出目录: {output_dir}")
console.print(f"Worker 数: {num_workers}")
console.print()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
console.print("[bold cyan]创建补充数据集...[/bold cyan]")
dataset = PinyinInputDataset(
data_path="json",
max_workers=num_workers,
max_iter_length=max_iter,
max_seq_length=args.max_seq_length,
text_field="text",
py_style_weight=py_style_weight,
shuffle_buffer_size=100,
length_weights=length_weights,
data_kwargs={
"data_files": args.supplement_texts,
"streaming": False,
},
target_labels=target_labels,
)
dataloader_kwargs = {
"batch_size": args.batch_size,
"num_workers": num_workers,
"pin_memory": False,
"worker_init_fn": worker_init_fn,
"collate_fn": preprocess_collate_fn(args.max_seq_length),
}
if num_workers > 0:
dataloader_kwargs["prefetch_factor"] = 2
dataloader_kwargs["persistent_workers"] = True
dataloader = DataLoader(dataset, **dataloader_kwargs)
logger.info("开始收集补充数据...")
count = collect_samples(
dataloader,
args.num_samples,
output_dir,
"supplement",
args.max_seq_length,
args.shard_size,
)
if count < args.num_samples:
logger.warning(f"补充样本不足: 目标 {args.num_samples}, 实际 {count}")
console.print("\n[bold green]=== 补充预处理完成 ===[/bold green]")
console.print(f"生成样本: {count:,}")
console.print(f"输出目录: {output_dir}")
total_size = sum(
f.stat().st_size for f in output_dir.iterdir() if f.suffix == ".npz"
)
console.print(f"总大小: {total_size / (1024**3):.2f} GB (compressed)")
console.print()
console.print(
"[bold yellow]提示[/bold yellow]: 请检查补充数据质量,清洗无误后手动将 shard_*.npz 合并到 train/ 目录并更新 metadata.json"
)
def main():
parser = argparse.ArgumentParser(
description="缺失字符补充工具",
@ -183,6 +296,7 @@ def main():
子命令
find-missing 扫描已预处理数据找出从未出现的 label ID
generate-template 根据缺失字符 JSON 生成 JSONL 占位文件
preprocess-supplement 将填好的 JSONL 预处理为 .npz 分片独立目录
示例
python -m model.supplement_missing find-missing \\
@ -192,6 +306,12 @@ def main():
python -m model.supplement_missing generate-template \\
--missing-chars missing_chars.json \\
--output supplement_texts.jsonl
python -m model.supplement_missing preprocess-supplement \\
--missing-chars missing_chars.json \\
--supplement-texts supplement_texts.jsonl \\
--output-dir ./preprocessed/supplement \\
--num-samples 100000
""",
)
subparsers = parser.add_subparsers(dest="command", help="子命令")
@ -232,6 +352,77 @@ def main():
help="每个缺失字符生成的模板条数(默认: 3",
)
# preprocess-supplement
p_pre = subparsers.add_parser(
"preprocess-supplement", help="将 JSONL 预处理为 .npz 分片"
)
p_pre.add_argument(
"--missing-chars",
type=str,
required=True,
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
)
p_pre.add_argument(
"--supplement-texts",
type=str,
required=True,
help="已填写的补充文本 JSONL 文件路径",
)
p_pre.add_argument(
"--output-dir",
type=str,
required=True,
help="输出目录(独立目录,不会覆盖已有数据)",
)
p_pre.add_argument(
"--num-samples",
type=int,
required=True,
help="目标样本数量",
)
p_pre.add_argument(
"--batch-size",
type=int,
default=128,
help="批大小(默认: 128",
)
p_pre.add_argument(
"--num-workers",
type=int,
default=0,
help="DataLoader worker 数量。本地 JSONL 小文件建议 0默认: 0",
)
p_pre.add_argument(
"--max-seq-length",
type=int,
default=128,
help="最大序列长度(默认: 128",
)
p_pre.add_argument(
"--seed",
type=int,
default=42,
help="随机种子(默认: 42",
)
p_pre.add_argument(
"--shard-size",
type=int,
default=5_000_000,
help="分片大小(样本数),控制内存峰值(默认: 500万",
)
p_pre.add_argument(
"--py-style-weight",
type=str,
default="9,2,1",
help="拼音风格权重(逗号分隔,默认: 9,2,1",
)
p_pre.add_argument(
"--length-weights",
type=str,
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
help="词长权重(默认: 1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
)
args = parser.parse_args()
if args.command is None:
@ -242,6 +433,8 @@ def main():
cmd_find_missing(args)
elif args.command == "generate-template":
cmd_generate_template(args)
elif args.command == "preprocess-supplement":
cmd_preprocess_supplement(args)
app = main

View File

@ -1,6 +1,7 @@
import os
import sys
sys.path.append("src")
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
import time
import torch
@ -26,7 +27,7 @@ from pypinyin.contrib.tone_convert import to_initials
from torch.utils.data import IterableDataset
tokenizer = AutoTokenizer.from_pretrained(
Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer"
Path(__file__).parent.parent / "src" / "model" / "assets" / "tokenizer"
)
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
@ -83,7 +84,9 @@ sample = {
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
checkpoint = torch.load("/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu")
checkpoint = torch.load(
"/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu"
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
@ -100,7 +103,7 @@ for k, v in sample.items():
start = time.time()
with torch.no_grad():
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
print(f"计算时长: {(time.time() - start) * 1000:4f}ms")
sort_res = sorted(
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
)

View File

@ -392,7 +392,13 @@ def check_compile_issues():
issues = []
# 检查 components.py 中的潜在问题
with open("src/model/components.py", "r") as f:
components_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"src",
"model",
"components.py",
)
with open(components_path, "r") as f:
content = f.read()
# 检查 float('-inf')

View File

@ -4,9 +4,13 @@
解决设备转换和权重加载问题
"""
import os
import sys
from pathlib import Path
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
@ -196,7 +200,7 @@ def test_id_mapping():
query_engine = QueryEngine()
stats_path = (
Path(__file__).parent
Path(__file__).parent.parent
/ "src"
/ "model"
/ "assets"

View File

@ -1,6 +1,7 @@
import os
import sys
sys.path.append("src")
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import time
@ -42,23 +43,49 @@ def worker_init_fn(worker_id: int) -> None:
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
自定义批处理函数将多个样本组合成一个batch
Args:
batch: 样本列表每个样本是一个字典
Returns:
批处理后的字典tensor字段已stack字符串字段保持为列表
自定义批处理函数将多个样本组合成一个batch
支持动态padding根据batch内最大序列长度进行padding
"""
# 处理tensor字段 - 使用squeeze去除多余的batch维度
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
input_ids_list = [item["input_ids"] for item in batch]
token_type_ids_list = [item["token_type_ids"] for item in batch]
attention_mask_list = [item["attention_mask"] for item in batch]
target_len = max(ids.shape[0] for ids in input_ids_list)
padded_input_ids = []
padded_token_type_ids = []
padded_attention_mask = []
for ids, tt_ids, mask in zip(
input_ids_list, token_type_ids_list, attention_mask_list
):
seq_len = ids.shape[0]
if seq_len < target_len:
pad_len = target_len - seq_len
padded_input_ids.append(
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
)
padded_token_type_ids.append(
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
)
padded_attention_mask.append(
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
)
elif seq_len > target_len:
padded_input_ids.append(ids[:target_len])
padded_token_type_ids.append(tt_ids[:target_len])
padded_attention_mask.append(mask[:target_len])
else:
padded_input_ids.append(ids)
padded_token_type_ids.append(tt_ids)
padded_attention_mask.append(mask)
input_ids = torch.stack(padded_input_ids)
token_type_ids = torch.stack(padded_token_type_ids)
attention_mask = torch.stack(padded_attention_mask)
labels = torch.stack([item["label"].squeeze(0) for item in batch])
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
# 字符串字段保持为列表
prefixes = [item["prefix"] for item in batch]
suffixes = [item["suffix"] for item in batch]
pinyins = [item["pinyin"] for item in batch]
@ -98,4 +125,3 @@ dataloader = DataLoader(
for i, shape in tqdm(enumerate(dataloader), total=1000000 / 512):
pass

View File

@ -6,8 +6,8 @@ import torch.nn as nn
from rich.console import Console
from torch.utils.data import DataLoader, Dataset
# 添加src目录到路径
sys.path.insert(0, str(Path(__file__).parent))
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.model.model import InputMethodEngine
from src.model.trainer import Trainer