feat(analyze_frequency): 添加拼音字符频率分析脚本

This commit is contained in:
songsenand 2026-04-22 22:05:15 +08:00
parent 1b7da9ddd4
commit 4ded2d656f
16 changed files with 43279 additions and 72 deletions

236
analyze_frequency.py Normal file
View File

@ -0,0 +1,236 @@
#!/usr/bin/env python3
"""
Analyze frequency distribution in pinyin_char_statistics.json
"""
import json
import sys
import os
import math
from collections import defaultdict
from pathlib import Path
def main():
# Path to the JSON file
json_path = Path("src/model/assets/pinyin_char_statistics.json")
if not json_path.exists():
print(f"Error: File not found: {json_path}")
sys.exit(1)
print(f"Loading {json_path}...")
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"Timestamp: {data.get('timestamp')}")
print(f"Total characters: {data.get('total_characters')}")
print(f"Total pinyins: {data.get('total_pinyins')}")
print(f"Valid input character count: {data.get('valid_input_character_count')}")
pairs = data.get('pairs', {})
print(f"Number of pairs: {len(pairs)}")
# Extract counts and IDs
counts = []
id_to_count = {}
char_to_count = {}
for key, pair in pairs.items():
try:
char_id = pair.get('id')
count = pair.get('count')
char = pair.get('char', '')
if char_id is not None and count is not None:
counts.append(count)
id_to_count[char_id] = count
if char:
char_to_count[char] = count
except (ValueError, TypeError) as e:
print(f"Warning: Could not parse pair {key}: {e}")
continue
if not counts:
print("No valid count data found.")
return
# Basic statistics
min_count = min(counts)
max_count = max(counts)
total_count = sum(counts)
mean_count = total_count / len(counts)
# Sort counts for percentiles
sorted_counts = sorted(counts)
n = len(sorted_counts)
# Percentiles
p10 = sorted_counts[int(0.1 * n)]
p25 = sorted_counts[int(0.25 * n)]
p50 = sorted_counts[int(0.5 * n)]
p75 = sorted_counts[int(0.75 * n)]
p90 = sorted_counts[int(0.9 * n)]
p99 = sorted_counts[int(0.99 * n)]
# Variance and std dev
variance = sum((x - mean_count) ** 2 for x in counts) / n
std_dev = math.sqrt(variance)
print("\n=== BASIC STATISTICS ===")
print(f"Min frequency: {min_count}")
print(f"Max frequency: {max_count}")
print(f"Mean frequency: {mean_count:.2f}")
print(f"Standard deviation: {std_dev:.2f}")
print(f"Total frequency sum: {total_count}")
print(f"Number of entries: {n}")
print("\n=== PERCENTILES ===")
print(f"10th percentile: {p10}")
print(f"25th percentile: {p25}")
print(f"50th percentile (median): {p50}")
print(f"75th percentile: {p75}")
print(f"90th percentile: {p90}")
print(f"99th percentile: {p99}")
# Find IDs with min and max counts
min_ids = [id for id, count in id_to_count.items() if count == min_count]
max_ids = [id for id, count in id_to_count.items() if count == max_count]
print(f"\nIDs with min frequency ({min_count}): {min_ids}")
print(f"IDs with max frequency ({max_count}): {max_ids}")
# Check if IDs are assigned in frequency order
# Compute correlation between ID and count
ids = list(id_to_count.keys())
id_counts = [id_to_count[id] for id in ids]
# Sort by ID and check if counts are decreasing
sorted_by_id = sorted(ids)
counts_by_id = [id_to_count[id] for id in sorted_by_id]
# Calculate monotonicity: count of times count decreases as ID increases
decreases = 0
increases = 0
for i in range(1, len(counts_by_id)):
if counts_by_id[i] < counts_by_id[i-1]:
decreases += 1
elif counts_by_id[i] > counts_by_id[i-1]:
increases += 1
print(f"\n=== ID ORDER ANALYSIS ===")
print(f"Total pairs: {len(counts_by_id)}")
print(f"Decreases as ID increases: {decreases} times")
print(f"Increases as ID increases: {increases} times")
print(f"Percentage decreasing: {decreases/(len(counts_by_id)-1)*100:.2f}%")
# Check if IDs are roughly sorted by frequency
# Compute Spearman rank correlation (simplified)
sorted_by_count = sorted(ids, key=lambda x: id_to_count[x], reverse=True)
rank_by_id = {id: i for i, id in enumerate(sorted_by_id)}
rank_by_count = {id: i for i, id in enumerate(sorted_by_count)}
# Average rank difference
rank_diffs = [abs(rank_by_id[id] - rank_by_count[id]) for id in ids]
avg_rank_diff = sum(rank_diffs) / len(rank_diffs)
max_rank_diff = max(rank_diffs)
print(f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}")
print(f"Maximum rank difference: {max_rank_diff}")
# Analyze specific ID range 5000-5500
print("\n=== ANALYSIS OF ID RANGE 5000-5500 ===")
range_counts = []
range_ids = []
for id in range(5000, 5501):
if id in id_to_count:
range_counts.append(id_to_count[id])
range_ids.append(id)
if range_counts:
range_min = min(range_counts)
range_max = max(range_counts)
range_mean = sum(range_counts) / len(range_counts)
range_sorted = sorted(range_counts)
range_n = len(range_counts)
range_p10 = range_sorted[int(0.1 * range_n)] if range_n > 0 else 0
range_p50 = range_sorted[int(0.5 * range_n)] if range_n > 0 else 0
range_p90 = range_sorted[int(0.9 * range_n)] if range_n > 0 else 0
print(f"IDs in range 5000-5500: {len(range_counts)}")
print(f"Min frequency in range: {range_min}")
print(f"Max frequency in range: {range_max}")
print(f"Mean frequency in range: {range_mean:.2f}")
print(f"10th percentile in range: {range_p10}")
print(f"50th percentile in range: {range_p50}")
print(f"90th percentile in range: {range_p90}")
# Find IDs with min frequency in this range
min_in_range_ids = [id for id in range_ids if id_to_count[id] == range_min]
print(f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}")
else:
print("No IDs found in range 5000-5500")
# Histogram of frequencies (log bins)
print("\n=== FREQUENCY DISTRIBUTION (LOG BINS) ===")
if max_count > 0:
log_min = math.log10(min_count) if min_count > 0 else 0
log_max = math.log10(max_count)
num_bins = 20
bin_edges = [10**(log_min + i*(log_max-log_min)/num_bins) for i in range(num_bins+1)]
hist = [0] * num_bins
for count in counts:
if count > 0:
log_val = math.log10(count)
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * num_bins), num_bins-1)
hist[bin_idx] += 1
print("Log-scale histogram (count range -> frequency count):")
for i in range(num_bins):
if hist[i] > 0:
lower = bin_edges[i]
upper = bin_edges[i+1]
print(f" {lower:.2e} - {upper:.2e}: {hist[i]} entries")
# Check for zero or near-zero frequencies
zero_count = sum(1 for c in counts if c == 0)
low_count = sum(1 for c in counts if 0 < c <= 10)
very_low_count = sum(1 for c in counts if 0 < c <= 100)
print(f"\n=== LOW FREQUENCY ANALYSIS ===")
print(f"Entries with zero frequency: {zero_count}")
print(f"Entries with frequency <= 10: {low_count}")
print(f"Entries with frequency <= 100: {very_low_count}")
# Find the actual min frequency (excluding zeros if any)
non_zero_counts = [c for c in counts if c > 0]
if non_zero_counts:
actual_min = min(non_zero_counts)
print(f"Actual min frequency (non-zero): {actual_min}")
actual_min_ids = [id for id, count in id_to_count.items() if count == actual_min]
print(f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}")
# Summary for smoothing algorithm design
print("\n=== SUMMARY FOR SMOOTHING ALGORITHM DESIGN ===")
print(f"Frequency range spans {max_count/min_count if min_count>0 else 'inf'}:1 ratio")
print(f"Most entries ({p50}) have frequency around {p50}")
print(f"Top 10% of entries have frequency > {p90}")
print(f"Bottom 10% of entries have frequency < {p10}")
print(f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency")
# Save detailed data for further analysis
output_file = "frequency_analysis_results.txt"
with open(output_file, 'w', encoding='utf-8') as f:
f.write("Frequency Analysis Results\n")
f.write("="*50 + "\n")
f.write(f"Min frequency: {min_count}\n")
f.write(f"Max frequency: {max_count}\n")
f.write(f"Mean frequency: {mean_count:.2f}\n")
f.write(f"Standard deviation: {std_dev:.2f}\n")
f.write(f"10th percentile: {p10}\n")
f.write(f"50th percentile: {p50}\n")
f.write(f"90th percentile: {p90}\n")
f.write(f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n")
f.write(f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n")
print(f"\nDetailed results saved to {output_file}")
if __name__ == "__main__":
main()

115
analyze_range.py Normal file
View File

@ -0,0 +1,115 @@
#!/usr/bin/env python3
"""
Analyze specific ID ranges in pinyin_char_statistics.json
"""
import json
import sys
from pathlib import Path
def main():
json_path = Path("src/model/assets/pinyin_char_statistics.json")
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
pairs = data.get('pairs', {})
# Build ID to count mapping
id_to_count = {}
for key, pair in pairs.items():
char_id = pair.get('id')
count = pair.get('count')
if char_id is not None and count is not None:
id_to_count[char_id] = count
# Analyze range 5000-5500 in detail
print("ID range 5000-5500 detailed analysis:")
print("ID\tCount\tChar\tPinyin")
range_data = []
for id in range(5000, 5501):
if id in id_to_count:
# Find the pair to get char and pinyin
for key, pair in pairs.items():
if pair.get('id') == id:
char = pair.get('char', '')
pinyin = pair.get('pinyin', '')
count = pair.get('count', 0)
range_data.append((id, count, char, pinyin))
if id % 100 == 0: # Print every 100th for overview
print(f"{id}\t{count}\t{char}\t{pinyin}")
break
# Print min and max in range
if range_data:
min_item = min(range_data, key=lambda x: x[1])
max_item = max(range_data, key=lambda x: x[1])
print(f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'")
print(f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'")
# Check if frequencies are monotonic in this range
counts = [item[1] for item in range_data]
increasing = all(counts[i] <= counts[i+1] for i in range(len(counts)-1))
decreasing = all(counts[i] >= counts[i+1] for i in range(len(counts)-1))
print(f"Monotonic in range: increasing={increasing}, decreasing={decreasing}")
# Check for frequency plateaus
from collections import Counter
freq_count = Counter(counts)
most_common = freq_count.most_common(5)
print(f"Most common frequencies in range: {most_common}")
# Analyze the tail (IDs with frequency 1)
print("\n\nAnalysis of frequency=1 entries:")
freq_one_ids = [id for id, count in id_to_count.items() if count == 1]
print(f"Number of entries with frequency=1: {len(freq_one_ids)}")
if freq_one_ids:
print(f"ID range of frequency=1: {min(freq_one_ids)} to {max(freq_one_ids)}")
print(f"First 10 IDs: {freq_one_ids[:10]}")
print(f"Last 10 IDs: {freq_one_ids[-10:]}")
# Check if they're contiguous
sorted_ids = sorted(freq_one_ids)
contiguous = all(sorted_ids[i] + 1 == sorted_ids[i+1] for i in range(len(sorted_ids)-1))
print(f"Are they contiguous IDs? {contiguous}")
# Sample some characters
print("\nSample characters with frequency=1:")
sample_count = 0
for key, pair in pairs.items():
if pair.get('count') == 1 and sample_count < 10:
print(f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'")
sample_count += 1
# Check overall ID-frequency ordering
print("\n\nOverall ID-frequency ordering analysis:")
all_ids = sorted(id_to_count.keys())
all_counts = [id_to_count[id] for id in all_ids]
# Count monotonic segments
non_increasing_segments = 0
current_segment_length = 1
for i in range(1, len(all_counts)):
if all_counts[i] <= all_counts[i-1]:
current_segment_length += 1
else:
if current_segment_length > 1:
non_increasing_segments += 1
current_segment_length = 1
if current_segment_length > 1:
non_increasing_segments += 1
print(f"Total IDs: {len(all_ids)}")
print(f"Non-increasing segments: {non_increasing_segments}")
# Check for frequency plateaus overall
from collections import Counter
overall_freq_count = Counter(all_counts)
plateaus = [(freq, count) for freq, count in overall_freq_count.items() if count > 1]
plateaus_sorted = sorted(plateaus, key=lambda x: x[1], reverse=True)[:10]
print(f"Top 10 frequency plateaus (freq: count of IDs sharing that freq):")
for freq, count in plateaus_sorted:
print(f" {freq}: {count} IDs")
if __name__ == "__main__":
main()

177
comprehensive_analysis.py Normal file
View File

@ -0,0 +1,177 @@
#!/usr/bin/env python3
"""
Comprehensive frequency distribution analysis
"""
import json
import sys
import math
from collections import Counter
from pathlib import Path
def main():
json_path = Path("src/model/assets/pinyin_char_statistics.json")
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
pairs = data.get('pairs', {})
# Extract counts
counts = []
for key, pair in pairs.items():
count = pair.get('count')
if count is not None:
counts.append(count)
n = len(counts)
print(f"Total entries: {n}")
# Sort descending for rank-frequency analysis
counts_sorted_desc = sorted(counts, reverse=True)
# Basic statistics
min_count = min(counts)
max_count = max(counts)
mean_count = sum(counts) / n
# Percentiles
percentiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
print("\n=== PERCENTILE DISTRIBUTION ===")
for p in percentiles:
idx = int(p * n)
value = counts_sorted_desc[idx]
print(f"{p*100:5.1f}%: {value:>12} (rank ~{idx})")
# Cumulative distribution
print("\n=== CUMULATIVE DISTRIBUTION ===")
thresholds = [1, 2, 3, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 500000, 1000000, 5000000, 10000000, 50000000, 100000000, 500000000]
for thresh in thresholds:
if thresh > max_count:
break
below = sum(1 for c in counts if c <= thresh)
above = sum(1 for c in counts if c >= thresh)
print(f"Count <= {thresh:10}: {below:6} entries ({below/n*100:5.1f}%)")
# print(f"Count >= {thresh:10}: {above:6} entries ({above/n*100:5.1f}%)")
# Check min_count=109 parameter
print("\n=== ANALYSIS OF THRESHOLD 109 ===")
below_109 = sum(1 for c in counts if c < 109)
at_or_above_109 = sum(1 for c in counts if c >= 109)
print(f"Entries with count < 109: {below_109} ({below_109/n*100:.1f}%)")
print(f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109/n*100:.1f}%)")
# If 109 is a threshold, what's the actual min among those >= 109?
counts_ge_109 = [c for c in counts if c >= 109]
if counts_ge_109:
actual_min_ge_109 = min(counts_ge_109)
print(f"Actual min frequency among those >= 109: {actual_min_ge_109}")
# Rank-frequency analysis (Zipf's law)
print("\n=== RANK-FREQUENCY ANALYSIS (Top 100) ===")
print("Rank\tFrequency\tlog(rank)\tlog(freq)")
for rank in range(1, 101):
freq = counts_sorted_desc[rank-1]
print(f"{rank}\t{freq}\t{math.log(rank):.3f}\t{math.log(freq):.3f}")
# Frequency spectrum (how many distinct frequencies)
freq_counter = Counter(counts)
print(f"\n=== FREQUENCY SPECTRUM ===")
print(f"Distinct frequency values: {len(freq_counter)}")
# Most common frequencies
print("\nTop 20 most common frequencies (plateau sizes):")
for freq, freq_count in freq_counter.most_common(20):
print(f" Frequency {freq}: {freq_count} entries")
# Analyze ID ranges
print("\n=== ID RANGE ANALYSIS ===")
# Build ID to count mapping
id_to_count = {}
for key, pair in pairs.items():
char_id = pair.get('id')
count = pair.get('count')
if char_id is not None and count is not None:
id_to_count[char_id] = count
ranges = [
(0, 100, "Top 100 IDs"),
(100, 500, "IDs 100-500"),
(500, 1000, "IDs 500-1000"),
(1000, 2000, "IDs 1000-2000"),
(2000, 5000, "IDs 2000-5000"),
(5000, 5500, "IDs 5000-5500 (user mentioned)"),
(5500, 6000, "IDs 5500-6000"),
(10000, 10500, "IDs 10000-10500"),
(15000, 15500, "IDs 15000-15500"),
(19000, 19500, "IDs 19000-19500 (before freq=1)"),
(19499, 20647, "IDs with freq=1"),
]
for start, end, label in ranges:
range_counts = [id_to_count[id] for id in range(start, end) if id in id_to_count]
if range_counts:
min_c = min(range_counts)
max_c = max(range_counts)
mean_c = sum(range_counts) / len(range_counts)
median_c = sorted(range_counts)[len(range_counts)//2]
print(f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}")
# Check if IDs are perfectly sorted by frequency
print("\n=== ID ORDER VERIFICATION ===")
all_ids = sorted(id_to_count.keys())
all_counts = [id_to_count[id] for id in all_ids]
# Check for any violations of non-increasing order
violations = 0
for i in range(1, len(all_counts)):
if all_counts[i] > all_counts[i-1]:
violations += 1
if violations <= 5:
print(f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i-1]} (ID {all_ids[i-1]})")
print(f"Total violations of non-increasing order: {violations}")
# Check if equal frequencies are grouped together
print("\n=== FREQUENCY GROUPING ANALYSIS ===")
current_freq = None
group_start = None
group_sizes = []
for i, (id, count) in enumerate(zip(all_ids, all_counts)):
if count != current_freq:
if current_freq is not None:
group_sizes.append((current_freq, group_start, all_ids[i-1], i - group_start))
current_freq = count
group_start = i
# Last group
if current_freq is not None:
group_sizes.append((current_freq, group_start, all_ids[-1], len(all_ids) - group_start))
# Sort groups by size
group_sizes.sort(key=lambda x: x[3], reverse=True)
print("Top 10 largest frequency groups (plateaus):")
for freq, start_id_idx, end_id, size in group_sizes[:10]:
start_id = all_ids[start_id_idx]
print(f" Frequency {freq}: IDs {start_id}-{end_id} ({size} entries)")
# Summary for smoothing algorithm
print("\n=== SMOOTHING ALGORITHM IMPLICATIONS ===")
print("1. IDs are perfectly sorted by frequency (non-increasing).")
print(f"2. Frequency range: {min_count} to {max_count} (ratio {max_count/min_count:.1e}:1).")
print(f"3. {below_109} entries ({below_109/n*100:.1f}%) have frequency < 109.")
print(f"4. Median frequency: {counts_sorted_desc[n//2]}.")
print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9*n)]}.")
print(f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01*n)]}.")
print("7. Large frequency plateaus exist (many IDs share same frequency).")
print("8. Smoothing should handle extreme frequency ratios (1:5e8).")
# Save data for plotting
with open("rank_freq.csv", "w") as f:
f.write("rank,frequency\n")
for rank, freq in enumerate(counts_sorted_desc, 1):
f.write(f"{rank},{freq}\n")
print("\nRank-frequency data saved to rank_freq.csv")
if __name__ == "__main__":
main()

168
docs/WORD_BREAK_DESIGN.md Normal file
View File

@ -0,0 +1,168 @@
# 破词训练设计文档
## 背景
输入法用户在实际使用中通常是**逐词输入**的,而非逐字输入。例如输入"那边的特别漂亮的女孩是我的表姐"时,用户可能分词为:
```
那边 / 的 / 特别 / 漂亮 / 的 / 女孩 / 是 / 我 / 的 / 表姐
```
但为了增强模型的泛化能力,需要模拟用户**从词中间断开**的情况。例如用户可能只输入了"漂"就开始选字"亮"。
## 破词概念
### 术语定义
| 术语 | 说明 |
|------|------|
| 整词输入 | 用户输入完整词的拼音,如"piaoliang" |
| 破词输入 | 用户只输入词的部分拼音,如"piao" |
| 前缀 | 光标前的已确认文本 |
| 拼音 | 当前待选字的拼音(可能不完整) |
| 后缀 | 光标后的原文内容 |
### 场景示例
以词"漂亮"为例:
**整词模式90%概率):**
```
光标前: 那边的特别
拼音: piaoliang
预测: 漂 → 亮
```
**破词模式10%概率):**
```
光标前: 那边的特别漂
拼音: liang
预测: 亮
```
## 实现方案
### 分词策略
使用 jieba 分词器进行词语边界识别:
```python
import jieba
words = list(jieba.cut(text, HMM=False))
# "那边的特别漂亮的女孩是我的表姐。"
# → ['那边', '的', '特别', '漂亮', '的', '女孩', '是', '我', '的', '表姐', '。']
```
### 两阶段样本生成
每个词生成样本时分为两个阶段:
#### Phase 1前缀/整词阶段
- **整词90%**`prefix_positions = 整个词的所有字符`
- **破词前缀10%**`prefix_positions = 词的前 break_pos 个字符`
```python
if should_break:
break_pos = random.randint(1, word_len_chars - 1) # 随机破开位置
else:
break_pos = word_len_chars # 整词
```
#### Phase 2破词续接阶段仅当破词时
当破词发生时,从断点位置开始继续采样:
```python
if should_break and break_pos < word_len_chars:
cont_start = char_positions[break_pos]
# 从断点开始采样后续字符
target_len = random_choice(1-8) # 采样长度
cont_positions = [cont_start, ...] # 后续字符位置
```
### 样本结构
每个字符生成一个训练样本,包含:
| 字段 | 说明 | 示例 |
|------|------|------|
| `part1` (prefix) | 光标前文本 | "那边的特别漂" |
| `part2` (pinyin) | 当前字拼音 | "liang" |
| `part3` (suffix) | 光标后文本 | "亮的女孩是我的表姐" |
| `part4` | 专有词提示 | "漂亮\|特别" |
| `label` | 目标汉字ID | 1234 |
| `history_slot_ids` | 历史已确认字 | [0, 0, 0, 0, 0, 0, 0, 0] |
### 拼音增强策略
根据 `py_style_weight` 参数,拼音有以下三种形式:
| 形式 | 概率 | 示例 |
|------|------|------|
| 完整拼音 | 75% (9/12) | "piaoliang" |
| 仅声母 | 16.7% (2/12) | "pl" (通过 to_initials) |
| 仅首字母 | 8.3% (1/12) | "p" |
参数配置:`py_style_weight=(9, 2, 1)`
## 破词概率控制
### word_break_prob 参数
控制每个词从中间断开的概率,默认为 **10%**
```python
self.word_break_prob = 0.10 # 10%概率从词中间破开
```
### 破词位置分布
对于长度为 N 的词,破开位置 `break_pos` 的分布:
```python
break_pos = random.randint(1, N - 1)
```
- 2字词break_pos = 1100%在第1字后破开
- 3字词break_pos = 1 或 2各50%
- 4字词break_pos = 1, 2, 或 3各33%
## 数据分布预期
### 理想分布
| 类别 | 预期比例 |
|------|----------|
| 单字样本 | ~15% |
| 2字词整词 | ~30% |
| 3字词整词 | ~20% |
| 破词样本 | ~10% |
| 其他 | ~25% |
### 拼音不完整率
由于 `py_style_weight=(9, 2, 1)`
- 声母initials~16.7%
- 首字母:~8.3%
- **总计不完整**~25%
## 代码实现位置
主要实现文件:`src/model/dataset.py`
| 函数/类 | 行号 | 功能 |
|---------|------|------|
| `segment_text()` | ~30 | jieba分词 |
| `build_word_boundaries()` | ~35 | 建立词边界映射 |
| `PinyinInputDataset.__iter__()` | ~280 | 核心迭代逻辑 |
| `get_mask_pinyin()` | ~215 | 拼音加强处理 |
| `_add_word_samples()` | ~240 | 样本构建 |
## 注意事项
1. **破词仅针对多字词**:单字词(如"的"、“是”)不会破词
2. **破词保持语义完整**:破词后仍能根据上下文预测正确汉字
3. **历史槽位模拟逐步确认**:同一词内已确认的字会填入 `history_slot_ids`
4. **10% EOS标记**词尾有10%概率追加ID=0表示句子结束

View File

@ -0,0 +1,11 @@
Frequency Analysis Results
==================================================
Min frequency: 1
Max frequency: 494748360
Mean frequency: 560007.90
Standard deviation: 5730144.34
10th percentile: 3
50th percentile: 93
90th percentile: 331538
IDs in range 5000-5500 min: 5594
IDs in range 5000-5500 max: 9569

20648
id_vs_freq.csv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -30,6 +30,8 @@ dependencies = [
[project.scripts]
train-model = "model.trainer:app"
monitor-training = "model.monitor:app"
preprocess-model = "model.preprocess:main"
inspect-preprocessed = "model.inspect_preprocessed:main"
[tool.uv]
# 设置当前项目的默认索引源

20648
rank_freq.csv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -22,9 +22,26 @@ import torch
from torch.utils.data import DataLoader
from model.dataset import PinyinInputDataset
from model.query import QueryEngine
def analyze_label_distribution(dataset: PinyinInputDataset, sample_size: int = 10000):
_id2char_cache = {}
def get_char_by_id(query_engine: QueryEngine, char_id: int) -> str:
if char_id == 0:
return "<EOS>"
if char_id not in _id2char_cache:
info = query_engine.query_by_id(char_id)
_id2char_cache[char_id] = info.char if info else f"<ID:{char_id}>"
return _id2char_cache[char_id]
def analyze_label_distribution(
dataset: PinyinInputDataset,
sample_size: int = 10000,
query_engine: QueryEngine = None,
):
"""分析 label 在指定区间的分布"""
target_ranges = [
(0, 10),
@ -63,13 +80,23 @@ def analyze_label_distribution(dataset: PinyinInputDataset, sample_size: int = 1
in_target_range = True
if len(all_examples) < 200:
label_char = (
get_char_by_id(query_engine, label) if query_engine else f"<ID:{label}>"
)
history_chars = (
[get_char_by_id(query_engine, hid) for hid in history]
if query_engine
else history
)
all_examples.append(
{
"label": label,
"label_char": label_char,
"prefix": prefix,
"suffix": suffix,
"pinyin": pinyin,
"history": history,
"history_chars": history_chars,
"part4": part4,
}
)
@ -96,12 +123,13 @@ def analyze_label_distribution(dataset: PinyinInputDataset, sample_size: int = 1
random.shuffle(all_examples)
for idx, ex in enumerate(all_examples[:20], 1):
print(f"\n样本 {idx}:")
print(f" Label: {ex['label']}")
print(f" Label: {ex['label']} ({ex['label_char']})")
print(f" Part4: {ex['part4']}")
print(f" 光标前: {ex['prefix']}")
print(f" 光标后: {ex['suffix']}")
print(f" 拼音: {ex['pinyin']}")
print(f" 历史槽位: {ex['history']}")
print(f" 历史汉字: {ex['history_chars']}")
def main():
@ -127,7 +155,9 @@ def main():
help="数据集路径 (本地文件或HuggingFace路径)",
)
parser.add_argument("--sample_size", type=int, default=10000, help="采样大小")
parser.add_argument("--max_workers", type=int, default=-1, help="DataLoader workers")
parser.add_argument(
"--max_workers", type=int, default=-1, help="DataLoader workers"
)
args = parser.parse_args()
print(f"加载数据集: {args.data_path}")
@ -148,7 +178,11 @@ def main():
print(" 3. 如果是 HuggingFace 数据集,路径应该正确")
return
analyze_label_distribution(dataset, sample_size=args.sample_size)
query_engine = QueryEngine()
query_engine.load()
analyze_label_distribution(
dataset, sample_size=args.sample_size, query_engine=query_engine
)
if __name__ == "__main__":

View File

@ -62,10 +62,13 @@ class PinyinInputDataset(IterableDataset):
max_iter_length=1e6,
max_seq_length=128,
text_field: str = "text",
py_style_weight=(90, 2, 1),
py_style_weight=(9, 2, 1),
shuffle_buffer_size: int = 100000,
retention_ratio: float = 0.8,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
merge_short_words_prob: float = 0.5,
merge_max_short_words: int = 3,
merge_max_total_chars: int = 6,
):
# 频率调整参数 (可根据需要调整)
self.drop_start_freq = 10_000_000
@ -76,6 +79,9 @@ class PinyinInputDataset(IterableDataset):
self.word_break_prob = 0.10
self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
self.merge_short_words_prob = merge_short_words_prob
self.merge_max_short_words = merge_max_short_words
self.merge_max_total_chars = merge_max_total_chars
jieba.initialize()
@ -259,8 +265,10 @@ class PinyinInputDataset(IterableDataset):
repeats = max(1, int(base_repeats * weight))
history = labels[:label_idx]
len_h = len(history)
history.extend([0] * (8 - len_h))
if len(history) > 8:
history = history[-8:]
else:
history.extend([0] * (8 - len(history)))
sample_dict = {
"input_ids": encoded["input_ids"],
@ -291,7 +299,12 @@ class PinyinInputDataset(IterableDataset):
if worker_id >= num_workers:
return
worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
try:
worker_dataset = self.dataset.shard(
num_shards=num_workers, index=worker_id
)
except (IndexError, ValueError):
worker_dataset = self.dataset
total_quota = int(self.max_iter_length)
base_quota = total_quota // num_workers
@ -321,17 +334,58 @@ class PinyinInputDataset(IterableDataset):
word_boundaries = build_word_boundaries(words)
pinyin_list = self.generate_pinyin(text)
for word_start, word_end in word_boundaries:
idx = 0
while idx < len(word_boundaries):
word_start, word_end = word_boundaries[idx]
char_positions = []
for i in range(word_start, word_end):
if self.query_engine.is_chinese_char(text[i]):
char_positions.append(i)
if not char_positions:
idx += 1
continue
word_len_chars = len(char_positions)
merge_end_idx = idx + 1
if word_len_chars <= 2:
accumulated_positions = list(char_positions)
accumulated_count = 1
next_idx = idx + 1
while next_idx < len(word_boundaries):
ns, ne = word_boundaries[next_idx]
next_positions = []
for i in range(ns, ne):
if self.query_engine.is_chinese_char(text[i]):
next_positions.append(i)
next_len = len(next_positions)
if next_len == 0 or next_len > 2:
break
if (
len(accumulated_positions) + next_len
> self.merge_max_total_chars
):
break
if accumulated_count + 1 > self.merge_max_short_words:
break
if random.random() > self.merge_short_words_prob:
break
accumulated_positions.extend(next_positions)
accumulated_count += 1
next_idx += 1
if accumulated_count > 1:
char_positions = accumulated_positions
word_len_chars = len(char_positions)
merge_end_idx = next_idx
word_start = word_boundaries[idx][0]
word_end = word_boundaries[next_idx - 1][1]
should_break = (
word_len_chars > 1 and random.random() < self.word_break_prob
)
@ -364,6 +418,7 @@ class PinyinInputDataset(IterableDataset):
logger.error(
f"e: {e}, (text, pinyin): {prefix_text} - {prefix_pinyin}"
)
idx = merge_end_idx
continue
# 整词末尾 10% 概率追加 EOS破词前缀不加
@ -448,6 +503,7 @@ class PinyinInputDataset(IterableDataset):
logger.error(
f"e: {e}, (text, pinyin): {cont_text} - {cont_pinyin}"
)
idx = merge_end_idx
continue
# 续接末尾 10% 概率追加 EOS
@ -486,6 +542,8 @@ class PinyinInputDataset(IterableDataset):
pinyin_ids_cont,
)
idx = merge_end_idx
# 处理shuffle buffer - 单缓冲区半保留方案
if len(batch_samples) >= self.shuffle_buffer_size:
indices = np.random.permutation(len(batch_samples))

View File

@ -0,0 +1,401 @@
#!/usr/bin/env python3
"""
预处理数据质量分析脚本
功能
1. 统计 labels 的分布出现次数比例最大/最小未出现标签数
2. 随机抽样还原为人类可读文本导出为 CSV 文件
用法
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train --num-samples 50 --output samples.csv
"""
import argparse
import csv
import json
import random
from collections import Counter
from pathlib import Path
import numpy as np
from loguru import logger
from rich.console import Console
from rich.table import Table
from tqdm import tqdm
from .char_info import CharInfo
from .dataset import CHAR_TO_ID
from .preprocessed_dataset import PreProcessedDataset
from .query import QueryEngine
ID_TO_CHAR = {v: k for k, v in CHAR_TO_ID.items()}
def decode_pinyin_ids(pinyin_ids: list) -> str:
"""将 pinyin_ids 还原为拼音字符串"""
chars = []
for pid in pinyin_ids:
if pid == 0:
break
chars.append(ID_TO_CHAR.get(pid, "?"))
return "".join(chars)
def decode_history(history_ids: list, query_engine: QueryEngine) -> str:
"""将 history_slot_ids 还原为文字"""
parts = []
for hid in history_ids:
if hid == 0:
parts.append("<PAD>")
else:
info = query_engine.query_by_id(hid)
if info is not None:
parts.append(f"{info.char}({info.pinyin})")
else:
parts.append(f"<ID:{hid}>")
return " | ".join(parts)
def analyze_labels(dataset: PreProcessedDataset, max_shards: int = 0):
"""统计 labels 分布,带进度条"""
logger.info("正在统计 labels 分布...")
counter = Counter()
total = 0
num_shards = dataset._num_shards if dataset._is_sharded else 1
effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards
pbar = tqdm(range(effective_shards), desc="统计 labels", unit="shard")
for shard_idx in pbar:
if dataset._is_sharded:
shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz"))
labels = shard_data["labels"].astype(np.int64)
else:
labels = dataset.labels[:].astype(np.int64)
unique, counts = np.unique(labels, return_counts=True)
for uid, cnt in zip(unique, counts):
counter[int(uid)] += cnt
total += len(labels)
if dataset._is_sharded:
del shard_data
return counter, total
def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
"""将一个样本还原为人类可读格式"""
input_ids = (
sample["input_ids"].tolist()
if hasattr(sample["input_ids"], "tolist")
else sample["input_ids"]
)
token_type_ids = (
sample["token_type_ids"].tolist()
if hasattr(sample["token_type_ids"], "tolist")
else sample["token_type_ids"]
)
labels = (
sample["labels"].item()
if hasattr(sample["labels"], "item")
else sample["labels"]
)
history_ids = (
sample["history_slot_ids"].tolist()
if hasattr(sample["history_slot_ids"], "tolist")
else sample["history_slot_ids"]
)
pinyin_ids = (
sample["pinyin_ids"].tolist()
if hasattr(sample["pinyin_ids"], "tolist")
else sample["pinyin_ids"]
)
# 还原 token 文本
token_text = tokenizer.decode(input_ids, skip_special_tokens=False)
# 找到 token_type_ids 切换点,分离 sentence A 和 sentence B
sep_positions = [i for i, tid in enumerate(token_type_ids) if tid == 1]
if sep_positions:
sep_start = sep_positions[0]
sent_a_ids = [
tid
for tid, tt in zip(input_ids[:sep_start], token_type_ids[:sep_start])
if tt == 0
]
sent_b_ids = [tid for tid, tt in zip(input_ids, token_type_ids) if tt == 1]
else:
sent_a_ids = input_ids
sent_b_ids = []
context_text = tokenizer.decode(sent_a_ids, skip_special_tokens=True)
suffix_text = (
tokenizer.decode(sent_b_ids, skip_special_tokens=True) if sent_b_ids else ""
)
pinyin_str = decode_pinyin_ids(pinyin_ids)
label_info = query_engine.query_by_id(labels)
history_str = decode_history(history_ids, query_engine)
return {
"context": context_text,
"suffix": suffix_text,
"pinyin": pinyin_str,
"label_id": labels,
"label_char": f"{label_info.char}({label_info.pinyin})"
if label_info
else f"<ID:{labels}>",
"label_count": label_info.count if label_info else 0,
"history": history_str,
"full_tokens": token_text,
}
def main():
console = Console()
parser = argparse.ArgumentParser(description="预处理数据质量分析")
parser.add_argument(
"--data-dir",
type=str,
required=True,
help="预处理数据目录train/ 或 eval/",
)
parser.add_argument(
"--num-samples",
type=int,
default=50,
help="随机抽样的样本数量默认50",
)
parser.add_argument(
"--output",
type=str,
default=None,
help="CSV 输出文件路径(默认: <data-dir>/samples.csv",
)
parser.add_argument(
"--max-shards",
type=int,
default=0,
help="统计 labels 时最多读取的分片数0=全部)",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="随机种子",
)
parser.add_argument(
"--top-k",
type=int,
default=30,
help="显示出现次数最多和最少的标签数量",
)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
if args.output is None:
args.output = str(Path(args.data_dir) / "samples.csv")
# 加载数据集
logger.info(f"加载数据集: {args.data_dir}")
dataset = PreProcessedDataset(args.data_dir)
console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]")
if dataset._is_sharded:
console.print(
f" 分片数: {dataset._num_shards}, 每分片: {dataset._shard_size:,} 样本"
)
console.print()
# 加载 QueryEngine
logger.info("加载 QueryEngine...")
query_engine = QueryEngine()
query_engine.load()
# 加载 Tokenizer
logger.info("加载 Tokenizer...")
from importlib.resources import files as pkg_files
from modelscope import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
Path(str(pkg_files(__package__))) / "assets" / "tokenizer"
)
# ====== 1. Labels 分布分析 ======
console.print("[bold yellow]====== Labels 分布分析 ======[/bold yellow]")
counter, total = analyze_labels(dataset, max_shards=args.max_shards)
# 获取词表总大小
vocab_size = len(query_engine._id_to_info) # 不含 EOS (id=0)
appeared_ids = set(counter.keys())
all_ids = set(range(0, vocab_size + 1)) # +1 包含 id=0 (EOS)
missing_ids = all_ids - appeared_ids
console.print(f"\n总样本数: {total:,}")
console.print(f"词表大小: {vocab_size + 1:,} (含 EOS)")
console.print(f"唯一标签数: {len(counter):,}")
console.print(
f"EOS (id=0) 出现次数: {counter.get(0, 0):,} ({counter.get(0, 0) / total * 100:.2f}%)"
)
console.print(
f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size + 1:,} ({len(missing_ids) / (vocab_size + 1) * 100:.2f}%)[/bold red]"
)
most_common = counter.most_common(args.top_k)
least_common = (
counter.most_common()[: -args.top_k - 1 : -1]
if len(counter) > args.top_k
else counter.most_common()
)
# 最多标签表
table_top = Table(
title=f"出现次数最多的 {args.top_k} 个标签",
show_header=True,
header_style="bold magenta",
)
table_top.add_column("排名", style="cyan", width=6)
table_top.add_column("ID", style="green", width=8)
table_top.add_column("字符(拼音)", style="yellow", width=20)
table_top.add_column("频次", style="red", width=12)
table_top.add_column("占比", style="blue", width=10)
for rank, (label_id, count) in enumerate(most_common, 1):
info = query_engine.query_by_id(label_id)
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
pct = count / total * 100
table_top.add_row(
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.3f}%"
)
console.print(table_top)
# 最少标签表
table_bottom = Table(
title=f"出现次数最少的 {min(args.top_k, len(counter))} 个标签",
show_header=True,
header_style="bold magenta",
)
table_bottom.add_column("排名", style="cyan", width=6)
table_bottom.add_column("ID", style="green", width=8)
table_bottom.add_column("字符(拼音)", style="yellow", width=20)
table_bottom.add_column("频次", style="red", width=12)
table_bottom.add_column("占比", style="blue", width=10)
for rank, (label_id, count) in enumerate(least_common, 1):
info = query_engine.query_by_id(label_id)
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
pct = count / total * 100
table_bottom.add_row(
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.6f}%"
)
console.print(table_bottom)
# 频次分布概览
table_dist = Table(
title="频次分布概览", show_header=True, header_style="bold magenta"
)
table_dist.add_column("频次区间", style="cyan")
table_dist.add_column("标签数", style="green")
table_dist.add_column("占总标签数比例", style="yellow")
bins = [
(1, 10),
(11, 100),
(101, 1000),
(1001, 10000),
(10001, 100000),
(100001, 1000000),
(1000001, float("inf")),
]
for lo, hi in bins:
count_in_bin = sum(1 for c in counter.values() if lo <= c <= hi)
if count_in_bin > 0:
hi_str = str(int(hi)) if hi != float("inf") else ""
table_dist.add_row(
f"{lo}-{hi_str}",
f"{count_in_bin:,}",
f"{count_in_bin / len(counter) * 100:.1f}%",
)
# 未出现
if len(missing_ids) > 0:
table_dist.add_row(
"未出现",
f"{len(missing_ids):,}",
f"{len(missing_ids) / (vocab_size + 1) * 100:.1f}%",
)
console.print(table_dist)
# ====== 2. 随机抽样还原 → CSV ======
num_samples = min(args.num_samples, len(dataset))
console.print(
f"\n[bold yellow]====== 随机抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
)
indices = random.sample(range(len(dataset)), num_samples)
csv_path = Path(args.output)
csv_path.parent.mkdir(parents=True, exist_ok=True)
csv_headers = [
"index",
"pinyin",
"label_char",
"label_id",
"label_count",
"context",
"suffix",
"history",
"full_tokens",
]
with open(csv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(csv_headers)
for i, idx in enumerate(tqdm(indices, desc="解码样本", unit="sample")):
sample = dataset[idx]
decoded = decode_sample(sample, tokenizer, query_engine)
writer.writerow(
[
idx,
decoded["pinyin"],
decoded["label_char"],
decoded["label_id"],
decoded["label_count"],
decoded["context"],
decoded["suffix"],
decoded["history"],
decoded["full_tokens"],
]
)
console.print(
f"[bold green]✓ 已导出 {num_samples} 个样本到 {csv_path}[/bold green]"
)
# 打印前 5 个样本的概要
console.print(f"\n[bold cyan]前 {min(5, num_samples)} 个样本概览:[/bold cyan]")
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
if i >= 5:
break
console.print(
f" [{i + 1}] 拼音={row['pinyin']} 目标={row['label_char']} "
f"上下文={row['context'][:50]}..."
)
console.print("\n[bold green]分析完成[/bold green]")
app = main
if __name__ == "__main__":
main()

335
src/model/preprocess.py Normal file
View File

@ -0,0 +1,335 @@
#!/usr/bin/env python3
"""
预处理脚本 PinyinInputDataset 的输出转换为分片压缩 .npz 文件
采用分片流式写入内存峰值固定为 shard_size 级别不随总样本数增长
每个分片使用 np.savez_compressed 保存zlib 压缩GPU 服务器无需解压到硬盘
用法
python -m model.preprocess \
--train-data-path "some/hf_dataset" \
--eval-data-path "some/hf_dataset" \
--output-dir ./preprocessed \
--num-train-samples 5000000 \
--num-eval-samples 8192
生成目录结构
output_dir/
train/
metadata.json
shard_000.npz (5M样本, 6个字段, zlib压缩)
shard_001.npz
...
eval/
metadata.json
shard_000.npz
...
"""
import argparse
import gc
import json
from pathlib import Path
from typing import Dict, List
import numpy as np
import torch
from loguru import logger
from rich.console import Console
from torch.utils.data import DataLoader
from tqdm import tqdm
from .dataset import PinyinInputDataset
from .trainer import collate_fn, worker_init_fn
FIELDS = [
"input_ids",
"token_type_ids",
"attention_mask",
"labels",
"history_slot_ids",
"pinyin_ids",
]
def _extract_batch(batch: dict, take: int) -> Dict[str, np.ndarray]:
"""从 DataLoader batch 中提取指定数量的样本,转为 int16 numpy 数组"""
result = {}
for f in FIELDS:
tensor = batch[f][:take]
arr = tensor.numpy().astype(np.int16)
if f == "labels" and arr.ndim > 1 and arr.shape[-1] == 1:
arr = arr.squeeze(-1)
result[f] = arr
return result
def collect_samples(
dataloader: DataLoader,
num_samples: int,
output_dir: Path,
split_name: str,
max_seq_length: int = 128,
shard_size: int = 5_000_000,
) -> int:
"""
分片流式收集样本每累积 shard_size 个样本保存为一个压缩 .npz 分片
内存峰值 = shard_size × 每样本字节数约578字节 @ shard_size=5M 约2.9GB
"""
split_dir = output_dir / split_name
split_dir.mkdir(parents=True, exist_ok=True)
shard_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
shard_count = 0
shard_idx = 0
total = 0
pbar = tqdm(total=num_samples, desc=f"Processing {split_name}", unit="samples")
for batch in dataloader:
batch_size = batch["input_ids"].size(0)
remaining = num_samples - total
if remaining <= 0:
break
take = min(batch_size, remaining)
extracted = _extract_batch(batch, take)
for f in FIELDS:
shard_buffers[f].append(extracted[f])
shard_count += take
total += take
pbar.update(take)
if shard_count >= shard_size:
merged = {}
for f in FIELDS:
merged[f] = np.concatenate(shard_buffers[f], axis=0)
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
shard_idx += 1
shard_buffers = {f: [] for f in FIELDS}
shard_count = 0
del merged
gc.collect()
if total >= num_samples:
break
# 写入最后一个不满的分片
if shard_count > 0:
merged = {}
for f in FIELDS:
merged[f] = np.concatenate(shard_buffers[f], axis=0)
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
shard_idx += 1
pbar.close()
actual_count = min(total, num_samples)
num_shards = shard_idx
metadata = {
"num_samples": actual_count,
"max_seq_length": max_seq_length,
"dtype": "int16",
"fields": FIELDS,
"shard_size": shard_size,
"num_shards": num_shards,
}
with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp:
json.dump(metadata, fp, indent=2, ensure_ascii=False)
total_size = sum(
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
)
logger.info(
f"{split_name}: {actual_count} samples in {num_shards} shards, "
f"{total_size / (1024**3):.2f} GB (compressed)"
)
return actual_count
def main():
console = Console()
parser = argparse.ArgumentParser(description="预处理数据集为分片压缩npz文件")
parser.add_argument(
"--train-data-path",
type=str,
required=True,
help="训练数据集路径HuggingFace格式",
)
parser.add_argument(
"--eval-data-path",
type=str,
required=True,
help="评估数据集路径HuggingFace格式",
)
parser.add_argument("--output-dir", type=str, required=True, help="输出目录")
parser.add_argument(
"--num-train-samples", type=int, required=True, help="训练集样本数量"
)
parser.add_argument(
"--num-eval-samples", type=int, required=True, help="评估集样本数量"
)
parser.add_argument("--batch-size", type=int, default=128, help="批大小")
parser.add_argument(
"--num-workers", type=int, default=2, help="DataLoader worker数量"
)
parser.add_argument("--max-seq-length", type=int, default=128, help="最大序列长度")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
parser.add_argument(
"--shard-size",
type=int,
default=5_000_000,
help="分片大小样本数控制内存峰值默认500万约2.9GB/分片未压缩)",
)
parser.add_argument(
"--py-style-weight",
type=str,
default="9,2,1",
help="拼音风格权重(逗号分隔)",
)
parser.add_argument(
"--shuffle-buffer-size",
type=int,
default=2000000,
help="数据集shuffle缓冲区大小",
)
parser.add_argument(
"--length-weights",
type=str,
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
help="词长权重",
)
args = parser.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
length_weights = {
int(k): int(v)
for k, v in (item.split(":") for item in args.length_weights.split(","))
}
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
train_max_iter = args.num_train_samples * 5
eval_max_iter = args.num_eval_samples * 5
shard_mem_gb = args.shard_size * 578 / (1024**3)
console.print("[bold cyan]=== 数据预处理 ===[/bold cyan]")
console.print(f"训练集目标: {args.num_train_samples:,} 样本")
console.print(f"评估集目标: {args.num_eval_samples:,} 样本")
console.print(f"输出目录: {output_dir}")
console.print(f"数据类型: int16")
console.print(
f"分片大小: {args.shard_size:,} 样本 (约 {shard_mem_gb:.1f} GB/分片 未压缩)"
)
console.print()
num_train_workers = args.num_workers
num_eval_workers = max(1, args.num_workers // 2)
console.print("[bold cyan]创建训练数据集...[/bold cyan]")
train_dataset = PinyinInputDataset(
data_path=args.train_data_path,
max_workers=num_train_workers,
max_iter_length=train_max_iter,
max_seq_length=args.max_seq_length,
text_field="text",
py_style_weight=py_style_weight,
shuffle_buffer_size=100,
length_weights=length_weights,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
num_workers=num_train_workers,
pin_memory=False,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2,
persistent_workers=True if num_train_workers > 0 else False,
)
console.print("[bold cyan]创建评估数据集...[/bold cyan]")
eval_dataset = PinyinInputDataset(
data_path=args.eval_data_path,
max_workers=num_eval_workers,
max_iter_length=eval_max_iter,
max_seq_length=args.max_seq_length,
text_field="text",
py_style_weight=py_style_weight,
shuffle_buffer_size=100,
length_weights=length_weights,
)
eval_dataloader = DataLoader(
eval_dataset,
batch_size=args.batch_size,
num_workers=num_eval_workers,
pin_memory=False,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2,
persistent_workers=True if num_eval_workers > 0 else False,
)
logger.info("开始收集训练数据...")
train_count = collect_samples(
train_dataloader,
args.num_train_samples,
output_dir,
"train",
args.max_seq_length,
args.shard_size,
)
if train_count < args.num_train_samples:
logger.warning(
f"训练集样本不足: 目标 {args.num_train_samples}, 实际 {train_count}"
)
logger.info("开始收集评估数据...")
eval_count = collect_samples(
eval_dataloader,
args.num_eval_samples,
output_dir,
"eval",
args.max_seq_length,
args.shard_size,
)
if eval_count < args.num_eval_samples:
logger.warning(
f"评估集样本不足: 目标 {args.num_eval_samples}, 实际 {eval_count}"
)
console.print("\n[bold green]=== 预处理完成 ===[/bold green]")
console.print(f"训练集: {train_count:,} 样本")
console.print(f"评估集: {eval_count:,} 样本")
console.print(f"输出目录: {output_dir}")
for split in ["train", "eval"]:
split_dir = output_dir / split
if split_dir.exists():
total_size = sum(
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
)
console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)")
app = main
if __name__ == "__main__":
main()

View File

@ -0,0 +1,188 @@
"""
预处理数据集加载器
支持两种格式
1. 分片压缩格式.npz从压缩文件中按需加载分片LRU 缓存最多持有 max_cache_shards 个分片
2. 单体格式.npy向后兼容使用 mmap 零拷贝加载
GPU 服务器仅需存放压缩后的 .npz 文件无需解压到硬盘
"""
import gc
import json
from collections import OrderedDict
from pathlib import Path
from typing import Dict, Optional
import numpy as np
import torch
from loguru import logger
from torch.utils.data import Dataset
FIELDS = [
"input_ids",
"token_type_ids",
"attention_mask",
"labels",
"history_slot_ids",
"pinyin_ids",
]
def is_preprocessed_data(path: str) -> bool:
"""判断路径是否为预处理数据目录"""
p = Path(path)
return p.is_dir() and (p / "metadata.json").exists()
class _ShardCache:
"""LRU 缓存,管理按需加载的 .npz 分片,最多持有 max_size 个分片"""
def __init__(self, max_size: int = 2):
self.max_size = max_size
self._cache: OrderedDict[int, Dict[str, np.ndarray]] = OrderedDict()
def get(self, shard_idx: int, loader_fn) -> Dict[str, np.ndarray]:
if shard_idx in self._cache:
self._cache.move_to_end(shard_idx)
return self._cache[shard_idx]
data = loader_fn(shard_idx)
self._cache[shard_idx] = data
self._cache.move_to_end(shard_idx)
while len(self._cache) > self.max_size:
evicted_key, evicted_data = self._cache.popitem(last=False)
del evicted_data
gc.collect()
return data
def clear(self):
self._cache.clear()
gc.collect()
class PreProcessedDataset(Dataset):
"""
预处理数据集加载器自动检测数据格式
- 分片压缩格式metadata.json shard_size 字段
.npz 分片按需加载LRU 缓存控制内存
- 单体格式向后兼容
mmap 零拷贝加载 .npy 文件
所有数据以 int16 存储读取时转为 torch.long (int64)
"""
def __init__(self, data_dir: str, max_cache_shards: int = 2):
self.data_dir = Path(data_dir)
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
self.metadata = json.load(f)
self.num_samples = self.metadata["num_samples"]
self.max_seq_length = self.metadata["max_seq_length"]
self._shard_size: Optional[int] = self.metadata.get("shard_size")
self._num_shards: Optional[int] = self.metadata.get("num_shards")
if self._shard_size is not None and self._num_shards is not None:
self._is_sharded = True
self._cache = _ShardCache(max_size=max_cache_shards)
logger.info(
f"Loaded sharded dataset: {self.num_samples:,} samples, "
f"{self._num_shards} shards, shard_size={self._shard_size:,}"
)
else:
self._is_sharded = False
self._load_single_files()
logger.info(
f"Loaded single-file dataset: {self.num_samples:,} samples (mmap)"
)
def _load_single_files(self):
"""向后兼容:加载单体 .npy 文件mmap 模式)"""
self.input_ids = np.load(self.data_dir / "input_ids.npy", mmap_mode="r")
self.token_type_ids = np.load(
self.data_dir / "token_type_ids.npy", mmap_mode="r"
)
self.attention_mask = np.load(
self.data_dir / "attention_mask.npy", mmap_mode="r"
)
self.labels = np.load(self.data_dir / "labels.npy", mmap_mode="r")
self.history_slot_ids = np.load(
self.data_dir / "history_slot_ids.npy", mmap_mode="r"
)
self.pinyin_ids = np.load(self.data_dir / "pinyin_ids.npy", mmap_mode="r")
def _load_shard(self, shard_idx: int) -> Dict[str, np.ndarray]:
"""加载一个 .npz 分片到内存"""
shard_path = self.data_dir / f"shard_{shard_idx:06d}.npz"
data = dict(np.load(shard_path))
for key in data:
data[key] = data[key].astype(np.int64)
return data
def __len__(self) -> int:
return self.num_samples
def __getitem__(self, idx: int) -> dict:
if not 0 <= idx < self.num_samples:
raise IndexError(
f"Index {idx} out of range for dataset with {self.num_samples} samples"
)
if self._is_sharded:
shard_idx = idx // self._shard_size
local_idx = idx % self._shard_size
shard_data = self._cache.get(shard_idx, self._load_shard)
return {
"input_ids": torch.from_numpy(
shard_data["input_ids"][local_idx].copy()
),
"token_type_ids": torch.from_numpy(
shard_data["token_type_ids"][local_idx].copy()
),
"attention_mask": torch.from_numpy(
shard_data["attention_mask"][local_idx].copy()
),
"labels": torch.tensor(
shard_data["labels"][local_idx], dtype=torch.long
),
"history_slot_ids": torch.from_numpy(
shard_data["history_slot_ids"][local_idx].copy()
),
"pinyin_ids": torch.from_numpy(
shard_data["pinyin_ids"][local_idx].copy()
),
}
else:
return {
"input_ids": torch.from_numpy(self.input_ids[idx].astype(np.int64)),
"token_type_ids": torch.from_numpy(
self.token_type_ids[idx].astype(np.int64)
),
"attention_mask": torch.from_numpy(
self.attention_mask[idx].astype(np.int64)
),
"labels": torch.tensor(self.labels[idx], dtype=torch.long),
"history_slot_ids": torch.from_numpy(
self.history_slot_ids[idx].astype(np.int64)
),
"pinyin_ids": torch.from_numpy(self.pinyin_ids[idx].astype(np.int64)),
}
def preprocessed_collate_fn(batch):
"""
预处理数据的 collate 函数
不含 string 字段prefix/suffix/pinyin仅处理 tensor 字段
"""
return {
"input_ids": torch.stack([item["input_ids"] for item in batch]),
"token_type_ids": torch.stack([item["token_type_ids"] for item in batch]),
"attention_mask": torch.stack([item["attention_mask"] for item in batch]),
"labels": torch.stack([item["labels"] for item in batch]),
"history_slot_ids": torch.stack([item["history_slot_ids"] for item in batch]),
"pinyin_ids": torch.stack([item["pinyin_ids"] for item in batch]),
}

View File

@ -28,6 +28,11 @@ from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .dataset import PinyinInputDataset
from .preprocessed_dataset import (
PreProcessedDataset,
is_preprocessed_data,
preprocessed_collate_fn,
)
# 导入模型和数据
from .model import InputMethodEngine
@ -97,7 +102,12 @@ class Trainer:
"""
self.model = model
self.train_dataloader = train_dataloader
self.eval_dataloader = list([i for i in eval_dataloader])
if isinstance(eval_dataloader, DataLoader) and not isinstance(
eval_dataloader.dataset, torch.utils.data.IterableDataset
):
self.eval_dataloader = eval_dataloader
else:
self.eval_dataloader = list([i for i in eval_dataloader])
self.output_dir = Path(output_dir)
self.num_epochs = num_epochs
self.learning_rate = learning_rate
@ -1003,7 +1013,7 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
# Typer CLI应用
def create_dataloader(
dataset: PinyinInputDataset,
dataset,
batch_size: int,
num_workers: int = 2,
pin_memory: bool = True,
@ -1011,20 +1021,23 @@ def create_dataloader(
max_iter_length: Optional[int] = None,
) -> Any:
"""
创建数据加载器优先使用DataLoader2如果不可用则回退到DataLoader
专门针对流式数据集优化
创建数据加载器自动识别数据集类型
Args:
dataset: PinyinInputDataset实例
batch_size: 批次大小
num_workers: worker数量对于流式数据集建议为2
pin_memory: 是否固定内存
shuffle: 是否打乱流式数据集内部处理打乱
max_iter_length: 最大迭代长度用于计算总步数
Returns:
数据加载器实例
- PinyinInputDatasetIterableDataset使用流式加载
- PreProcessedDatasetmap-style使用标准加载支持 shuffle
"""
if isinstance(dataset, PreProcessedDataset):
logger.info(f"📊 使用预处理数据集,样本数: {len(dataset)}")
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=preprocessed_collate_fn,
persistent_workers=True if num_workers > 0 else False,
)
logger.info(f"📊 使用标准DataLoaderworker数量: {num_workers}")
dataloader = DataLoader(
dataset,
@ -1164,12 +1177,82 @@ def train(
config_table.add_row("训练", "混合精度", str(mixed_precision))
config_table.add_row("其他", "自动恢复", str(auto_resume))
console.print(config_table)
# 创建输出目录
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 检测数据类型并创建数据加载器
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
is_train_preprocessed = is_preprocessed_data(train_data_path)
is_eval_preprocessed = is_preprocessed_data(eval_data_path)
if is_train_preprocessed:
train_dataset = PreProcessedDataset(train_data_path)
total_steps = (len(train_dataset) // batch_size) * num_epochs
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
shuffle=True,
)
config_table.add_row("数据", "训练数据类型", "预处理数据")
else:
train_dataset = PinyinInputDataset(
data_path=train_data_path,
max_workers=-1,
max_iter_length=max_iter_length,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
total_steps = int(max_iter_length * num_epochs / batch_size)
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=max_iter_length,
)
config_table.add_row("数据", "训练数据类型", "流式数据")
if is_eval_preprocessed:
eval_dataset = PreProcessedDataset(eval_data_path)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=batch_size,
num_workers=2,
pin_memory=torch.cuda.is_available(),
shuffle=False,
)
config_table.add_row("数据", "评估数据类型", "预处理数据")
else:
eval_dataset = PinyinInputDataset(
data_path=eval_data_path,
max_workers=-1,
max_iter_length=batch_size * 64,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=batch_size,
num_workers=2,
pin_memory=torch.cuda.is_available(),
max_iter_length=batch_size * 64,
)
config_table.add_row("数据", "评估数据类型", "流式数据")
config_table.add_row("数据", "总步数", str(total_steps))
console.print(config_table)
# 保存配置
config = {
"train_data_path": train_data_path,
@ -1202,6 +1285,9 @@ def train(
"auto_resume": auto_resume,
"max_iter_length": max_iter_length,
"compile": compile,
"is_train_preprocessed": is_train_preprocessed,
"is_eval_preprocessed": is_eval_preprocessed,
"total_steps": total_steps,
}
config_file = output_path / "training_config.json"
@ -1210,52 +1296,6 @@ def train(
logger.info(f"Configuration saved to {config_file}")
# 创建数据加载器
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
# 训练数据集
train_dataset = PinyinInputDataset(
data_path=train_data_path,
max_workers=-1, # 自动选择worker数量
max_iter_length=max_iter_length,
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
# 训练数据加载器
# 注意PinyinInputDataset是IterableDataset所以不能使用shuffle参数
# 多worker配置每个worker处理数据集的一个分片由dataset.__iter__中的shard处理
train_dataloader = create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
max_iter_length=max_iter_length,
)
# 评估数据集(使用相同的设置,但可以调整参数)
eval_dataset = PinyinInputDataset(
data_path=eval_data_path,
max_workers=-1,
max_iter_length=batch_size * 64, # 评估集较小
max_seq_length=max_seq_len,
text_field="text",
py_style_weight=(9, 2, 1),
shuffle_buffer_size=2000000,
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
)
eval_dataloader = create_dataloader(
dataset=eval_dataset,
batch_size=batch_size,
num_workers=2, # 评估使用较少的worker
pin_memory=torch.cuda.is_available(),
max_iter_length=batch_size * 64,
)
console.print("[bold cyan]正在创建模型...[/bold cyan]")
model = InputMethodEngine(
vocab_size=vocab_size,
@ -1279,7 +1319,7 @@ def train(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
total_steps=int(max_iter_length * num_epochs / batch_size),
total_steps=total_steps,
output_dir=output_dir,
num_epochs=num_epochs,
learning_rate=learning_rate,

View File

@ -89,7 +89,7 @@ train_dataset = PinyinInputDataset(
dataloader = DataLoader(
train_dataset,
batch_size=512,
num_workers=2,
num_workers=16,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn,
prefetch_factor=2, # 减少预取以避免内存问题

146
visualize_distribution.py Normal file
View File

@ -0,0 +1,146 @@
#!/usr/bin/env python3
"""
Visualize frequency distribution with ASCII plots
"""
import json
import math
import sys
from pathlib import Path
def ascii_histogram(data, bins=20, width=60):
"""Create ASCII histogram"""
if not data:
return ""
min_val = min(data)
max_val = max(data)
# Use log bins for wide range
if max_val / min_val > 1000:
log_min = math.log10(min_val) if min_val > 0 else 0
log_max = math.log10(max_val)
bin_edges = [10**(log_min + i*(log_max-log_min)/bins) for i in range(bins+1)]
hist = [0] * bins
for val in data:
if val > 0:
log_val = math.log10(val)
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * bins), bins-1)
hist[bin_idx] += 1
bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i+1]:.1e}" for i in range(bins)]
else:
bin_width = (max_val - min_val) / bins
bin_edges = [min_val + i*bin_width for i in range(bins+1)]
hist = [0] * bins
for val in data:
bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins-1)
hist[bin_idx] += 1
bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i+1]:.1f}" for i in range(bins)]
max_count = max(hist)
result = []
for i in range(bins):
if hist[i] == 0:
continue
bar = '#' * int(hist[i] / max_count * width)
result.append(f"{bin_labels[i]:20} | {bar} {hist[i]}")
return "\n".join(result)
def main():
json_path = Path("src/model/assets/pinyin_char_statistics.json")
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
pairs = data.get('pairs', {})
counts = [pair.get('count', 0) for pair in pairs.values() if pair.get('count') is not None]
print("FREQUENCY DISTRIBUTION ANALYSIS")
print("="*60)
print("\n1. ASCII Histogram (log bins):")
print(ascii_histogram(counts, bins=20, width=60))
# Rank-frequency plot in ASCII
print("\n2. Rank-Frequency Relationship (Top 50):")
counts_sorted_desc = sorted(counts, reverse=True)
max_freq = counts_sorted_desc[0]
max_rank = 50
for rank in range(1, max_rank + 1):
freq = counts_sorted_desc[rank-1]
bar_length = int(math.log(freq) / math.log(max_freq) * 40)
bar = '#' * bar_length
print(f"Rank {rank:3}: {freq:12} {bar}")
# ID vs Frequency plot (sampled)
print("\n3. ID vs Frequency (sampled every 500 IDs):")
# Build ID to count mapping
id_to_count = {}
for key, pair in pairs.items():
char_id = pair.get('id')
count = pair.get('count')
if char_id is not None and count is not None:
id_to_count[char_id] = count
all_ids = sorted(id_to_count.keys())
max_id = all_ids[-1]
print("ID Frequency log10(freq)")
for id in range(0, max_id + 1, 500):
if id in id_to_count:
freq = id_to_count[id]
log_freq = math.log10(freq) if freq > 0 else 0
bar = '#' * int(log_freq / math.log10(max_freq) * 40)
print(f"{id:6} {freq:10} {log_freq:6.2f} {bar}")
# Zipf's law fit
print("\n4. Zipf's Law Analysis:")
print(" Rank * Frequency ≈ constant for Zipf's law")
print(" Top 10 ranks:")
for rank in range(1, 11):
freq = counts_sorted_desc[rank-1]
product = rank * freq
print(f" Rank {rank}: {freq:12} rank*freq = {product:.3e}")
# Check if product is roughly constant
products = [(rank+1) * counts_sorted_desc[rank] for rank in range(10)]
avg_product = sum(products) / len(products)
std_product = math.sqrt(sum((p - avg_product)**2 for p in products) / len(products))
print(f" Average product (ranks 2-11): {avg_product:.3e} ± {std_product:.3e}")
print(f" Coefficient of variation: {std_product/avg_product*100:.1f}%")
# Frequency spectrum
from collections import Counter
freq_counter = Counter(counts)
print("\n5. Frequency Spectrum (how many entries have each frequency):")
print(" Frequency Count Cumulative")
cum = 0
for freq in sorted(freq_counter.keys())[:20]:
count = freq_counter[freq]
cum += count
print(f" {freq:10} {count:6} {cum:6}")
# Summary statistics
print("\n6. Key Statistics:")
n = len(counts)
print(f" Total entries: {n}")
print(f" Min frequency: {min(counts)}")
print(f" Max frequency: {max(counts)}")
print(f" Ratio max/min: {max(counts)/min(counts):.2e}")
percentiles = [0.01, 0.1, 0.5, 0.9, 0.99]
for p in percentiles:
idx = int(p * n)
value = counts_sorted_desc[idx]
print(f" {p*100:5.1f}th percentile: {value:12} (rank ~{idx})")
# Save data for external plotting
with open("id_vs_freq.csv", "w") as f:
f.write("id,frequency\n")
for id in sorted(id_to_count.keys()):
f.write(f"{id},{id_to_count[id]}\n")
print("\nData saved to id_vs_freq.csv for external plotting")
if __name__ == "__main__":
main()