feat(analyze_frequency): 添加拼音字符频率分析脚本
This commit is contained in:
parent
1b7da9ddd4
commit
4ded2d656f
|
|
@ -0,0 +1,236 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Analyze frequency distribution in pinyin_char_statistics.json
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
def main():
|
||||
# Path to the JSON file
|
||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
||||
if not json_path.exists():
|
||||
print(f"Error: File not found: {json_path}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loading {json_path}...")
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
print(f"Timestamp: {data.get('timestamp')}")
|
||||
print(f"Total characters: {data.get('total_characters')}")
|
||||
print(f"Total pinyins: {data.get('total_pinyins')}")
|
||||
print(f"Valid input character count: {data.get('valid_input_character_count')}")
|
||||
|
||||
pairs = data.get('pairs', {})
|
||||
print(f"Number of pairs: {len(pairs)}")
|
||||
|
||||
# Extract counts and IDs
|
||||
counts = []
|
||||
id_to_count = {}
|
||||
char_to_count = {}
|
||||
for key, pair in pairs.items():
|
||||
try:
|
||||
char_id = pair.get('id')
|
||||
count = pair.get('count')
|
||||
char = pair.get('char', '')
|
||||
if char_id is not None and count is not None:
|
||||
counts.append(count)
|
||||
id_to_count[char_id] = count
|
||||
if char:
|
||||
char_to_count[char] = count
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"Warning: Could not parse pair {key}: {e}")
|
||||
continue
|
||||
|
||||
if not counts:
|
||||
print("No valid count data found.")
|
||||
return
|
||||
|
||||
# Basic statistics
|
||||
min_count = min(counts)
|
||||
max_count = max(counts)
|
||||
total_count = sum(counts)
|
||||
mean_count = total_count / len(counts)
|
||||
|
||||
# Sort counts for percentiles
|
||||
sorted_counts = sorted(counts)
|
||||
n = len(sorted_counts)
|
||||
|
||||
# Percentiles
|
||||
p10 = sorted_counts[int(0.1 * n)]
|
||||
p25 = sorted_counts[int(0.25 * n)]
|
||||
p50 = sorted_counts[int(0.5 * n)]
|
||||
p75 = sorted_counts[int(0.75 * n)]
|
||||
p90 = sorted_counts[int(0.9 * n)]
|
||||
p99 = sorted_counts[int(0.99 * n)]
|
||||
|
||||
# Variance and std dev
|
||||
variance = sum((x - mean_count) ** 2 for x in counts) / n
|
||||
std_dev = math.sqrt(variance)
|
||||
|
||||
print("\n=== BASIC STATISTICS ===")
|
||||
print(f"Min frequency: {min_count}")
|
||||
print(f"Max frequency: {max_count}")
|
||||
print(f"Mean frequency: {mean_count:.2f}")
|
||||
print(f"Standard deviation: {std_dev:.2f}")
|
||||
print(f"Total frequency sum: {total_count}")
|
||||
print(f"Number of entries: {n}")
|
||||
|
||||
print("\n=== PERCENTILES ===")
|
||||
print(f"10th percentile: {p10}")
|
||||
print(f"25th percentile: {p25}")
|
||||
print(f"50th percentile (median): {p50}")
|
||||
print(f"75th percentile: {p75}")
|
||||
print(f"90th percentile: {p90}")
|
||||
print(f"99th percentile: {p99}")
|
||||
|
||||
# Find IDs with min and max counts
|
||||
min_ids = [id for id, count in id_to_count.items() if count == min_count]
|
||||
max_ids = [id for id, count in id_to_count.items() if count == max_count]
|
||||
|
||||
print(f"\nIDs with min frequency ({min_count}): {min_ids}")
|
||||
print(f"IDs with max frequency ({max_count}): {max_ids}")
|
||||
|
||||
# Check if IDs are assigned in frequency order
|
||||
# Compute correlation between ID and count
|
||||
ids = list(id_to_count.keys())
|
||||
id_counts = [id_to_count[id] for id in ids]
|
||||
|
||||
# Sort by ID and check if counts are decreasing
|
||||
sorted_by_id = sorted(ids)
|
||||
counts_by_id = [id_to_count[id] for id in sorted_by_id]
|
||||
|
||||
# Calculate monotonicity: count of times count decreases as ID increases
|
||||
decreases = 0
|
||||
increases = 0
|
||||
for i in range(1, len(counts_by_id)):
|
||||
if counts_by_id[i] < counts_by_id[i-1]:
|
||||
decreases += 1
|
||||
elif counts_by_id[i] > counts_by_id[i-1]:
|
||||
increases += 1
|
||||
|
||||
print(f"\n=== ID ORDER ANALYSIS ===")
|
||||
print(f"Total pairs: {len(counts_by_id)}")
|
||||
print(f"Decreases as ID increases: {decreases} times")
|
||||
print(f"Increases as ID increases: {increases} times")
|
||||
print(f"Percentage decreasing: {decreases/(len(counts_by_id)-1)*100:.2f}%")
|
||||
|
||||
# Check if IDs are roughly sorted by frequency
|
||||
# Compute Spearman rank correlation (simplified)
|
||||
sorted_by_count = sorted(ids, key=lambda x: id_to_count[x], reverse=True)
|
||||
rank_by_id = {id: i for i, id in enumerate(sorted_by_id)}
|
||||
rank_by_count = {id: i for i, id in enumerate(sorted_by_count)}
|
||||
|
||||
# Average rank difference
|
||||
rank_diffs = [abs(rank_by_id[id] - rank_by_count[id]) for id in ids]
|
||||
avg_rank_diff = sum(rank_diffs) / len(rank_diffs)
|
||||
max_rank_diff = max(rank_diffs)
|
||||
|
||||
print(f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}")
|
||||
print(f"Maximum rank difference: {max_rank_diff}")
|
||||
|
||||
# Analyze specific ID range 5000-5500
|
||||
print("\n=== ANALYSIS OF ID RANGE 5000-5500 ===")
|
||||
range_counts = []
|
||||
range_ids = []
|
||||
for id in range(5000, 5501):
|
||||
if id in id_to_count:
|
||||
range_counts.append(id_to_count[id])
|
||||
range_ids.append(id)
|
||||
|
||||
if range_counts:
|
||||
range_min = min(range_counts)
|
||||
range_max = max(range_counts)
|
||||
range_mean = sum(range_counts) / len(range_counts)
|
||||
range_sorted = sorted(range_counts)
|
||||
range_n = len(range_counts)
|
||||
range_p10 = range_sorted[int(0.1 * range_n)] if range_n > 0 else 0
|
||||
range_p50 = range_sorted[int(0.5 * range_n)] if range_n > 0 else 0
|
||||
range_p90 = range_sorted[int(0.9 * range_n)] if range_n > 0 else 0
|
||||
|
||||
print(f"IDs in range 5000-5500: {len(range_counts)}")
|
||||
print(f"Min frequency in range: {range_min}")
|
||||
print(f"Max frequency in range: {range_max}")
|
||||
print(f"Mean frequency in range: {range_mean:.2f}")
|
||||
print(f"10th percentile in range: {range_p10}")
|
||||
print(f"50th percentile in range: {range_p50}")
|
||||
print(f"90th percentile in range: {range_p90}")
|
||||
|
||||
# Find IDs with min frequency in this range
|
||||
min_in_range_ids = [id for id in range_ids if id_to_count[id] == range_min]
|
||||
print(f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}")
|
||||
else:
|
||||
print("No IDs found in range 5000-5500")
|
||||
|
||||
# Histogram of frequencies (log bins)
|
||||
print("\n=== FREQUENCY DISTRIBUTION (LOG BINS) ===")
|
||||
if max_count > 0:
|
||||
log_min = math.log10(min_count) if min_count > 0 else 0
|
||||
log_max = math.log10(max_count)
|
||||
num_bins = 20
|
||||
bin_edges = [10**(log_min + i*(log_max-log_min)/num_bins) for i in range(num_bins+1)]
|
||||
|
||||
hist = [0] * num_bins
|
||||
for count in counts:
|
||||
if count > 0:
|
||||
log_val = math.log10(count)
|
||||
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * num_bins), num_bins-1)
|
||||
hist[bin_idx] += 1
|
||||
|
||||
print("Log-scale histogram (count range -> frequency count):")
|
||||
for i in range(num_bins):
|
||||
if hist[i] > 0:
|
||||
lower = bin_edges[i]
|
||||
upper = bin_edges[i+1]
|
||||
print(f" {lower:.2e} - {upper:.2e}: {hist[i]} entries")
|
||||
|
||||
# Check for zero or near-zero frequencies
|
||||
zero_count = sum(1 for c in counts if c == 0)
|
||||
low_count = sum(1 for c in counts if 0 < c <= 10)
|
||||
very_low_count = sum(1 for c in counts if 0 < c <= 100)
|
||||
|
||||
print(f"\n=== LOW FREQUENCY ANALYSIS ===")
|
||||
print(f"Entries with zero frequency: {zero_count}")
|
||||
print(f"Entries with frequency <= 10: {low_count}")
|
||||
print(f"Entries with frequency <= 100: {very_low_count}")
|
||||
|
||||
# Find the actual min frequency (excluding zeros if any)
|
||||
non_zero_counts = [c for c in counts if c > 0]
|
||||
if non_zero_counts:
|
||||
actual_min = min(non_zero_counts)
|
||||
print(f"Actual min frequency (non-zero): {actual_min}")
|
||||
actual_min_ids = [id for id, count in id_to_count.items() if count == actual_min]
|
||||
print(f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}")
|
||||
|
||||
# Summary for smoothing algorithm design
|
||||
print("\n=== SUMMARY FOR SMOOTHING ALGORITHM DESIGN ===")
|
||||
print(f"Frequency range spans {max_count/min_count if min_count>0 else 'inf'}:1 ratio")
|
||||
print(f"Most entries ({p50}) have frequency around {p50}")
|
||||
print(f"Top 10% of entries have frequency > {p90}")
|
||||
print(f"Bottom 10% of entries have frequency < {p10}")
|
||||
print(f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency")
|
||||
|
||||
# Save detailed data for further analysis
|
||||
output_file = "frequency_analysis_results.txt"
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write("Frequency Analysis Results\n")
|
||||
f.write("="*50 + "\n")
|
||||
f.write(f"Min frequency: {min_count}\n")
|
||||
f.write(f"Max frequency: {max_count}\n")
|
||||
f.write(f"Mean frequency: {mean_count:.2f}\n")
|
||||
f.write(f"Standard deviation: {std_dev:.2f}\n")
|
||||
f.write(f"10th percentile: {p10}\n")
|
||||
f.write(f"50th percentile: {p50}\n")
|
||||
f.write(f"90th percentile: {p90}\n")
|
||||
f.write(f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n")
|
||||
f.write(f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n")
|
||||
|
||||
print(f"\nDetailed results saved to {output_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Analyze specific ID ranges in pinyin_char_statistics.json
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
def main():
|
||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
pairs = data.get('pairs', {})
|
||||
|
||||
# Build ID to count mapping
|
||||
id_to_count = {}
|
||||
for key, pair in pairs.items():
|
||||
char_id = pair.get('id')
|
||||
count = pair.get('count')
|
||||
if char_id is not None and count is not None:
|
||||
id_to_count[char_id] = count
|
||||
|
||||
# Analyze range 5000-5500 in detail
|
||||
print("ID range 5000-5500 detailed analysis:")
|
||||
print("ID\tCount\tChar\tPinyin")
|
||||
|
||||
range_data = []
|
||||
for id in range(5000, 5501):
|
||||
if id in id_to_count:
|
||||
# Find the pair to get char and pinyin
|
||||
for key, pair in pairs.items():
|
||||
if pair.get('id') == id:
|
||||
char = pair.get('char', '')
|
||||
pinyin = pair.get('pinyin', '')
|
||||
count = pair.get('count', 0)
|
||||
range_data.append((id, count, char, pinyin))
|
||||
if id % 100 == 0: # Print every 100th for overview
|
||||
print(f"{id}\t{count}\t{char}\t{pinyin}")
|
||||
break
|
||||
|
||||
# Print min and max in range
|
||||
if range_data:
|
||||
min_item = min(range_data, key=lambda x: x[1])
|
||||
max_item = max(range_data, key=lambda x: x[1])
|
||||
print(f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'")
|
||||
print(f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'")
|
||||
|
||||
# Check if frequencies are monotonic in this range
|
||||
counts = [item[1] for item in range_data]
|
||||
increasing = all(counts[i] <= counts[i+1] for i in range(len(counts)-1))
|
||||
decreasing = all(counts[i] >= counts[i+1] for i in range(len(counts)-1))
|
||||
print(f"Monotonic in range: increasing={increasing}, decreasing={decreasing}")
|
||||
|
||||
# Check for frequency plateaus
|
||||
from collections import Counter
|
||||
freq_count = Counter(counts)
|
||||
most_common = freq_count.most_common(5)
|
||||
print(f"Most common frequencies in range: {most_common}")
|
||||
|
||||
# Analyze the tail (IDs with frequency 1)
|
||||
print("\n\nAnalysis of frequency=1 entries:")
|
||||
freq_one_ids = [id for id, count in id_to_count.items() if count == 1]
|
||||
print(f"Number of entries with frequency=1: {len(freq_one_ids)}")
|
||||
if freq_one_ids:
|
||||
print(f"ID range of frequency=1: {min(freq_one_ids)} to {max(freq_one_ids)}")
|
||||
print(f"First 10 IDs: {freq_one_ids[:10]}")
|
||||
print(f"Last 10 IDs: {freq_one_ids[-10:]}")
|
||||
|
||||
# Check if they're contiguous
|
||||
sorted_ids = sorted(freq_one_ids)
|
||||
contiguous = all(sorted_ids[i] + 1 == sorted_ids[i+1] for i in range(len(sorted_ids)-1))
|
||||
print(f"Are they contiguous IDs? {contiguous}")
|
||||
|
||||
# Sample some characters
|
||||
print("\nSample characters with frequency=1:")
|
||||
sample_count = 0
|
||||
for key, pair in pairs.items():
|
||||
if pair.get('count') == 1 and sample_count < 10:
|
||||
print(f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'")
|
||||
sample_count += 1
|
||||
|
||||
# Check overall ID-frequency ordering
|
||||
print("\n\nOverall ID-frequency ordering analysis:")
|
||||
all_ids = sorted(id_to_count.keys())
|
||||
all_counts = [id_to_count[id] for id in all_ids]
|
||||
|
||||
# Count monotonic segments
|
||||
non_increasing_segments = 0
|
||||
current_segment_length = 1
|
||||
for i in range(1, len(all_counts)):
|
||||
if all_counts[i] <= all_counts[i-1]:
|
||||
current_segment_length += 1
|
||||
else:
|
||||
if current_segment_length > 1:
|
||||
non_increasing_segments += 1
|
||||
current_segment_length = 1
|
||||
if current_segment_length > 1:
|
||||
non_increasing_segments += 1
|
||||
|
||||
print(f"Total IDs: {len(all_ids)}")
|
||||
print(f"Non-increasing segments: {non_increasing_segments}")
|
||||
|
||||
# Check for frequency plateaus overall
|
||||
from collections import Counter
|
||||
overall_freq_count = Counter(all_counts)
|
||||
plateaus = [(freq, count) for freq, count in overall_freq_count.items() if count > 1]
|
||||
plateaus_sorted = sorted(plateaus, key=lambda x: x[1], reverse=True)[:10]
|
||||
print(f"Top 10 frequency plateaus (freq: count of IDs sharing that freq):")
|
||||
for freq, count in plateaus_sorted:
|
||||
print(f" {freq}: {count} IDs")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,177 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive frequency distribution analysis
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import math
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
def main():
|
||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
pairs = data.get('pairs', {})
|
||||
|
||||
# Extract counts
|
||||
counts = []
|
||||
for key, pair in pairs.items():
|
||||
count = pair.get('count')
|
||||
if count is not None:
|
||||
counts.append(count)
|
||||
|
||||
n = len(counts)
|
||||
print(f"Total entries: {n}")
|
||||
|
||||
# Sort descending for rank-frequency analysis
|
||||
counts_sorted_desc = sorted(counts, reverse=True)
|
||||
|
||||
# Basic statistics
|
||||
min_count = min(counts)
|
||||
max_count = max(counts)
|
||||
mean_count = sum(counts) / n
|
||||
|
||||
# Percentiles
|
||||
percentiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
|
||||
print("\n=== PERCENTILE DISTRIBUTION ===")
|
||||
for p in percentiles:
|
||||
idx = int(p * n)
|
||||
value = counts_sorted_desc[idx]
|
||||
print(f"{p*100:5.1f}%: {value:>12} (rank ~{idx})")
|
||||
|
||||
# Cumulative distribution
|
||||
print("\n=== CUMULATIVE DISTRIBUTION ===")
|
||||
thresholds = [1, 2, 3, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 500000, 1000000, 5000000, 10000000, 50000000, 100000000, 500000000]
|
||||
for thresh in thresholds:
|
||||
if thresh > max_count:
|
||||
break
|
||||
below = sum(1 for c in counts if c <= thresh)
|
||||
above = sum(1 for c in counts if c >= thresh)
|
||||
print(f"Count <= {thresh:10}: {below:6} entries ({below/n*100:5.1f}%)")
|
||||
# print(f"Count >= {thresh:10}: {above:6} entries ({above/n*100:5.1f}%)")
|
||||
|
||||
# Check min_count=109 parameter
|
||||
print("\n=== ANALYSIS OF THRESHOLD 109 ===")
|
||||
below_109 = sum(1 for c in counts if c < 109)
|
||||
at_or_above_109 = sum(1 for c in counts if c >= 109)
|
||||
print(f"Entries with count < 109: {below_109} ({below_109/n*100:.1f}%)")
|
||||
print(f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109/n*100:.1f}%)")
|
||||
|
||||
# If 109 is a threshold, what's the actual min among those >= 109?
|
||||
counts_ge_109 = [c for c in counts if c >= 109]
|
||||
if counts_ge_109:
|
||||
actual_min_ge_109 = min(counts_ge_109)
|
||||
print(f"Actual min frequency among those >= 109: {actual_min_ge_109}")
|
||||
|
||||
# Rank-frequency analysis (Zipf's law)
|
||||
print("\n=== RANK-FREQUENCY ANALYSIS (Top 100) ===")
|
||||
print("Rank\tFrequency\tlog(rank)\tlog(freq)")
|
||||
for rank in range(1, 101):
|
||||
freq = counts_sorted_desc[rank-1]
|
||||
print(f"{rank}\t{freq}\t{math.log(rank):.3f}\t{math.log(freq):.3f}")
|
||||
|
||||
# Frequency spectrum (how many distinct frequencies)
|
||||
freq_counter = Counter(counts)
|
||||
print(f"\n=== FREQUENCY SPECTRUM ===")
|
||||
print(f"Distinct frequency values: {len(freq_counter)}")
|
||||
|
||||
# Most common frequencies
|
||||
print("\nTop 20 most common frequencies (plateau sizes):")
|
||||
for freq, freq_count in freq_counter.most_common(20):
|
||||
print(f" Frequency {freq}: {freq_count} entries")
|
||||
|
||||
# Analyze ID ranges
|
||||
print("\n=== ID RANGE ANALYSIS ===")
|
||||
# Build ID to count mapping
|
||||
id_to_count = {}
|
||||
for key, pair in pairs.items():
|
||||
char_id = pair.get('id')
|
||||
count = pair.get('count')
|
||||
if char_id is not None and count is not None:
|
||||
id_to_count[char_id] = count
|
||||
|
||||
ranges = [
|
||||
(0, 100, "Top 100 IDs"),
|
||||
(100, 500, "IDs 100-500"),
|
||||
(500, 1000, "IDs 500-1000"),
|
||||
(1000, 2000, "IDs 1000-2000"),
|
||||
(2000, 5000, "IDs 2000-5000"),
|
||||
(5000, 5500, "IDs 5000-5500 (user mentioned)"),
|
||||
(5500, 6000, "IDs 5500-6000"),
|
||||
(10000, 10500, "IDs 10000-10500"),
|
||||
(15000, 15500, "IDs 15000-15500"),
|
||||
(19000, 19500, "IDs 19000-19500 (before freq=1)"),
|
||||
(19499, 20647, "IDs with freq=1"),
|
||||
]
|
||||
|
||||
for start, end, label in ranges:
|
||||
range_counts = [id_to_count[id] for id in range(start, end) if id in id_to_count]
|
||||
if range_counts:
|
||||
min_c = min(range_counts)
|
||||
max_c = max(range_counts)
|
||||
mean_c = sum(range_counts) / len(range_counts)
|
||||
median_c = sorted(range_counts)[len(range_counts)//2]
|
||||
print(f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}")
|
||||
|
||||
# Check if IDs are perfectly sorted by frequency
|
||||
print("\n=== ID ORDER VERIFICATION ===")
|
||||
all_ids = sorted(id_to_count.keys())
|
||||
all_counts = [id_to_count[id] for id in all_ids]
|
||||
|
||||
# Check for any violations of non-increasing order
|
||||
violations = 0
|
||||
for i in range(1, len(all_counts)):
|
||||
if all_counts[i] > all_counts[i-1]:
|
||||
violations += 1
|
||||
if violations <= 5:
|
||||
print(f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i-1]} (ID {all_ids[i-1]})")
|
||||
|
||||
print(f"Total violations of non-increasing order: {violations}")
|
||||
|
||||
# Check if equal frequencies are grouped together
|
||||
print("\n=== FREQUENCY GROUPING ANALYSIS ===")
|
||||
current_freq = None
|
||||
group_start = None
|
||||
group_sizes = []
|
||||
|
||||
for i, (id, count) in enumerate(zip(all_ids, all_counts)):
|
||||
if count != current_freq:
|
||||
if current_freq is not None:
|
||||
group_sizes.append((current_freq, group_start, all_ids[i-1], i - group_start))
|
||||
current_freq = count
|
||||
group_start = i
|
||||
|
||||
# Last group
|
||||
if current_freq is not None:
|
||||
group_sizes.append((current_freq, group_start, all_ids[-1], len(all_ids) - group_start))
|
||||
|
||||
# Sort groups by size
|
||||
group_sizes.sort(key=lambda x: x[3], reverse=True)
|
||||
print("Top 10 largest frequency groups (plateaus):")
|
||||
for freq, start_id_idx, end_id, size in group_sizes[:10]:
|
||||
start_id = all_ids[start_id_idx]
|
||||
print(f" Frequency {freq}: IDs {start_id}-{end_id} ({size} entries)")
|
||||
|
||||
# Summary for smoothing algorithm
|
||||
print("\n=== SMOOTHING ALGORITHM IMPLICATIONS ===")
|
||||
print("1. IDs are perfectly sorted by frequency (non-increasing).")
|
||||
print(f"2. Frequency range: {min_count} to {max_count} (ratio {max_count/min_count:.1e}:1).")
|
||||
print(f"3. {below_109} entries ({below_109/n*100:.1f}%) have frequency < 109.")
|
||||
print(f"4. Median frequency: {counts_sorted_desc[n//2]}.")
|
||||
print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9*n)]}.")
|
||||
print(f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01*n)]}.")
|
||||
print("7. Large frequency plateaus exist (many IDs share same frequency).")
|
||||
print("8. Smoothing should handle extreme frequency ratios (1:5e8).")
|
||||
|
||||
# Save data for plotting
|
||||
with open("rank_freq.csv", "w") as f:
|
||||
f.write("rank,frequency\n")
|
||||
for rank, freq in enumerate(counts_sorted_desc, 1):
|
||||
f.write(f"{rank},{freq}\n")
|
||||
print("\nRank-frequency data saved to rank_freq.csv")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
# 破词训练设计文档
|
||||
|
||||
## 背景
|
||||
|
||||
输入法用户在实际使用中通常是**逐词输入**的,而非逐字输入。例如输入"那边的特别漂亮的女孩是我的表姐"时,用户可能分词为:
|
||||
|
||||
```
|
||||
那边 / 的 / 特别 / 漂亮 / 的 / 女孩 / 是 / 我 / 的 / 表姐
|
||||
```
|
||||
|
||||
但为了增强模型的泛化能力,需要模拟用户**从词中间断开**的情况。例如用户可能只输入了"漂"就开始选字"亮"。
|
||||
|
||||
## 破词概念
|
||||
|
||||
### 术语定义
|
||||
|
||||
| 术语 | 说明 |
|
||||
|------|------|
|
||||
| 整词输入 | 用户输入完整词的拼音,如"piaoliang" |
|
||||
| 破词输入 | 用户只输入词的部分拼音,如"piao" |
|
||||
| 前缀 | 光标前的已确认文本 |
|
||||
| 拼音 | 当前待选字的拼音(可能不完整) |
|
||||
| 后缀 | 光标后的原文内容 |
|
||||
|
||||
### 场景示例
|
||||
|
||||
以词"漂亮"为例:
|
||||
|
||||
**整词模式(90%概率):**
|
||||
```
|
||||
光标前: 那边的特别
|
||||
拼音: piaoliang
|
||||
预测: 漂 → 亮
|
||||
```
|
||||
|
||||
**破词模式(10%概率):**
|
||||
```
|
||||
光标前: 那边的特别漂
|
||||
拼音: liang
|
||||
预测: 亮
|
||||
```
|
||||
|
||||
## 实现方案
|
||||
|
||||
### 分词策略
|
||||
|
||||
使用 jieba 分词器进行词语边界识别:
|
||||
|
||||
```python
|
||||
import jieba
|
||||
words = list(jieba.cut(text, HMM=False))
|
||||
# "那边的特别漂亮的女孩是我的表姐。"
|
||||
# → ['那边', '的', '特别', '漂亮', '的', '女孩', '是', '我', '的', '表姐', '。']
|
||||
```
|
||||
|
||||
### 两阶段样本生成
|
||||
|
||||
每个词生成样本时分为两个阶段:
|
||||
|
||||
#### Phase 1:前缀/整词阶段
|
||||
|
||||
- **整词(90%)**:`prefix_positions = 整个词的所有字符`
|
||||
- **破词前缀(10%)**:`prefix_positions = 词的前 break_pos 个字符`
|
||||
|
||||
```python
|
||||
if should_break:
|
||||
break_pos = random.randint(1, word_len_chars - 1) # 随机破开位置
|
||||
else:
|
||||
break_pos = word_len_chars # 整词
|
||||
```
|
||||
|
||||
#### Phase 2:破词续接阶段(仅当破词时)
|
||||
|
||||
当破词发生时,从断点位置开始继续采样:
|
||||
|
||||
```python
|
||||
if should_break and break_pos < word_len_chars:
|
||||
cont_start = char_positions[break_pos]
|
||||
# 从断点开始采样后续字符
|
||||
target_len = random_choice(1-8) # 采样长度
|
||||
cont_positions = [cont_start, ...] # 后续字符位置
|
||||
```
|
||||
|
||||
### 样本结构
|
||||
|
||||
每个字符生成一个训练样本,包含:
|
||||
|
||||
| 字段 | 说明 | 示例 |
|
||||
|------|------|------|
|
||||
| `part1` (prefix) | 光标前文本 | "那边的特别漂" |
|
||||
| `part2` (pinyin) | 当前字拼音 | "liang" |
|
||||
| `part3` (suffix) | 光标后文本 | "亮的女孩是我的表姐" |
|
||||
| `part4` | 专有词提示 | "漂亮\|特别" |
|
||||
| `label` | 目标汉字ID | 1234 |
|
||||
| `history_slot_ids` | 历史已确认字 | [0, 0, 0, 0, 0, 0, 0, 0] |
|
||||
|
||||
### 拼音增强策略
|
||||
|
||||
根据 `py_style_weight` 参数,拼音有以下三种形式:
|
||||
|
||||
| 形式 | 概率 | 示例 |
|
||||
|------|------|------|
|
||||
| 完整拼音 | 75% (9/12) | "piaoliang" |
|
||||
| 仅声母 | 16.7% (2/12) | "pl" (通过 to_initials) |
|
||||
| 仅首字母 | 8.3% (1/12) | "p" |
|
||||
|
||||
参数配置:`py_style_weight=(9, 2, 1)`
|
||||
|
||||
## 破词概率控制
|
||||
|
||||
### word_break_prob 参数
|
||||
|
||||
控制每个词从中间断开的概率,默认为 **10%**:
|
||||
|
||||
```python
|
||||
self.word_break_prob = 0.10 # 10%概率从词中间破开
|
||||
```
|
||||
|
||||
### 破词位置分布
|
||||
|
||||
对于长度为 N 的词,破开位置 `break_pos` 的分布:
|
||||
|
||||
```python
|
||||
break_pos = random.randint(1, N - 1)
|
||||
```
|
||||
|
||||
- 2字词:break_pos = 1(100%在第1字后破开)
|
||||
- 3字词:break_pos = 1 或 2(各50%)
|
||||
- 4字词:break_pos = 1, 2, 或 3(各33%)
|
||||
|
||||
## 数据分布预期
|
||||
|
||||
### 理想分布
|
||||
|
||||
| 类别 | 预期比例 |
|
||||
|------|----------|
|
||||
| 单字样本 | ~15% |
|
||||
| 2字词整词 | ~30% |
|
||||
| 3字词整词 | ~20% |
|
||||
| 破词样本 | ~10% |
|
||||
| 其他 | ~25% |
|
||||
|
||||
### 拼音不完整率
|
||||
|
||||
由于 `py_style_weight=(9, 2, 1)`:
|
||||
|
||||
- 声母(initials):~16.7%
|
||||
- 首字母:~8.3%
|
||||
- **总计不完整**:~25%
|
||||
|
||||
## 代码实现位置
|
||||
|
||||
主要实现文件:`src/model/dataset.py`
|
||||
|
||||
| 函数/类 | 行号 | 功能 |
|
||||
|---------|------|------|
|
||||
| `segment_text()` | ~30 | jieba分词 |
|
||||
| `build_word_boundaries()` | ~35 | 建立词边界映射 |
|
||||
| `PinyinInputDataset.__iter__()` | ~280 | 核心迭代逻辑 |
|
||||
| `get_mask_pinyin()` | ~215 | 拼音加强处理 |
|
||||
| `_add_word_samples()` | ~240 | 样本构建 |
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **破词仅针对多字词**:单字词(如"的"、“是”)不会破词
|
||||
2. **破词保持语义完整**:破词后仍能根据上下文预测正确汉字
|
||||
3. **历史槽位模拟逐步确认**:同一词内已确认的字会填入 `history_slot_ids`
|
||||
4. **10% EOS标记**:词尾有10%概率追加ID=0表示句子结束
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
Frequency Analysis Results
|
||||
==================================================
|
||||
Min frequency: 1
|
||||
Max frequency: 494748360
|
||||
Mean frequency: 560007.90
|
||||
Standard deviation: 5730144.34
|
||||
10th percentile: 3
|
||||
50th percentile: 93
|
||||
90th percentile: 331538
|
||||
IDs in range 5000-5500 min: 5594
|
||||
IDs in range 5000-5500 max: 9569
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -30,6 +30,8 @@ dependencies = [
|
|||
[project.scripts]
|
||||
train-model = "model.trainer:app"
|
||||
monitor-training = "model.monitor:app"
|
||||
preprocess-model = "model.preprocess:main"
|
||||
inspect-preprocessed = "model.inspect_preprocessed:main"
|
||||
|
||||
[tool.uv]
|
||||
# 设置当前项目的默认索引源
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -22,9 +22,26 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from model.dataset import PinyinInputDataset
|
||||
from model.query import QueryEngine
|
||||
|
||||
|
||||
def analyze_label_distribution(dataset: PinyinInputDataset, sample_size: int = 10000):
|
||||
_id2char_cache = {}
|
||||
|
||||
|
||||
def get_char_by_id(query_engine: QueryEngine, char_id: int) -> str:
|
||||
if char_id == 0:
|
||||
return "<EOS>"
|
||||
if char_id not in _id2char_cache:
|
||||
info = query_engine.query_by_id(char_id)
|
||||
_id2char_cache[char_id] = info.char if info else f"<ID:{char_id}>"
|
||||
return _id2char_cache[char_id]
|
||||
|
||||
|
||||
def analyze_label_distribution(
|
||||
dataset: PinyinInputDataset,
|
||||
sample_size: int = 10000,
|
||||
query_engine: QueryEngine = None,
|
||||
):
|
||||
"""分析 label 在指定区间的分布"""
|
||||
target_ranges = [
|
||||
(0, 10),
|
||||
|
|
@ -63,13 +80,23 @@ def analyze_label_distribution(dataset: PinyinInputDataset, sample_size: int = 1
|
|||
in_target_range = True
|
||||
|
||||
if len(all_examples) < 200:
|
||||
label_char = (
|
||||
get_char_by_id(query_engine, label) if query_engine else f"<ID:{label}>"
|
||||
)
|
||||
history_chars = (
|
||||
[get_char_by_id(query_engine, hid) for hid in history]
|
||||
if query_engine
|
||||
else history
|
||||
)
|
||||
all_examples.append(
|
||||
{
|
||||
"label": label,
|
||||
"label_char": label_char,
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"pinyin": pinyin,
|
||||
"history": history,
|
||||
"history_chars": history_chars,
|
||||
"part4": part4,
|
||||
}
|
||||
)
|
||||
|
|
@ -96,12 +123,13 @@ def analyze_label_distribution(dataset: PinyinInputDataset, sample_size: int = 1
|
|||
random.shuffle(all_examples)
|
||||
for idx, ex in enumerate(all_examples[:20], 1):
|
||||
print(f"\n样本 {idx}:")
|
||||
print(f" Label: {ex['label']}")
|
||||
print(f" Label: {ex['label']} ({ex['label_char']})")
|
||||
print(f" Part4: {ex['part4']}")
|
||||
print(f" 光标前: {ex['prefix']}")
|
||||
print(f" 光标后: {ex['suffix']}")
|
||||
print(f" 拼音: {ex['pinyin']}")
|
||||
print(f" 历史槽位: {ex['history']}")
|
||||
print(f" 历史汉字: {ex['history_chars']}")
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -127,7 +155,9 @@ def main():
|
|||
help="数据集路径 (本地文件或HuggingFace路径)",
|
||||
)
|
||||
parser.add_argument("--sample_size", type=int, default=10000, help="采样大小")
|
||||
parser.add_argument("--max_workers", type=int, default=-1, help="DataLoader workers")
|
||||
parser.add_argument(
|
||||
"--max_workers", type=int, default=-1, help="DataLoader workers"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"加载数据集: {args.data_path}")
|
||||
|
|
@ -148,7 +178,11 @@ def main():
|
|||
print(" 3. 如果是 HuggingFace 数据集,路径应该正确")
|
||||
return
|
||||
|
||||
analyze_label_distribution(dataset, sample_size=args.sample_size)
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load()
|
||||
analyze_label_distribution(
|
||||
dataset, sample_size=args.sample_size, query_engine=query_engine
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -62,10 +62,13 @@ class PinyinInputDataset(IterableDataset):
|
|||
max_iter_length=1e6,
|
||||
max_seq_length=128,
|
||||
text_field: str = "text",
|
||||
py_style_weight=(90, 2, 1),
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size: int = 100000,
|
||||
retention_ratio: float = 0.8,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
merge_short_words_prob: float = 0.5,
|
||||
merge_max_short_words: int = 3,
|
||||
merge_max_total_chars: int = 6,
|
||||
):
|
||||
# 频率调整参数 (可根据需要调整)
|
||||
self.drop_start_freq = 10_000_000
|
||||
|
|
@ -76,6 +79,9 @@ class PinyinInputDataset(IterableDataset):
|
|||
self.word_break_prob = 0.10
|
||||
self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||||
self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
||||
self.merge_short_words_prob = merge_short_words_prob
|
||||
self.merge_max_short_words = merge_max_short_words
|
||||
self.merge_max_total_chars = merge_max_total_chars
|
||||
|
||||
jieba.initialize()
|
||||
|
||||
|
|
@ -259,8 +265,10 @@ class PinyinInputDataset(IterableDataset):
|
|||
repeats = max(1, int(base_repeats * weight))
|
||||
|
||||
history = labels[:label_idx]
|
||||
len_h = len(history)
|
||||
history.extend([0] * (8 - len_h))
|
||||
if len(history) > 8:
|
||||
history = history[-8:]
|
||||
else:
|
||||
history.extend([0] * (8 - len(history)))
|
||||
|
||||
sample_dict = {
|
||||
"input_ids": encoded["input_ids"],
|
||||
|
|
@ -291,7 +299,12 @@ class PinyinInputDataset(IterableDataset):
|
|||
if worker_id >= num_workers:
|
||||
return
|
||||
|
||||
worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
|
||||
try:
|
||||
worker_dataset = self.dataset.shard(
|
||||
num_shards=num_workers, index=worker_id
|
||||
)
|
||||
except (IndexError, ValueError):
|
||||
worker_dataset = self.dataset
|
||||
|
||||
total_quota = int(self.max_iter_length)
|
||||
base_quota = total_quota // num_workers
|
||||
|
|
@ -321,17 +334,58 @@ class PinyinInputDataset(IterableDataset):
|
|||
word_boundaries = build_word_boundaries(words)
|
||||
pinyin_list = self.generate_pinyin(text)
|
||||
|
||||
for word_start, word_end in word_boundaries:
|
||||
idx = 0
|
||||
while idx < len(word_boundaries):
|
||||
word_start, word_end = word_boundaries[idx]
|
||||
|
||||
char_positions = []
|
||||
for i in range(word_start, word_end):
|
||||
if self.query_engine.is_chinese_char(text[i]):
|
||||
char_positions.append(i)
|
||||
|
||||
if not char_positions:
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
word_len_chars = len(char_positions)
|
||||
|
||||
merge_end_idx = idx + 1
|
||||
if word_len_chars <= 2:
|
||||
accumulated_positions = list(char_positions)
|
||||
accumulated_count = 1
|
||||
next_idx = idx + 1
|
||||
|
||||
while next_idx < len(word_boundaries):
|
||||
ns, ne = word_boundaries[next_idx]
|
||||
next_positions = []
|
||||
for i in range(ns, ne):
|
||||
if self.query_engine.is_chinese_char(text[i]):
|
||||
next_positions.append(i)
|
||||
next_len = len(next_positions)
|
||||
|
||||
if next_len == 0 or next_len > 2:
|
||||
break
|
||||
if (
|
||||
len(accumulated_positions) + next_len
|
||||
> self.merge_max_total_chars
|
||||
):
|
||||
break
|
||||
if accumulated_count + 1 > self.merge_max_short_words:
|
||||
break
|
||||
if random.random() > self.merge_short_words_prob:
|
||||
break
|
||||
|
||||
accumulated_positions.extend(next_positions)
|
||||
accumulated_count += 1
|
||||
next_idx += 1
|
||||
|
||||
if accumulated_count > 1:
|
||||
char_positions = accumulated_positions
|
||||
word_len_chars = len(char_positions)
|
||||
merge_end_idx = next_idx
|
||||
word_start = word_boundaries[idx][0]
|
||||
word_end = word_boundaries[next_idx - 1][1]
|
||||
|
||||
should_break = (
|
||||
word_len_chars > 1 and random.random() < self.word_break_prob
|
||||
)
|
||||
|
|
@ -364,6 +418,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
logger.error(
|
||||
f"e: {e}, (text, pinyin): {prefix_text} - {prefix_pinyin}"
|
||||
)
|
||||
idx = merge_end_idx
|
||||
continue
|
||||
|
||||
# 整词末尾 10% 概率追加 EOS(破词前缀不加)
|
||||
|
|
@ -448,6 +503,7 @@ class PinyinInputDataset(IterableDataset):
|
|||
logger.error(
|
||||
f"e: {e}, (text, pinyin): {cont_text} - {cont_pinyin}"
|
||||
)
|
||||
idx = merge_end_idx
|
||||
continue
|
||||
|
||||
# 续接末尾 10% 概率追加 EOS
|
||||
|
|
@ -486,6 +542,8 @@ class PinyinInputDataset(IterableDataset):
|
|||
pinyin_ids_cont,
|
||||
)
|
||||
|
||||
idx = merge_end_idx
|
||||
|
||||
# 处理shuffle buffer - 单缓冲区半保留方案
|
||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||
indices = np.random.permutation(len(batch_samples))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,401 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
预处理数据质量分析脚本
|
||||
|
||||
功能:
|
||||
1. 统计 labels 的分布(出现次数、比例,最大/最小,未出现标签数)
|
||||
2. 随机抽样还原为人类可读文本,导出为 CSV 文件
|
||||
|
||||
用法:
|
||||
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train
|
||||
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train --num-samples 50 --output samples.csv
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import random
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from tqdm import tqdm
|
||||
|
||||
from .char_info import CharInfo
|
||||
from .dataset import CHAR_TO_ID
|
||||
from .preprocessed_dataset import PreProcessedDataset
|
||||
from .query import QueryEngine
|
||||
|
||||
ID_TO_CHAR = {v: k for k, v in CHAR_TO_ID.items()}
|
||||
|
||||
|
||||
def decode_pinyin_ids(pinyin_ids: list) -> str:
|
||||
"""将 pinyin_ids 还原为拼音字符串"""
|
||||
chars = []
|
||||
for pid in pinyin_ids:
|
||||
if pid == 0:
|
||||
break
|
||||
chars.append(ID_TO_CHAR.get(pid, "?"))
|
||||
return "".join(chars)
|
||||
|
||||
|
||||
def decode_history(history_ids: list, query_engine: QueryEngine) -> str:
|
||||
"""将 history_slot_ids 还原为文字"""
|
||||
parts = []
|
||||
for hid in history_ids:
|
||||
if hid == 0:
|
||||
parts.append("<PAD>")
|
||||
else:
|
||||
info = query_engine.query_by_id(hid)
|
||||
if info is not None:
|
||||
parts.append(f"{info.char}({info.pinyin})")
|
||||
else:
|
||||
parts.append(f"<ID:{hid}>")
|
||||
return " | ".join(parts)
|
||||
|
||||
|
||||
def analyze_labels(dataset: PreProcessedDataset, max_shards: int = 0):
|
||||
"""统计 labels 分布,带进度条"""
|
||||
logger.info("正在统计 labels 分布...")
|
||||
counter = Counter()
|
||||
total = 0
|
||||
|
||||
num_shards = dataset._num_shards if dataset._is_sharded else 1
|
||||
effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards
|
||||
|
||||
pbar = tqdm(range(effective_shards), desc="统计 labels", unit="shard")
|
||||
|
||||
for shard_idx in pbar:
|
||||
if dataset._is_sharded:
|
||||
shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz"))
|
||||
labels = shard_data["labels"].astype(np.int64)
|
||||
else:
|
||||
labels = dataset.labels[:].astype(np.int64)
|
||||
|
||||
unique, counts = np.unique(labels, return_counts=True)
|
||||
for uid, cnt in zip(unique, counts):
|
||||
counter[int(uid)] += cnt
|
||||
total += len(labels)
|
||||
|
||||
if dataset._is_sharded:
|
||||
del shard_data
|
||||
|
||||
return counter, total
|
||||
|
||||
|
||||
def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
|
||||
"""将一个样本还原为人类可读格式"""
|
||||
input_ids = (
|
||||
sample["input_ids"].tolist()
|
||||
if hasattr(sample["input_ids"], "tolist")
|
||||
else sample["input_ids"]
|
||||
)
|
||||
token_type_ids = (
|
||||
sample["token_type_ids"].tolist()
|
||||
if hasattr(sample["token_type_ids"], "tolist")
|
||||
else sample["token_type_ids"]
|
||||
)
|
||||
labels = (
|
||||
sample["labels"].item()
|
||||
if hasattr(sample["labels"], "item")
|
||||
else sample["labels"]
|
||||
)
|
||||
history_ids = (
|
||||
sample["history_slot_ids"].tolist()
|
||||
if hasattr(sample["history_slot_ids"], "tolist")
|
||||
else sample["history_slot_ids"]
|
||||
)
|
||||
pinyin_ids = (
|
||||
sample["pinyin_ids"].tolist()
|
||||
if hasattr(sample["pinyin_ids"], "tolist")
|
||||
else sample["pinyin_ids"]
|
||||
)
|
||||
|
||||
# 还原 token 文本
|
||||
token_text = tokenizer.decode(input_ids, skip_special_tokens=False)
|
||||
|
||||
# 找到 token_type_ids 切换点,分离 sentence A 和 sentence B
|
||||
sep_positions = [i for i, tid in enumerate(token_type_ids) if tid == 1]
|
||||
if sep_positions:
|
||||
sep_start = sep_positions[0]
|
||||
sent_a_ids = [
|
||||
tid
|
||||
for tid, tt in zip(input_ids[:sep_start], token_type_ids[:sep_start])
|
||||
if tt == 0
|
||||
]
|
||||
sent_b_ids = [tid for tid, tt in zip(input_ids, token_type_ids) if tt == 1]
|
||||
else:
|
||||
sent_a_ids = input_ids
|
||||
sent_b_ids = []
|
||||
|
||||
context_text = tokenizer.decode(sent_a_ids, skip_special_tokens=True)
|
||||
suffix_text = (
|
||||
tokenizer.decode(sent_b_ids, skip_special_tokens=True) if sent_b_ids else ""
|
||||
)
|
||||
|
||||
pinyin_str = decode_pinyin_ids(pinyin_ids)
|
||||
label_info = query_engine.query_by_id(labels)
|
||||
history_str = decode_history(history_ids, query_engine)
|
||||
|
||||
return {
|
||||
"context": context_text,
|
||||
"suffix": suffix_text,
|
||||
"pinyin": pinyin_str,
|
||||
"label_id": labels,
|
||||
"label_char": f"{label_info.char}({label_info.pinyin})"
|
||||
if label_info
|
||||
else f"<ID:{labels}>",
|
||||
"label_count": label_info.count if label_info else 0,
|
||||
"history": history_str,
|
||||
"full_tokens": token_text,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
console = Console()
|
||||
|
||||
parser = argparse.ArgumentParser(description="预处理数据质量分析")
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="预处理数据目录(train/ 或 eval/)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=50,
|
||||
help="随机抽样的样本数量(默认50)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="CSV 输出文件路径(默认: <data-dir>/samples.csv)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-shards",
|
||||
type=int,
|
||||
default=0,
|
||||
help="统计 labels 时最多读取的分片数(0=全部)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="随机种子",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=30,
|
||||
help="显示出现次数最多和最少的标签数量",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
if args.output is None:
|
||||
args.output = str(Path(args.data_dir) / "samples.csv")
|
||||
|
||||
# 加载数据集
|
||||
logger.info(f"加载数据集: {args.data_dir}")
|
||||
dataset = PreProcessedDataset(args.data_dir)
|
||||
console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]")
|
||||
if dataset._is_sharded:
|
||||
console.print(
|
||||
f" 分片数: {dataset._num_shards}, 每分片: {dataset._shard_size:,} 样本"
|
||||
)
|
||||
console.print()
|
||||
|
||||
# 加载 QueryEngine
|
||||
logger.info("加载 QueryEngine...")
|
||||
query_engine = QueryEngine()
|
||||
query_engine.load()
|
||||
|
||||
# 加载 Tokenizer
|
||||
logger.info("加载 Tokenizer...")
|
||||
from importlib.resources import files as pkg_files
|
||||
from modelscope import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
Path(str(pkg_files(__package__))) / "assets" / "tokenizer"
|
||||
)
|
||||
|
||||
# ====== 1. Labels 分布分析 ======
|
||||
console.print("[bold yellow]====== Labels 分布分析 ======[/bold yellow]")
|
||||
counter, total = analyze_labels(dataset, max_shards=args.max_shards)
|
||||
|
||||
# 获取词表总大小
|
||||
vocab_size = len(query_engine._id_to_info) # 不含 EOS (id=0)
|
||||
|
||||
appeared_ids = set(counter.keys())
|
||||
all_ids = set(range(0, vocab_size + 1)) # +1 包含 id=0 (EOS)
|
||||
missing_ids = all_ids - appeared_ids
|
||||
|
||||
console.print(f"\n总样本数: {total:,}")
|
||||
console.print(f"词表大小: {vocab_size + 1:,} (含 EOS)")
|
||||
console.print(f"唯一标签数: {len(counter):,}")
|
||||
console.print(
|
||||
f"EOS (id=0) 出现次数: {counter.get(0, 0):,} ({counter.get(0, 0) / total * 100:.2f}%)"
|
||||
)
|
||||
console.print(
|
||||
f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size + 1:,} ({len(missing_ids) / (vocab_size + 1) * 100:.2f}%)[/bold red]"
|
||||
)
|
||||
|
||||
most_common = counter.most_common(args.top_k)
|
||||
least_common = (
|
||||
counter.most_common()[: -args.top_k - 1 : -1]
|
||||
if len(counter) > args.top_k
|
||||
else counter.most_common()
|
||||
)
|
||||
|
||||
# 最多标签表
|
||||
table_top = Table(
|
||||
title=f"出现次数最多的 {args.top_k} 个标签",
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
)
|
||||
table_top.add_column("排名", style="cyan", width=6)
|
||||
table_top.add_column("ID", style="green", width=8)
|
||||
table_top.add_column("字符(拼音)", style="yellow", width=20)
|
||||
table_top.add_column("频次", style="red", width=12)
|
||||
table_top.add_column("占比", style="blue", width=10)
|
||||
|
||||
for rank, (label_id, count) in enumerate(most_common, 1):
|
||||
info = query_engine.query_by_id(label_id)
|
||||
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
|
||||
pct = count / total * 100
|
||||
table_top.add_row(
|
||||
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.3f}%"
|
||||
)
|
||||
console.print(table_top)
|
||||
|
||||
# 最少标签表
|
||||
table_bottom = Table(
|
||||
title=f"出现次数最少的 {min(args.top_k, len(counter))} 个标签",
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
)
|
||||
table_bottom.add_column("排名", style="cyan", width=6)
|
||||
table_bottom.add_column("ID", style="green", width=8)
|
||||
table_bottom.add_column("字符(拼音)", style="yellow", width=20)
|
||||
table_bottom.add_column("频次", style="red", width=12)
|
||||
table_bottom.add_column("占比", style="blue", width=10)
|
||||
|
||||
for rank, (label_id, count) in enumerate(least_common, 1):
|
||||
info = query_engine.query_by_id(label_id)
|
||||
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
|
||||
pct = count / total * 100
|
||||
table_bottom.add_row(
|
||||
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.6f}%"
|
||||
)
|
||||
console.print(table_bottom)
|
||||
|
||||
# 频次分布概览
|
||||
table_dist = Table(
|
||||
title="频次分布概览", show_header=True, header_style="bold magenta"
|
||||
)
|
||||
table_dist.add_column("频次区间", style="cyan")
|
||||
table_dist.add_column("标签数", style="green")
|
||||
table_dist.add_column("占总标签数比例", style="yellow")
|
||||
|
||||
bins = [
|
||||
(1, 10),
|
||||
(11, 100),
|
||||
(101, 1000),
|
||||
(1001, 10000),
|
||||
(10001, 100000),
|
||||
(100001, 1000000),
|
||||
(1000001, float("inf")),
|
||||
]
|
||||
for lo, hi in bins:
|
||||
count_in_bin = sum(1 for c in counter.values() if lo <= c <= hi)
|
||||
if count_in_bin > 0:
|
||||
hi_str = str(int(hi)) if hi != float("inf") else "∞"
|
||||
table_dist.add_row(
|
||||
f"{lo}-{hi_str}",
|
||||
f"{count_in_bin:,}",
|
||||
f"{count_in_bin / len(counter) * 100:.1f}%",
|
||||
)
|
||||
# 未出现
|
||||
if len(missing_ids) > 0:
|
||||
table_dist.add_row(
|
||||
"未出现",
|
||||
f"{len(missing_ids):,}",
|
||||
f"{len(missing_ids) / (vocab_size + 1) * 100:.1f}%",
|
||||
)
|
||||
console.print(table_dist)
|
||||
|
||||
# ====== 2. 随机抽样还原 → CSV ======
|
||||
num_samples = min(args.num_samples, len(dataset))
|
||||
console.print(
|
||||
f"\n[bold yellow]====== 随机抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
|
||||
)
|
||||
|
||||
indices = random.sample(range(len(dataset)), num_samples)
|
||||
|
||||
csv_path = Path(args.output)
|
||||
csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
csv_headers = [
|
||||
"index",
|
||||
"pinyin",
|
||||
"label_char",
|
||||
"label_id",
|
||||
"label_count",
|
||||
"context",
|
||||
"suffix",
|
||||
"history",
|
||||
"full_tokens",
|
||||
]
|
||||
|
||||
with open(csv_path, "w", encoding="utf-8", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(csv_headers)
|
||||
|
||||
for i, idx in enumerate(tqdm(indices, desc="解码样本", unit="sample")):
|
||||
sample = dataset[idx]
|
||||
decoded = decode_sample(sample, tokenizer, query_engine)
|
||||
writer.writerow(
|
||||
[
|
||||
idx,
|
||||
decoded["pinyin"],
|
||||
decoded["label_char"],
|
||||
decoded["label_id"],
|
||||
decoded["label_count"],
|
||||
decoded["context"],
|
||||
decoded["suffix"],
|
||||
decoded["history"],
|
||||
decoded["full_tokens"],
|
||||
]
|
||||
)
|
||||
|
||||
console.print(
|
||||
f"[bold green]✓ 已导出 {num_samples} 个样本到 {csv_path}[/bold green]"
|
||||
)
|
||||
|
||||
# 打印前 5 个样本的概要
|
||||
console.print(f"\n[bold cyan]前 {min(5, num_samples)} 个样本概览:[/bold cyan]")
|
||||
with open(csv_path, "r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for i, row in enumerate(reader):
|
||||
if i >= 5:
|
||||
break
|
||||
console.print(
|
||||
f" [{i + 1}] 拼音={row['pinyin']} 目标={row['label_char']} "
|
||||
f"上下文={row['context'][:50]}..."
|
||||
)
|
||||
|
||||
console.print("\n[bold green]分析完成[/bold green]")
|
||||
|
||||
|
||||
app = main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,335 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
预处理脚本:将 PinyinInputDataset 的输出转换为分片压缩 .npz 文件。
|
||||
|
||||
采用分片流式写入,内存峰值固定为 shard_size 级别,不随总样本数增长。
|
||||
每个分片使用 np.savez_compressed 保存(zlib 压缩),GPU 服务器无需解压到硬盘。
|
||||
|
||||
用法:
|
||||
python -m model.preprocess \
|
||||
--train-data-path "some/hf_dataset" \
|
||||
--eval-data-path "some/hf_dataset" \
|
||||
--output-dir ./preprocessed \
|
||||
--num-train-samples 5000000 \
|
||||
--num-eval-samples 8192
|
||||
|
||||
生成目录结构:
|
||||
output_dir/
|
||||
train/
|
||||
metadata.json
|
||||
shard_000.npz (5M样本, 6个字段, zlib压缩)
|
||||
shard_001.npz
|
||||
...
|
||||
eval/
|
||||
metadata.json
|
||||
shard_000.npz
|
||||
...
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .dataset import PinyinInputDataset
|
||||
from .trainer import collate_fn, worker_init_fn
|
||||
|
||||
FIELDS = [
|
||||
"input_ids",
|
||||
"token_type_ids",
|
||||
"attention_mask",
|
||||
"labels",
|
||||
"history_slot_ids",
|
||||
"pinyin_ids",
|
||||
]
|
||||
|
||||
|
||||
def _extract_batch(batch: dict, take: int) -> Dict[str, np.ndarray]:
|
||||
"""从 DataLoader batch 中提取指定数量的样本,转为 int16 numpy 数组"""
|
||||
result = {}
|
||||
for f in FIELDS:
|
||||
tensor = batch[f][:take]
|
||||
arr = tensor.numpy().astype(np.int16)
|
||||
if f == "labels" and arr.ndim > 1 and arr.shape[-1] == 1:
|
||||
arr = arr.squeeze(-1)
|
||||
result[f] = arr
|
||||
return result
|
||||
|
||||
|
||||
def collect_samples(
|
||||
dataloader: DataLoader,
|
||||
num_samples: int,
|
||||
output_dir: Path,
|
||||
split_name: str,
|
||||
max_seq_length: int = 128,
|
||||
shard_size: int = 5_000_000,
|
||||
) -> int:
|
||||
"""
|
||||
分片流式收集样本,每累积 shard_size 个样本保存为一个压缩 .npz 分片。
|
||||
|
||||
内存峰值 = shard_size × 每样本字节数(约578字节 @ shard_size=5M → 约2.9GB)
|
||||
"""
|
||||
split_dir = output_dir / split_name
|
||||
split_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shard_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
||||
shard_count = 0
|
||||
shard_idx = 0
|
||||
total = 0
|
||||
|
||||
pbar = tqdm(total=num_samples, desc=f"Processing {split_name}", unit="samples")
|
||||
|
||||
for batch in dataloader:
|
||||
batch_size = batch["input_ids"].size(0)
|
||||
remaining = num_samples - total
|
||||
if remaining <= 0:
|
||||
break
|
||||
take = min(batch_size, remaining)
|
||||
|
||||
extracted = _extract_batch(batch, take)
|
||||
for f in FIELDS:
|
||||
shard_buffers[f].append(extracted[f])
|
||||
|
||||
shard_count += take
|
||||
total += take
|
||||
pbar.update(take)
|
||||
|
||||
if shard_count >= shard_size:
|
||||
merged = {}
|
||||
for f in FIELDS:
|
||||
merged[f] = np.concatenate(shard_buffers[f], axis=0)
|
||||
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
|
||||
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
|
||||
|
||||
shard_idx += 1
|
||||
shard_buffers = {f: [] for f in FIELDS}
|
||||
shard_count = 0
|
||||
del merged
|
||||
gc.collect()
|
||||
|
||||
if total >= num_samples:
|
||||
break
|
||||
|
||||
# 写入最后一个不满的分片
|
||||
if shard_count > 0:
|
||||
merged = {}
|
||||
for f in FIELDS:
|
||||
merged[f] = np.concatenate(shard_buffers[f], axis=0)
|
||||
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
|
||||
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
|
||||
shard_idx += 1
|
||||
|
||||
pbar.close()
|
||||
|
||||
actual_count = min(total, num_samples)
|
||||
num_shards = shard_idx
|
||||
|
||||
metadata = {
|
||||
"num_samples": actual_count,
|
||||
"max_seq_length": max_seq_length,
|
||||
"dtype": "int16",
|
||||
"fields": FIELDS,
|
||||
"shard_size": shard_size,
|
||||
"num_shards": num_shards,
|
||||
}
|
||||
with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp:
|
||||
json.dump(metadata, fp, indent=2, ensure_ascii=False)
|
||||
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
|
||||
)
|
||||
logger.info(
|
||||
f"{split_name}: {actual_count} samples in {num_shards} shards, "
|
||||
f"{total_size / (1024**3):.2f} GB (compressed)"
|
||||
)
|
||||
|
||||
return actual_count
|
||||
|
||||
|
||||
def main():
|
||||
console = Console()
|
||||
|
||||
parser = argparse.ArgumentParser(description="预处理数据集为分片压缩npz文件")
|
||||
parser.add_argument(
|
||||
"--train-data-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="训练数据集路径(HuggingFace格式)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-data-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="评估数据集路径(HuggingFace格式)",
|
||||
)
|
||||
parser.add_argument("--output-dir", type=str, required=True, help="输出目录")
|
||||
parser.add_argument(
|
||||
"--num-train-samples", type=int, required=True, help="训练集样本数量"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-eval-samples", type=int, required=True, help="评估集样本数量"
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=128, help="批大小")
|
||||
parser.add_argument(
|
||||
"--num-workers", type=int, default=2, help="DataLoader worker数量"
|
||||
)
|
||||
parser.add_argument("--max-seq-length", type=int, default=128, help="最大序列长度")
|
||||
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
||||
parser.add_argument(
|
||||
"--shard-size",
|
||||
type=int,
|
||||
default=5_000_000,
|
||||
help="分片大小(样本数),控制内存峰值(默认500万,约2.9GB/分片未压缩)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--py-style-weight",
|
||||
type=str,
|
||||
default="9,2,1",
|
||||
help="拼音风格权重(逗号分隔)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle-buffer-size",
|
||||
type=int,
|
||||
default=2000000,
|
||||
help="数据集shuffle缓冲区大小",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--length-weights",
|
||||
type=str,
|
||||
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
|
||||
help="词长权重",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
|
||||
length_weights = {
|
||||
int(k): int(v)
|
||||
for k, v in (item.split(":") for item in args.length_weights.split(","))
|
||||
}
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
train_max_iter = args.num_train_samples * 5
|
||||
eval_max_iter = args.num_eval_samples * 5
|
||||
|
||||
shard_mem_gb = args.shard_size * 578 / (1024**3)
|
||||
console.print("[bold cyan]=== 数据预处理 ===[/bold cyan]")
|
||||
console.print(f"训练集目标: {args.num_train_samples:,} 样本")
|
||||
console.print(f"评估集目标: {args.num_eval_samples:,} 样本")
|
||||
console.print(f"输出目录: {output_dir}")
|
||||
console.print(f"数据类型: int16")
|
||||
console.print(
|
||||
f"分片大小: {args.shard_size:,} 样本 (约 {shard_mem_gb:.1f} GB/分片 未压缩)"
|
||||
)
|
||||
console.print()
|
||||
|
||||
num_train_workers = args.num_workers
|
||||
num_eval_workers = max(1, args.num_workers // 2)
|
||||
|
||||
console.print("[bold cyan]创建训练数据集...[/bold cyan]")
|
||||
train_dataset = PinyinInputDataset(
|
||||
data_path=args.train_data_path,
|
||||
max_workers=num_train_workers,
|
||||
max_iter_length=train_max_iter,
|
||||
max_seq_length=args.max_seq_length,
|
||||
text_field="text",
|
||||
py_style_weight=py_style_weight,
|
||||
shuffle_buffer_size=100,
|
||||
length_weights=length_weights,
|
||||
)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=num_train_workers,
|
||||
pin_memory=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=collate_fn,
|
||||
prefetch_factor=2,
|
||||
persistent_workers=True if num_train_workers > 0 else False,
|
||||
)
|
||||
|
||||
console.print("[bold cyan]创建评估数据集...[/bold cyan]")
|
||||
eval_dataset = PinyinInputDataset(
|
||||
data_path=args.eval_data_path,
|
||||
max_workers=num_eval_workers,
|
||||
max_iter_length=eval_max_iter,
|
||||
max_seq_length=args.max_seq_length,
|
||||
text_field="text",
|
||||
py_style_weight=py_style_weight,
|
||||
shuffle_buffer_size=100,
|
||||
length_weights=length_weights,
|
||||
)
|
||||
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=num_eval_workers,
|
||||
pin_memory=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=collate_fn,
|
||||
prefetch_factor=2,
|
||||
persistent_workers=True if num_eval_workers > 0 else False,
|
||||
)
|
||||
|
||||
logger.info("开始收集训练数据...")
|
||||
train_count = collect_samples(
|
||||
train_dataloader,
|
||||
args.num_train_samples,
|
||||
output_dir,
|
||||
"train",
|
||||
args.max_seq_length,
|
||||
args.shard_size,
|
||||
)
|
||||
|
||||
if train_count < args.num_train_samples:
|
||||
logger.warning(
|
||||
f"训练集样本不足: 目标 {args.num_train_samples}, 实际 {train_count}"
|
||||
)
|
||||
|
||||
logger.info("开始收集评估数据...")
|
||||
eval_count = collect_samples(
|
||||
eval_dataloader,
|
||||
args.num_eval_samples,
|
||||
output_dir,
|
||||
"eval",
|
||||
args.max_seq_length,
|
||||
args.shard_size,
|
||||
)
|
||||
|
||||
if eval_count < args.num_eval_samples:
|
||||
logger.warning(
|
||||
f"评估集样本不足: 目标 {args.num_eval_samples}, 实际 {eval_count}"
|
||||
)
|
||||
|
||||
console.print("\n[bold green]=== 预处理完成 ===[/bold green]")
|
||||
console.print(f"训练集: {train_count:,} 样本")
|
||||
console.print(f"评估集: {eval_count:,} 样本")
|
||||
console.print(f"输出目录: {output_dir}")
|
||||
|
||||
for split in ["train", "eval"]:
|
||||
split_dir = output_dir / split
|
||||
if split_dir.exists():
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
|
||||
)
|
||||
console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)")
|
||||
|
||||
|
||||
app = main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
"""
|
||||
预处理数据集加载器
|
||||
|
||||
支持两种格式:
|
||||
1. 分片压缩格式(.npz):从压缩文件中按需加载分片,LRU 缓存最多持有 max_cache_shards 个分片
|
||||
2. 单体格式(.npy):向后兼容,使用 mmap 零拷贝加载
|
||||
|
||||
GPU 服务器仅需存放压缩后的 .npz 文件,无需解压到硬盘。
|
||||
"""
|
||||
|
||||
import gc
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
FIELDS = [
|
||||
"input_ids",
|
||||
"token_type_ids",
|
||||
"attention_mask",
|
||||
"labels",
|
||||
"history_slot_ids",
|
||||
"pinyin_ids",
|
||||
]
|
||||
|
||||
|
||||
def is_preprocessed_data(path: str) -> bool:
|
||||
"""判断路径是否为预处理数据目录"""
|
||||
p = Path(path)
|
||||
return p.is_dir() and (p / "metadata.json").exists()
|
||||
|
||||
|
||||
class _ShardCache:
|
||||
"""LRU 缓存,管理按需加载的 .npz 分片,最多持有 max_size 个分片"""
|
||||
|
||||
def __init__(self, max_size: int = 2):
|
||||
self.max_size = max_size
|
||||
self._cache: OrderedDict[int, Dict[str, np.ndarray]] = OrderedDict()
|
||||
|
||||
def get(self, shard_idx: int, loader_fn) -> Dict[str, np.ndarray]:
|
||||
if shard_idx in self._cache:
|
||||
self._cache.move_to_end(shard_idx)
|
||||
return self._cache[shard_idx]
|
||||
|
||||
data = loader_fn(shard_idx)
|
||||
self._cache[shard_idx] = data
|
||||
self._cache.move_to_end(shard_idx)
|
||||
|
||||
while len(self._cache) > self.max_size:
|
||||
evicted_key, evicted_data = self._cache.popitem(last=False)
|
||||
del evicted_data
|
||||
gc.collect()
|
||||
|
||||
return data
|
||||
|
||||
def clear(self):
|
||||
self._cache.clear()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class PreProcessedDataset(Dataset):
|
||||
"""
|
||||
预处理数据集加载器,自动检测数据格式:
|
||||
|
||||
- 分片压缩格式(metadata.json 含 shard_size 字段):
|
||||
从 .npz 分片按需加载,LRU 缓存控制内存
|
||||
- 单体格式(向后兼容):
|
||||
mmap 零拷贝加载 .npy 文件
|
||||
|
||||
所有数据以 int16 存储,读取时转为 torch.long (int64)。
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str, max_cache_shards: int = 2):
|
||||
self.data_dir = Path(data_dir)
|
||||
|
||||
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
|
||||
self.metadata = json.load(f)
|
||||
|
||||
self.num_samples = self.metadata["num_samples"]
|
||||
self.max_seq_length = self.metadata["max_seq_length"]
|
||||
self._shard_size: Optional[int] = self.metadata.get("shard_size")
|
||||
self._num_shards: Optional[int] = self.metadata.get("num_shards")
|
||||
|
||||
if self._shard_size is not None and self._num_shards is not None:
|
||||
self._is_sharded = True
|
||||
self._cache = _ShardCache(max_size=max_cache_shards)
|
||||
logger.info(
|
||||
f"Loaded sharded dataset: {self.num_samples:,} samples, "
|
||||
f"{self._num_shards} shards, shard_size={self._shard_size:,}"
|
||||
)
|
||||
else:
|
||||
self._is_sharded = False
|
||||
self._load_single_files()
|
||||
logger.info(
|
||||
f"Loaded single-file dataset: {self.num_samples:,} samples (mmap)"
|
||||
)
|
||||
|
||||
def _load_single_files(self):
|
||||
"""向后兼容:加载单体 .npy 文件(mmap 模式)"""
|
||||
self.input_ids = np.load(self.data_dir / "input_ids.npy", mmap_mode="r")
|
||||
self.token_type_ids = np.load(
|
||||
self.data_dir / "token_type_ids.npy", mmap_mode="r"
|
||||
)
|
||||
self.attention_mask = np.load(
|
||||
self.data_dir / "attention_mask.npy", mmap_mode="r"
|
||||
)
|
||||
self.labels = np.load(self.data_dir / "labels.npy", mmap_mode="r")
|
||||
self.history_slot_ids = np.load(
|
||||
self.data_dir / "history_slot_ids.npy", mmap_mode="r"
|
||||
)
|
||||
self.pinyin_ids = np.load(self.data_dir / "pinyin_ids.npy", mmap_mode="r")
|
||||
|
||||
def _load_shard(self, shard_idx: int) -> Dict[str, np.ndarray]:
|
||||
"""加载一个 .npz 分片到内存"""
|
||||
shard_path = self.data_dir / f"shard_{shard_idx:06d}.npz"
|
||||
data = dict(np.load(shard_path))
|
||||
for key in data:
|
||||
data[key] = data[key].astype(np.int64)
|
||||
return data
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
if not 0 <= idx < self.num_samples:
|
||||
raise IndexError(
|
||||
f"Index {idx} out of range for dataset with {self.num_samples} samples"
|
||||
)
|
||||
|
||||
if self._is_sharded:
|
||||
shard_idx = idx // self._shard_size
|
||||
local_idx = idx % self._shard_size
|
||||
shard_data = self._cache.get(shard_idx, self._load_shard)
|
||||
return {
|
||||
"input_ids": torch.from_numpy(
|
||||
shard_data["input_ids"][local_idx].copy()
|
||||
),
|
||||
"token_type_ids": torch.from_numpy(
|
||||
shard_data["token_type_ids"][local_idx].copy()
|
||||
),
|
||||
"attention_mask": torch.from_numpy(
|
||||
shard_data["attention_mask"][local_idx].copy()
|
||||
),
|
||||
"labels": torch.tensor(
|
||||
shard_data["labels"][local_idx], dtype=torch.long
|
||||
),
|
||||
"history_slot_ids": torch.from_numpy(
|
||||
shard_data["history_slot_ids"][local_idx].copy()
|
||||
),
|
||||
"pinyin_ids": torch.from_numpy(
|
||||
shard_data["pinyin_ids"][local_idx].copy()
|
||||
),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"input_ids": torch.from_numpy(self.input_ids[idx].astype(np.int64)),
|
||||
"token_type_ids": torch.from_numpy(
|
||||
self.token_type_ids[idx].astype(np.int64)
|
||||
),
|
||||
"attention_mask": torch.from_numpy(
|
||||
self.attention_mask[idx].astype(np.int64)
|
||||
),
|
||||
"labels": torch.tensor(self.labels[idx], dtype=torch.long),
|
||||
"history_slot_ids": torch.from_numpy(
|
||||
self.history_slot_ids[idx].astype(np.int64)
|
||||
),
|
||||
"pinyin_ids": torch.from_numpy(self.pinyin_ids[idx].astype(np.int64)),
|
||||
}
|
||||
|
||||
|
||||
def preprocessed_collate_fn(batch):
|
||||
"""
|
||||
预处理数据的 collate 函数。
|
||||
不含 string 字段(prefix/suffix/pinyin),仅处理 tensor 字段。
|
||||
"""
|
||||
return {
|
||||
"input_ids": torch.stack([item["input_ids"] for item in batch]),
|
||||
"token_type_ids": torch.stack([item["token_type_ids"] for item in batch]),
|
||||
"attention_mask": torch.stack([item["attention_mask"] for item in batch]),
|
||||
"labels": torch.stack([item["labels"] for item in batch]),
|
||||
"history_slot_ids": torch.stack([item["history_slot_ids"] for item in batch]),
|
||||
"pinyin_ids": torch.stack([item["pinyin_ids"] for item in batch]),
|
||||
}
|
||||
|
|
@ -28,6 +28,11 @@ from torch.utils.data import DataLoader
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .dataset import PinyinInputDataset
|
||||
from .preprocessed_dataset import (
|
||||
PreProcessedDataset,
|
||||
is_preprocessed_data,
|
||||
preprocessed_collate_fn,
|
||||
)
|
||||
|
||||
# 导入模型和数据
|
||||
from .model import InputMethodEngine
|
||||
|
|
@ -97,6 +102,11 @@ class Trainer:
|
|||
"""
|
||||
self.model = model
|
||||
self.train_dataloader = train_dataloader
|
||||
if isinstance(eval_dataloader, DataLoader) and not isinstance(
|
||||
eval_dataloader.dataset, torch.utils.data.IterableDataset
|
||||
):
|
||||
self.eval_dataloader = eval_dataloader
|
||||
else:
|
||||
self.eval_dataloader = list([i for i in eval_dataloader])
|
||||
self.output_dir = Path(output_dir)
|
||||
self.num_epochs = num_epochs
|
||||
|
|
@ -1003,7 +1013,7 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|||
|
||||
# Typer CLI应用
|
||||
def create_dataloader(
|
||||
dataset: PinyinInputDataset,
|
||||
dataset,
|
||||
batch_size: int,
|
||||
num_workers: int = 2,
|
||||
pin_memory: bool = True,
|
||||
|
|
@ -1011,20 +1021,23 @@ def create_dataloader(
|
|||
max_iter_length: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
创建数据加载器,优先使用DataLoader2,如果不可用则回退到DataLoader。
|
||||
专门针对流式数据集优化。
|
||||
创建数据加载器,自动识别数据集类型。
|
||||
|
||||
Args:
|
||||
dataset: PinyinInputDataset实例
|
||||
batch_size: 批次大小
|
||||
num_workers: worker数量(对于流式数据集建议为2)
|
||||
pin_memory: 是否固定内存
|
||||
shuffle: 是否打乱(流式数据集内部处理打乱)
|
||||
max_iter_length: 最大迭代长度,用于计算总步数
|
||||
|
||||
Returns:
|
||||
数据加载器实例
|
||||
- PinyinInputDataset(IterableDataset):使用流式加载
|
||||
- PreProcessedDataset(map-style):使用标准加载,支持 shuffle
|
||||
"""
|
||||
if isinstance(dataset, PreProcessedDataset):
|
||||
logger.info(f"📊 使用预处理数据集,样本数: {len(dataset)}")
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=preprocessed_collate_fn,
|
||||
persistent_workers=True if num_workers > 0 else False,
|
||||
)
|
||||
|
||||
logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}")
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
|
|
@ -1164,12 +1177,82 @@ def train(
|
|||
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
||||
|
||||
config_table.add_row("其他", "自动恢复", str(auto_resume))
|
||||
console.print(config_table)
|
||||
|
||||
# 创建输出目录
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 检测数据类型并创建数据加载器
|
||||
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
||||
|
||||
is_train_preprocessed = is_preprocessed_data(train_data_path)
|
||||
is_eval_preprocessed = is_preprocessed_data(eval_data_path)
|
||||
|
||||
if is_train_preprocessed:
|
||||
train_dataset = PreProcessedDataset(train_data_path)
|
||||
total_steps = (len(train_dataset) // batch_size) * num_epochs
|
||||
train_dataloader = create_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
shuffle=True,
|
||||
)
|
||||
config_table.add_row("数据", "训练数据类型", "预处理数据")
|
||||
else:
|
||||
train_dataset = PinyinInputDataset(
|
||||
data_path=train_data_path,
|
||||
max_workers=-1,
|
||||
max_iter_length=max_iter_length,
|
||||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=2000000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
total_steps = int(max_iter_length * num_epochs / batch_size)
|
||||
train_dataloader = create_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
max_iter_length=max_iter_length,
|
||||
)
|
||||
config_table.add_row("数据", "训练数据类型", "流式数据")
|
||||
|
||||
if is_eval_preprocessed:
|
||||
eval_dataset = PreProcessedDataset(eval_data_path)
|
||||
eval_dataloader = create_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=2,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
shuffle=False,
|
||||
)
|
||||
config_table.add_row("数据", "评估数据类型", "预处理数据")
|
||||
else:
|
||||
eval_dataset = PinyinInputDataset(
|
||||
data_path=eval_data_path,
|
||||
max_workers=-1,
|
||||
max_iter_length=batch_size * 64,
|
||||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=2000000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
eval_dataloader = create_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=2,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
max_iter_length=batch_size * 64,
|
||||
)
|
||||
config_table.add_row("数据", "评估数据类型", "流式数据")
|
||||
|
||||
config_table.add_row("数据", "总步数", str(total_steps))
|
||||
console.print(config_table)
|
||||
|
||||
# 保存配置
|
||||
config = {
|
||||
"train_data_path": train_data_path,
|
||||
|
|
@ -1202,6 +1285,9 @@ def train(
|
|||
"auto_resume": auto_resume,
|
||||
"max_iter_length": max_iter_length,
|
||||
"compile": compile,
|
||||
"is_train_preprocessed": is_train_preprocessed,
|
||||
"is_eval_preprocessed": is_eval_preprocessed,
|
||||
"total_steps": total_steps,
|
||||
}
|
||||
|
||||
config_file = output_path / "training_config.json"
|
||||
|
|
@ -1210,52 +1296,6 @@ def train(
|
|||
|
||||
logger.info(f"Configuration saved to {config_file}")
|
||||
|
||||
# 创建数据加载器
|
||||
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
||||
|
||||
# 训练数据集
|
||||
train_dataset = PinyinInputDataset(
|
||||
data_path=train_data_path,
|
||||
max_workers=-1, # 自动选择worker数量
|
||||
max_iter_length=max_iter_length,
|
||||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=2000000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
# 训练数据加载器
|
||||
# 注意:PinyinInputDataset是IterableDataset,所以不能使用shuffle参数
|
||||
# 多worker配置:每个worker处理数据集的一个分片,由dataset.__iter__中的shard处理
|
||||
train_dataloader = create_dataloader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
max_iter_length=max_iter_length,
|
||||
)
|
||||
|
||||
# 评估数据集(使用相同的设置,但可以调整参数)
|
||||
eval_dataset = PinyinInputDataset(
|
||||
data_path=eval_data_path,
|
||||
max_workers=-1,
|
||||
max_iter_length=batch_size * 64, # 评估集较小
|
||||
max_seq_length=max_seq_len,
|
||||
text_field="text",
|
||||
py_style_weight=(9, 2, 1),
|
||||
shuffle_buffer_size=2000000,
|
||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||
)
|
||||
|
||||
eval_dataloader = create_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=2, # 评估使用较少的worker
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
max_iter_length=batch_size * 64,
|
||||
)
|
||||
|
||||
console.print("[bold cyan]正在创建模型...[/bold cyan]")
|
||||
model = InputMethodEngine(
|
||||
vocab_size=vocab_size,
|
||||
|
|
@ -1279,7 +1319,7 @@ def train(
|
|||
model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
total_steps=int(max_iter_length * num_epochs / batch_size),
|
||||
total_steps=total_steps,
|
||||
output_dir=output_dir,
|
||||
num_epochs=num_epochs,
|
||||
learning_rate=learning_rate,
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ train_dataset = PinyinInputDataset(
|
|||
dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=512,
|
||||
num_workers=2,
|
||||
num_workers=16,
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=collate_fn,
|
||||
prefetch_factor=2, # 减少预取以避免内存问题
|
||||
|
|
|
|||
|
|
@ -0,0 +1,146 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Visualize frequency distribution with ASCII plots
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
def ascii_histogram(data, bins=20, width=60):
|
||||
"""Create ASCII histogram"""
|
||||
if not data:
|
||||
return ""
|
||||
|
||||
min_val = min(data)
|
||||
max_val = max(data)
|
||||
|
||||
# Use log bins for wide range
|
||||
if max_val / min_val > 1000:
|
||||
log_min = math.log10(min_val) if min_val > 0 else 0
|
||||
log_max = math.log10(max_val)
|
||||
bin_edges = [10**(log_min + i*(log_max-log_min)/bins) for i in range(bins+1)]
|
||||
hist = [0] * bins
|
||||
for val in data:
|
||||
if val > 0:
|
||||
log_val = math.log10(val)
|
||||
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * bins), bins-1)
|
||||
hist[bin_idx] += 1
|
||||
bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i+1]:.1e}" for i in range(bins)]
|
||||
else:
|
||||
bin_width = (max_val - min_val) / bins
|
||||
bin_edges = [min_val + i*bin_width for i in range(bins+1)]
|
||||
hist = [0] * bins
|
||||
for val in data:
|
||||
bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins-1)
|
||||
hist[bin_idx] += 1
|
||||
bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i+1]:.1f}" for i in range(bins)]
|
||||
|
||||
max_count = max(hist)
|
||||
result = []
|
||||
for i in range(bins):
|
||||
if hist[i] == 0:
|
||||
continue
|
||||
bar = '#' * int(hist[i] / max_count * width)
|
||||
result.append(f"{bin_labels[i]:20} | {bar} {hist[i]}")
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def main():
|
||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
pairs = data.get('pairs', {})
|
||||
counts = [pair.get('count', 0) for pair in pairs.values() if pair.get('count') is not None]
|
||||
|
||||
print("FREQUENCY DISTRIBUTION ANALYSIS")
|
||||
print("="*60)
|
||||
|
||||
print("\n1. ASCII Histogram (log bins):")
|
||||
print(ascii_histogram(counts, bins=20, width=60))
|
||||
|
||||
# Rank-frequency plot in ASCII
|
||||
print("\n2. Rank-Frequency Relationship (Top 50):")
|
||||
counts_sorted_desc = sorted(counts, reverse=True)
|
||||
max_freq = counts_sorted_desc[0]
|
||||
max_rank = 50
|
||||
|
||||
for rank in range(1, max_rank + 1):
|
||||
freq = counts_sorted_desc[rank-1]
|
||||
bar_length = int(math.log(freq) / math.log(max_freq) * 40)
|
||||
bar = '#' * bar_length
|
||||
print(f"Rank {rank:3}: {freq:12} {bar}")
|
||||
|
||||
# ID vs Frequency plot (sampled)
|
||||
print("\n3. ID vs Frequency (sampled every 500 IDs):")
|
||||
# Build ID to count mapping
|
||||
id_to_count = {}
|
||||
for key, pair in pairs.items():
|
||||
char_id = pair.get('id')
|
||||
count = pair.get('count')
|
||||
if char_id is not None and count is not None:
|
||||
id_to_count[char_id] = count
|
||||
|
||||
all_ids = sorted(id_to_count.keys())
|
||||
max_id = all_ids[-1]
|
||||
|
||||
print("ID Frequency log10(freq)")
|
||||
for id in range(0, max_id + 1, 500):
|
||||
if id in id_to_count:
|
||||
freq = id_to_count[id]
|
||||
log_freq = math.log10(freq) if freq > 0 else 0
|
||||
bar = '#' * int(log_freq / math.log10(max_freq) * 40)
|
||||
print(f"{id:6} {freq:10} {log_freq:6.2f} {bar}")
|
||||
|
||||
# Zipf's law fit
|
||||
print("\n4. Zipf's Law Analysis:")
|
||||
print(" Rank * Frequency ≈ constant for Zipf's law")
|
||||
print(" Top 10 ranks:")
|
||||
for rank in range(1, 11):
|
||||
freq = counts_sorted_desc[rank-1]
|
||||
product = rank * freq
|
||||
print(f" Rank {rank}: {freq:12} rank*freq = {product:.3e}")
|
||||
|
||||
# Check if product is roughly constant
|
||||
products = [(rank+1) * counts_sorted_desc[rank] for rank in range(10)]
|
||||
avg_product = sum(products) / len(products)
|
||||
std_product = math.sqrt(sum((p - avg_product)**2 for p in products) / len(products))
|
||||
print(f" Average product (ranks 2-11): {avg_product:.3e} ± {std_product:.3e}")
|
||||
print(f" Coefficient of variation: {std_product/avg_product*100:.1f}%")
|
||||
|
||||
# Frequency spectrum
|
||||
from collections import Counter
|
||||
freq_counter = Counter(counts)
|
||||
print("\n5. Frequency Spectrum (how many entries have each frequency):")
|
||||
print(" Frequency Count Cumulative")
|
||||
cum = 0
|
||||
for freq in sorted(freq_counter.keys())[:20]:
|
||||
count = freq_counter[freq]
|
||||
cum += count
|
||||
print(f" {freq:10} {count:6} {cum:6}")
|
||||
|
||||
# Summary statistics
|
||||
print("\n6. Key Statistics:")
|
||||
n = len(counts)
|
||||
print(f" Total entries: {n}")
|
||||
print(f" Min frequency: {min(counts)}")
|
||||
print(f" Max frequency: {max(counts)}")
|
||||
print(f" Ratio max/min: {max(counts)/min(counts):.2e}")
|
||||
|
||||
percentiles = [0.01, 0.1, 0.5, 0.9, 0.99]
|
||||
for p in percentiles:
|
||||
idx = int(p * n)
|
||||
value = counts_sorted_desc[idx]
|
||||
print(f" {p*100:5.1f}th percentile: {value:12} (rank ~{idx})")
|
||||
|
||||
# Save data for external plotting
|
||||
with open("id_vs_freq.csv", "w") as f:
|
||||
f.write("id,frequency\n")
|
||||
for id in sorted(id_to_count.keys()):
|
||||
f.write(f"{id},{id_to_count[id]}\n")
|
||||
print("\nData saved to id_vs_freq.csv for external plotting")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue