Compare commits
No commits in common. "8b41bcdc6f60562429c22ea583765ffb016e1e80" and "1b7da9ddd41d6eda30bf60d4c28ad3c51ed7ec2c" have entirely different histories.
8b41bcdc6f
...
1b7da9ddd4
|
|
@ -1,236 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Analyze frequency distribution in pinyin_char_statistics.json
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import math
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Path to the JSON file
|
|
||||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
|
||||||
if not json_path.exists():
|
|
||||||
print(f"Error: File not found: {json_path}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
print(f"Loading {json_path}...")
|
|
||||||
with open(json_path, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
print(f"Timestamp: {data.get('timestamp')}")
|
|
||||||
print(f"Total characters: {data.get('total_characters')}")
|
|
||||||
print(f"Total pinyins: {data.get('total_pinyins')}")
|
|
||||||
print(f"Valid input character count: {data.get('valid_input_character_count')}")
|
|
||||||
|
|
||||||
pairs = data.get('pairs', {})
|
|
||||||
print(f"Number of pairs: {len(pairs)}")
|
|
||||||
|
|
||||||
# Extract counts and IDs
|
|
||||||
counts = []
|
|
||||||
id_to_count = {}
|
|
||||||
char_to_count = {}
|
|
||||||
for key, pair in pairs.items():
|
|
||||||
try:
|
|
||||||
char_id = pair.get('id')
|
|
||||||
count = pair.get('count')
|
|
||||||
char = pair.get('char', '')
|
|
||||||
if char_id is not None and count is not None:
|
|
||||||
counts.append(count)
|
|
||||||
id_to_count[char_id] = count
|
|
||||||
if char:
|
|
||||||
char_to_count[char] = count
|
|
||||||
except (ValueError, TypeError) as e:
|
|
||||||
print(f"Warning: Could not parse pair {key}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not counts:
|
|
||||||
print("No valid count data found.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Basic statistics
|
|
||||||
min_count = min(counts)
|
|
||||||
max_count = max(counts)
|
|
||||||
total_count = sum(counts)
|
|
||||||
mean_count = total_count / len(counts)
|
|
||||||
|
|
||||||
# Sort counts for percentiles
|
|
||||||
sorted_counts = sorted(counts)
|
|
||||||
n = len(sorted_counts)
|
|
||||||
|
|
||||||
# Percentiles
|
|
||||||
p10 = sorted_counts[int(0.1 * n)]
|
|
||||||
p25 = sorted_counts[int(0.25 * n)]
|
|
||||||
p50 = sorted_counts[int(0.5 * n)]
|
|
||||||
p75 = sorted_counts[int(0.75 * n)]
|
|
||||||
p90 = sorted_counts[int(0.9 * n)]
|
|
||||||
p99 = sorted_counts[int(0.99 * n)]
|
|
||||||
|
|
||||||
# Variance and std dev
|
|
||||||
variance = sum((x - mean_count) ** 2 for x in counts) / n
|
|
||||||
std_dev = math.sqrt(variance)
|
|
||||||
|
|
||||||
print("\n=== BASIC STATISTICS ===")
|
|
||||||
print(f"Min frequency: {min_count}")
|
|
||||||
print(f"Max frequency: {max_count}")
|
|
||||||
print(f"Mean frequency: {mean_count:.2f}")
|
|
||||||
print(f"Standard deviation: {std_dev:.2f}")
|
|
||||||
print(f"Total frequency sum: {total_count}")
|
|
||||||
print(f"Number of entries: {n}")
|
|
||||||
|
|
||||||
print("\n=== PERCENTILES ===")
|
|
||||||
print(f"10th percentile: {p10}")
|
|
||||||
print(f"25th percentile: {p25}")
|
|
||||||
print(f"50th percentile (median): {p50}")
|
|
||||||
print(f"75th percentile: {p75}")
|
|
||||||
print(f"90th percentile: {p90}")
|
|
||||||
print(f"99th percentile: {p99}")
|
|
||||||
|
|
||||||
# Find IDs with min and max counts
|
|
||||||
min_ids = [id for id, count in id_to_count.items() if count == min_count]
|
|
||||||
max_ids = [id for id, count in id_to_count.items() if count == max_count]
|
|
||||||
|
|
||||||
print(f"\nIDs with min frequency ({min_count}): {min_ids}")
|
|
||||||
print(f"IDs with max frequency ({max_count}): {max_ids}")
|
|
||||||
|
|
||||||
# Check if IDs are assigned in frequency order
|
|
||||||
# Compute correlation between ID and count
|
|
||||||
ids = list(id_to_count.keys())
|
|
||||||
id_counts = [id_to_count[id] for id in ids]
|
|
||||||
|
|
||||||
# Sort by ID and check if counts are decreasing
|
|
||||||
sorted_by_id = sorted(ids)
|
|
||||||
counts_by_id = [id_to_count[id] for id in sorted_by_id]
|
|
||||||
|
|
||||||
# Calculate monotonicity: count of times count decreases as ID increases
|
|
||||||
decreases = 0
|
|
||||||
increases = 0
|
|
||||||
for i in range(1, len(counts_by_id)):
|
|
||||||
if counts_by_id[i] < counts_by_id[i-1]:
|
|
||||||
decreases += 1
|
|
||||||
elif counts_by_id[i] > counts_by_id[i-1]:
|
|
||||||
increases += 1
|
|
||||||
|
|
||||||
print(f"\n=== ID ORDER ANALYSIS ===")
|
|
||||||
print(f"Total pairs: {len(counts_by_id)}")
|
|
||||||
print(f"Decreases as ID increases: {decreases} times")
|
|
||||||
print(f"Increases as ID increases: {increases} times")
|
|
||||||
print(f"Percentage decreasing: {decreases/(len(counts_by_id)-1)*100:.2f}%")
|
|
||||||
|
|
||||||
# Check if IDs are roughly sorted by frequency
|
|
||||||
# Compute Spearman rank correlation (simplified)
|
|
||||||
sorted_by_count = sorted(ids, key=lambda x: id_to_count[x], reverse=True)
|
|
||||||
rank_by_id = {id: i for i, id in enumerate(sorted_by_id)}
|
|
||||||
rank_by_count = {id: i for i, id in enumerate(sorted_by_count)}
|
|
||||||
|
|
||||||
# Average rank difference
|
|
||||||
rank_diffs = [abs(rank_by_id[id] - rank_by_count[id]) for id in ids]
|
|
||||||
avg_rank_diff = sum(rank_diffs) / len(rank_diffs)
|
|
||||||
max_rank_diff = max(rank_diffs)
|
|
||||||
|
|
||||||
print(f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}")
|
|
||||||
print(f"Maximum rank difference: {max_rank_diff}")
|
|
||||||
|
|
||||||
# Analyze specific ID range 5000-5500
|
|
||||||
print("\n=== ANALYSIS OF ID RANGE 5000-5500 ===")
|
|
||||||
range_counts = []
|
|
||||||
range_ids = []
|
|
||||||
for id in range(5000, 5501):
|
|
||||||
if id in id_to_count:
|
|
||||||
range_counts.append(id_to_count[id])
|
|
||||||
range_ids.append(id)
|
|
||||||
|
|
||||||
if range_counts:
|
|
||||||
range_min = min(range_counts)
|
|
||||||
range_max = max(range_counts)
|
|
||||||
range_mean = sum(range_counts) / len(range_counts)
|
|
||||||
range_sorted = sorted(range_counts)
|
|
||||||
range_n = len(range_counts)
|
|
||||||
range_p10 = range_sorted[int(0.1 * range_n)] if range_n > 0 else 0
|
|
||||||
range_p50 = range_sorted[int(0.5 * range_n)] if range_n > 0 else 0
|
|
||||||
range_p90 = range_sorted[int(0.9 * range_n)] if range_n > 0 else 0
|
|
||||||
|
|
||||||
print(f"IDs in range 5000-5500: {len(range_counts)}")
|
|
||||||
print(f"Min frequency in range: {range_min}")
|
|
||||||
print(f"Max frequency in range: {range_max}")
|
|
||||||
print(f"Mean frequency in range: {range_mean:.2f}")
|
|
||||||
print(f"10th percentile in range: {range_p10}")
|
|
||||||
print(f"50th percentile in range: {range_p50}")
|
|
||||||
print(f"90th percentile in range: {range_p90}")
|
|
||||||
|
|
||||||
# Find IDs with min frequency in this range
|
|
||||||
min_in_range_ids = [id for id in range_ids if id_to_count[id] == range_min]
|
|
||||||
print(f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}")
|
|
||||||
else:
|
|
||||||
print("No IDs found in range 5000-5500")
|
|
||||||
|
|
||||||
# Histogram of frequencies (log bins)
|
|
||||||
print("\n=== FREQUENCY DISTRIBUTION (LOG BINS) ===")
|
|
||||||
if max_count > 0:
|
|
||||||
log_min = math.log10(min_count) if min_count > 0 else 0
|
|
||||||
log_max = math.log10(max_count)
|
|
||||||
num_bins = 20
|
|
||||||
bin_edges = [10**(log_min + i*(log_max-log_min)/num_bins) for i in range(num_bins+1)]
|
|
||||||
|
|
||||||
hist = [0] * num_bins
|
|
||||||
for count in counts:
|
|
||||||
if count > 0:
|
|
||||||
log_val = math.log10(count)
|
|
||||||
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * num_bins), num_bins-1)
|
|
||||||
hist[bin_idx] += 1
|
|
||||||
|
|
||||||
print("Log-scale histogram (count range -> frequency count):")
|
|
||||||
for i in range(num_bins):
|
|
||||||
if hist[i] > 0:
|
|
||||||
lower = bin_edges[i]
|
|
||||||
upper = bin_edges[i+1]
|
|
||||||
print(f" {lower:.2e} - {upper:.2e}: {hist[i]} entries")
|
|
||||||
|
|
||||||
# Check for zero or near-zero frequencies
|
|
||||||
zero_count = sum(1 for c in counts if c == 0)
|
|
||||||
low_count = sum(1 for c in counts if 0 < c <= 10)
|
|
||||||
very_low_count = sum(1 for c in counts if 0 < c <= 100)
|
|
||||||
|
|
||||||
print(f"\n=== LOW FREQUENCY ANALYSIS ===")
|
|
||||||
print(f"Entries with zero frequency: {zero_count}")
|
|
||||||
print(f"Entries with frequency <= 10: {low_count}")
|
|
||||||
print(f"Entries with frequency <= 100: {very_low_count}")
|
|
||||||
|
|
||||||
# Find the actual min frequency (excluding zeros if any)
|
|
||||||
non_zero_counts = [c for c in counts if c > 0]
|
|
||||||
if non_zero_counts:
|
|
||||||
actual_min = min(non_zero_counts)
|
|
||||||
print(f"Actual min frequency (non-zero): {actual_min}")
|
|
||||||
actual_min_ids = [id for id, count in id_to_count.items() if count == actual_min]
|
|
||||||
print(f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}")
|
|
||||||
|
|
||||||
# Summary for smoothing algorithm design
|
|
||||||
print("\n=== SUMMARY FOR SMOOTHING ALGORITHM DESIGN ===")
|
|
||||||
print(f"Frequency range spans {max_count/min_count if min_count>0 else 'inf'}:1 ratio")
|
|
||||||
print(f"Most entries ({p50}) have frequency around {p50}")
|
|
||||||
print(f"Top 10% of entries have frequency > {p90}")
|
|
||||||
print(f"Bottom 10% of entries have frequency < {p10}")
|
|
||||||
print(f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency")
|
|
||||||
|
|
||||||
# Save detailed data for further analysis
|
|
||||||
output_file = "frequency_analysis_results.txt"
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
|
||||||
f.write("Frequency Analysis Results\n")
|
|
||||||
f.write("="*50 + "\n")
|
|
||||||
f.write(f"Min frequency: {min_count}\n")
|
|
||||||
f.write(f"Max frequency: {max_count}\n")
|
|
||||||
f.write(f"Mean frequency: {mean_count:.2f}\n")
|
|
||||||
f.write(f"Standard deviation: {std_dev:.2f}\n")
|
|
||||||
f.write(f"10th percentile: {p10}\n")
|
|
||||||
f.write(f"50th percentile: {p50}\n")
|
|
||||||
f.write(f"90th percentile: {p90}\n")
|
|
||||||
f.write(f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n")
|
|
||||||
f.write(f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n")
|
|
||||||
|
|
||||||
print(f"\nDetailed results saved to {output_file}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
115
analyze_range.py
115
analyze_range.py
|
|
@ -1,115 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Analyze specific ID ranges in pinyin_char_statistics.json
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
def main():
|
|
||||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
|
||||||
with open(json_path, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
pairs = data.get('pairs', {})
|
|
||||||
|
|
||||||
# Build ID to count mapping
|
|
||||||
id_to_count = {}
|
|
||||||
for key, pair in pairs.items():
|
|
||||||
char_id = pair.get('id')
|
|
||||||
count = pair.get('count')
|
|
||||||
if char_id is not None and count is not None:
|
|
||||||
id_to_count[char_id] = count
|
|
||||||
|
|
||||||
# Analyze range 5000-5500 in detail
|
|
||||||
print("ID range 5000-5500 detailed analysis:")
|
|
||||||
print("ID\tCount\tChar\tPinyin")
|
|
||||||
|
|
||||||
range_data = []
|
|
||||||
for id in range(5000, 5501):
|
|
||||||
if id in id_to_count:
|
|
||||||
# Find the pair to get char and pinyin
|
|
||||||
for key, pair in pairs.items():
|
|
||||||
if pair.get('id') == id:
|
|
||||||
char = pair.get('char', '')
|
|
||||||
pinyin = pair.get('pinyin', '')
|
|
||||||
count = pair.get('count', 0)
|
|
||||||
range_data.append((id, count, char, pinyin))
|
|
||||||
if id % 100 == 0: # Print every 100th for overview
|
|
||||||
print(f"{id}\t{count}\t{char}\t{pinyin}")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Print min and max in range
|
|
||||||
if range_data:
|
|
||||||
min_item = min(range_data, key=lambda x: x[1])
|
|
||||||
max_item = max(range_data, key=lambda x: x[1])
|
|
||||||
print(f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'")
|
|
||||||
print(f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'")
|
|
||||||
|
|
||||||
# Check if frequencies are monotonic in this range
|
|
||||||
counts = [item[1] for item in range_data]
|
|
||||||
increasing = all(counts[i] <= counts[i+1] for i in range(len(counts)-1))
|
|
||||||
decreasing = all(counts[i] >= counts[i+1] for i in range(len(counts)-1))
|
|
||||||
print(f"Monotonic in range: increasing={increasing}, decreasing={decreasing}")
|
|
||||||
|
|
||||||
# Check for frequency plateaus
|
|
||||||
from collections import Counter
|
|
||||||
freq_count = Counter(counts)
|
|
||||||
most_common = freq_count.most_common(5)
|
|
||||||
print(f"Most common frequencies in range: {most_common}")
|
|
||||||
|
|
||||||
# Analyze the tail (IDs with frequency 1)
|
|
||||||
print("\n\nAnalysis of frequency=1 entries:")
|
|
||||||
freq_one_ids = [id for id, count in id_to_count.items() if count == 1]
|
|
||||||
print(f"Number of entries with frequency=1: {len(freq_one_ids)}")
|
|
||||||
if freq_one_ids:
|
|
||||||
print(f"ID range of frequency=1: {min(freq_one_ids)} to {max(freq_one_ids)}")
|
|
||||||
print(f"First 10 IDs: {freq_one_ids[:10]}")
|
|
||||||
print(f"Last 10 IDs: {freq_one_ids[-10:]}")
|
|
||||||
|
|
||||||
# Check if they're contiguous
|
|
||||||
sorted_ids = sorted(freq_one_ids)
|
|
||||||
contiguous = all(sorted_ids[i] + 1 == sorted_ids[i+1] for i in range(len(sorted_ids)-1))
|
|
||||||
print(f"Are they contiguous IDs? {contiguous}")
|
|
||||||
|
|
||||||
# Sample some characters
|
|
||||||
print("\nSample characters with frequency=1:")
|
|
||||||
sample_count = 0
|
|
||||||
for key, pair in pairs.items():
|
|
||||||
if pair.get('count') == 1 and sample_count < 10:
|
|
||||||
print(f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'")
|
|
||||||
sample_count += 1
|
|
||||||
|
|
||||||
# Check overall ID-frequency ordering
|
|
||||||
print("\n\nOverall ID-frequency ordering analysis:")
|
|
||||||
all_ids = sorted(id_to_count.keys())
|
|
||||||
all_counts = [id_to_count[id] for id in all_ids]
|
|
||||||
|
|
||||||
# Count monotonic segments
|
|
||||||
non_increasing_segments = 0
|
|
||||||
current_segment_length = 1
|
|
||||||
for i in range(1, len(all_counts)):
|
|
||||||
if all_counts[i] <= all_counts[i-1]:
|
|
||||||
current_segment_length += 1
|
|
||||||
else:
|
|
||||||
if current_segment_length > 1:
|
|
||||||
non_increasing_segments += 1
|
|
||||||
current_segment_length = 1
|
|
||||||
if current_segment_length > 1:
|
|
||||||
non_increasing_segments += 1
|
|
||||||
|
|
||||||
print(f"Total IDs: {len(all_ids)}")
|
|
||||||
print(f"Non-increasing segments: {non_increasing_segments}")
|
|
||||||
|
|
||||||
# Check for frequency plateaus overall
|
|
||||||
from collections import Counter
|
|
||||||
overall_freq_count = Counter(all_counts)
|
|
||||||
plateaus = [(freq, count) for freq, count in overall_freq_count.items() if count > 1]
|
|
||||||
plateaus_sorted = sorted(plateaus, key=lambda x: x[1], reverse=True)[:10]
|
|
||||||
print(f"Top 10 frequency plateaus (freq: count of IDs sharing that freq):")
|
|
||||||
for freq, count in plateaus_sorted:
|
|
||||||
print(f" {freq}: {count} IDs")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,177 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Comprehensive frequency distribution analysis
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
from collections import Counter
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
def main():
|
|
||||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
|
||||||
with open(json_path, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
pairs = data.get('pairs', {})
|
|
||||||
|
|
||||||
# Extract counts
|
|
||||||
counts = []
|
|
||||||
for key, pair in pairs.items():
|
|
||||||
count = pair.get('count')
|
|
||||||
if count is not None:
|
|
||||||
counts.append(count)
|
|
||||||
|
|
||||||
n = len(counts)
|
|
||||||
print(f"Total entries: {n}")
|
|
||||||
|
|
||||||
# Sort descending for rank-frequency analysis
|
|
||||||
counts_sorted_desc = sorted(counts, reverse=True)
|
|
||||||
|
|
||||||
# Basic statistics
|
|
||||||
min_count = min(counts)
|
|
||||||
max_count = max(counts)
|
|
||||||
mean_count = sum(counts) / n
|
|
||||||
|
|
||||||
# Percentiles
|
|
||||||
percentiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
|
|
||||||
print("\n=== PERCENTILE DISTRIBUTION ===")
|
|
||||||
for p in percentiles:
|
|
||||||
idx = int(p * n)
|
|
||||||
value = counts_sorted_desc[idx]
|
|
||||||
print(f"{p*100:5.1f}%: {value:>12} (rank ~{idx})")
|
|
||||||
|
|
||||||
# Cumulative distribution
|
|
||||||
print("\n=== CUMULATIVE DISTRIBUTION ===")
|
|
||||||
thresholds = [1, 2, 3, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 500000, 1000000, 5000000, 10000000, 50000000, 100000000, 500000000]
|
|
||||||
for thresh in thresholds:
|
|
||||||
if thresh > max_count:
|
|
||||||
break
|
|
||||||
below = sum(1 for c in counts if c <= thresh)
|
|
||||||
above = sum(1 for c in counts if c >= thresh)
|
|
||||||
print(f"Count <= {thresh:10}: {below:6} entries ({below/n*100:5.1f}%)")
|
|
||||||
# print(f"Count >= {thresh:10}: {above:6} entries ({above/n*100:5.1f}%)")
|
|
||||||
|
|
||||||
# Check min_count=109 parameter
|
|
||||||
print("\n=== ANALYSIS OF THRESHOLD 109 ===")
|
|
||||||
below_109 = sum(1 for c in counts if c < 109)
|
|
||||||
at_or_above_109 = sum(1 for c in counts if c >= 109)
|
|
||||||
print(f"Entries with count < 109: {below_109} ({below_109/n*100:.1f}%)")
|
|
||||||
print(f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109/n*100:.1f}%)")
|
|
||||||
|
|
||||||
# If 109 is a threshold, what's the actual min among those >= 109?
|
|
||||||
counts_ge_109 = [c for c in counts if c >= 109]
|
|
||||||
if counts_ge_109:
|
|
||||||
actual_min_ge_109 = min(counts_ge_109)
|
|
||||||
print(f"Actual min frequency among those >= 109: {actual_min_ge_109}")
|
|
||||||
|
|
||||||
# Rank-frequency analysis (Zipf's law)
|
|
||||||
print("\n=== RANK-FREQUENCY ANALYSIS (Top 100) ===")
|
|
||||||
print("Rank\tFrequency\tlog(rank)\tlog(freq)")
|
|
||||||
for rank in range(1, 101):
|
|
||||||
freq = counts_sorted_desc[rank-1]
|
|
||||||
print(f"{rank}\t{freq}\t{math.log(rank):.3f}\t{math.log(freq):.3f}")
|
|
||||||
|
|
||||||
# Frequency spectrum (how many distinct frequencies)
|
|
||||||
freq_counter = Counter(counts)
|
|
||||||
print(f"\n=== FREQUENCY SPECTRUM ===")
|
|
||||||
print(f"Distinct frequency values: {len(freq_counter)}")
|
|
||||||
|
|
||||||
# Most common frequencies
|
|
||||||
print("\nTop 20 most common frequencies (plateau sizes):")
|
|
||||||
for freq, freq_count in freq_counter.most_common(20):
|
|
||||||
print(f" Frequency {freq}: {freq_count} entries")
|
|
||||||
|
|
||||||
# Analyze ID ranges
|
|
||||||
print("\n=== ID RANGE ANALYSIS ===")
|
|
||||||
# Build ID to count mapping
|
|
||||||
id_to_count = {}
|
|
||||||
for key, pair in pairs.items():
|
|
||||||
char_id = pair.get('id')
|
|
||||||
count = pair.get('count')
|
|
||||||
if char_id is not None and count is not None:
|
|
||||||
id_to_count[char_id] = count
|
|
||||||
|
|
||||||
ranges = [
|
|
||||||
(0, 100, "Top 100 IDs"),
|
|
||||||
(100, 500, "IDs 100-500"),
|
|
||||||
(500, 1000, "IDs 500-1000"),
|
|
||||||
(1000, 2000, "IDs 1000-2000"),
|
|
||||||
(2000, 5000, "IDs 2000-5000"),
|
|
||||||
(5000, 5500, "IDs 5000-5500 (user mentioned)"),
|
|
||||||
(5500, 6000, "IDs 5500-6000"),
|
|
||||||
(10000, 10500, "IDs 10000-10500"),
|
|
||||||
(15000, 15500, "IDs 15000-15500"),
|
|
||||||
(19000, 19500, "IDs 19000-19500 (before freq=1)"),
|
|
||||||
(19499, 20647, "IDs with freq=1"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for start, end, label in ranges:
|
|
||||||
range_counts = [id_to_count[id] for id in range(start, end) if id in id_to_count]
|
|
||||||
if range_counts:
|
|
||||||
min_c = min(range_counts)
|
|
||||||
max_c = max(range_counts)
|
|
||||||
mean_c = sum(range_counts) / len(range_counts)
|
|
||||||
median_c = sorted(range_counts)[len(range_counts)//2]
|
|
||||||
print(f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}")
|
|
||||||
|
|
||||||
# Check if IDs are perfectly sorted by frequency
|
|
||||||
print("\n=== ID ORDER VERIFICATION ===")
|
|
||||||
all_ids = sorted(id_to_count.keys())
|
|
||||||
all_counts = [id_to_count[id] for id in all_ids]
|
|
||||||
|
|
||||||
# Check for any violations of non-increasing order
|
|
||||||
violations = 0
|
|
||||||
for i in range(1, len(all_counts)):
|
|
||||||
if all_counts[i] > all_counts[i-1]:
|
|
||||||
violations += 1
|
|
||||||
if violations <= 5:
|
|
||||||
print(f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i-1]} (ID {all_ids[i-1]})")
|
|
||||||
|
|
||||||
print(f"Total violations of non-increasing order: {violations}")
|
|
||||||
|
|
||||||
# Check if equal frequencies are grouped together
|
|
||||||
print("\n=== FREQUENCY GROUPING ANALYSIS ===")
|
|
||||||
current_freq = None
|
|
||||||
group_start = None
|
|
||||||
group_sizes = []
|
|
||||||
|
|
||||||
for i, (id, count) in enumerate(zip(all_ids, all_counts)):
|
|
||||||
if count != current_freq:
|
|
||||||
if current_freq is not None:
|
|
||||||
group_sizes.append((current_freq, group_start, all_ids[i-1], i - group_start))
|
|
||||||
current_freq = count
|
|
||||||
group_start = i
|
|
||||||
|
|
||||||
# Last group
|
|
||||||
if current_freq is not None:
|
|
||||||
group_sizes.append((current_freq, group_start, all_ids[-1], len(all_ids) - group_start))
|
|
||||||
|
|
||||||
# Sort groups by size
|
|
||||||
group_sizes.sort(key=lambda x: x[3], reverse=True)
|
|
||||||
print("Top 10 largest frequency groups (plateaus):")
|
|
||||||
for freq, start_id_idx, end_id, size in group_sizes[:10]:
|
|
||||||
start_id = all_ids[start_id_idx]
|
|
||||||
print(f" Frequency {freq}: IDs {start_id}-{end_id} ({size} entries)")
|
|
||||||
|
|
||||||
# Summary for smoothing algorithm
|
|
||||||
print("\n=== SMOOTHING ALGORITHM IMPLICATIONS ===")
|
|
||||||
print("1. IDs are perfectly sorted by frequency (non-increasing).")
|
|
||||||
print(f"2. Frequency range: {min_count} to {max_count} (ratio {max_count/min_count:.1e}:1).")
|
|
||||||
print(f"3. {below_109} entries ({below_109/n*100:.1f}%) have frequency < 109.")
|
|
||||||
print(f"4. Median frequency: {counts_sorted_desc[n//2]}.")
|
|
||||||
print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9*n)]}.")
|
|
||||||
print(f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01*n)]}.")
|
|
||||||
print("7. Large frequency plateaus exist (many IDs share same frequency).")
|
|
||||||
print("8. Smoothing should handle extreme frequency ratios (1:5e8).")
|
|
||||||
|
|
||||||
# Save data for plotting
|
|
||||||
with open("rank_freq.csv", "w") as f:
|
|
||||||
f.write("rank,frequency\n")
|
|
||||||
for rank, freq in enumerate(counts_sorted_desc, 1):
|
|
||||||
f.write(f"{rank},{freq}\n")
|
|
||||||
print("\nRank-frequency data saved to rank_freq.csv")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,168 +0,0 @@
|
||||||
# 破词训练设计文档
|
|
||||||
|
|
||||||
## 背景
|
|
||||||
|
|
||||||
输入法用户在实际使用中通常是**逐词输入**的,而非逐字输入。例如输入"那边的特别漂亮的女孩是我的表姐"时,用户可能分词为:
|
|
||||||
|
|
||||||
```
|
|
||||||
那边 / 的 / 特别 / 漂亮 / 的 / 女孩 / 是 / 我 / 的 / 表姐
|
|
||||||
```
|
|
||||||
|
|
||||||
但为了增强模型的泛化能力,需要模拟用户**从词中间断开**的情况。例如用户可能只输入了"漂"就开始选字"亮"。
|
|
||||||
|
|
||||||
## 破词概念
|
|
||||||
|
|
||||||
### 术语定义
|
|
||||||
|
|
||||||
| 术语 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| 整词输入 | 用户输入完整词的拼音,如"piaoliang" |
|
|
||||||
| 破词输入 | 用户只输入词的部分拼音,如"piao" |
|
|
||||||
| 前缀 | 光标前的已确认文本 |
|
|
||||||
| 拼音 | 当前待选字的拼音(可能不完整) |
|
|
||||||
| 后缀 | 光标后的原文内容 |
|
|
||||||
|
|
||||||
### 场景示例
|
|
||||||
|
|
||||||
以词"漂亮"为例:
|
|
||||||
|
|
||||||
**整词模式(90%概率):**
|
|
||||||
```
|
|
||||||
光标前: 那边的特别
|
|
||||||
拼音: piaoliang
|
|
||||||
预测: 漂 → 亮
|
|
||||||
```
|
|
||||||
|
|
||||||
**破词模式(10%概率):**
|
|
||||||
```
|
|
||||||
光标前: 那边的特别漂
|
|
||||||
拼音: liang
|
|
||||||
预测: 亮
|
|
||||||
```
|
|
||||||
|
|
||||||
## 实现方案
|
|
||||||
|
|
||||||
### 分词策略
|
|
||||||
|
|
||||||
使用 jieba 分词器进行词语边界识别:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import jieba
|
|
||||||
words = list(jieba.cut(text, HMM=False))
|
|
||||||
# "那边的特别漂亮的女孩是我的表姐。"
|
|
||||||
# → ['那边', '的', '特别', '漂亮', '的', '女孩', '是', '我', '的', '表姐', '。']
|
|
||||||
```
|
|
||||||
|
|
||||||
### 两阶段样本生成
|
|
||||||
|
|
||||||
每个词生成样本时分为两个阶段:
|
|
||||||
|
|
||||||
#### Phase 1:前缀/整词阶段
|
|
||||||
|
|
||||||
- **整词(90%)**:`prefix_positions = 整个词的所有字符`
|
|
||||||
- **破词前缀(10%)**:`prefix_positions = 词的前 break_pos 个字符`
|
|
||||||
|
|
||||||
```python
|
|
||||||
if should_break:
|
|
||||||
break_pos = random.randint(1, word_len_chars - 1) # 随机破开位置
|
|
||||||
else:
|
|
||||||
break_pos = word_len_chars # 整词
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Phase 2:破词续接阶段(仅当破词时)
|
|
||||||
|
|
||||||
当破词发生时,从断点位置开始继续采样:
|
|
||||||
|
|
||||||
```python
|
|
||||||
if should_break and break_pos < word_len_chars:
|
|
||||||
cont_start = char_positions[break_pos]
|
|
||||||
# 从断点开始采样后续字符
|
|
||||||
target_len = random_choice(1-8) # 采样长度
|
|
||||||
cont_positions = [cont_start, ...] # 后续字符位置
|
|
||||||
```
|
|
||||||
|
|
||||||
### 样本结构
|
|
||||||
|
|
||||||
每个字符生成一个训练样本,包含:
|
|
||||||
|
|
||||||
| 字段 | 说明 | 示例 |
|
|
||||||
|------|------|------|
|
|
||||||
| `part1` (prefix) | 光标前文本 | "那边的特别漂" |
|
|
||||||
| `part2` (pinyin) | 当前字拼音 | "liang" |
|
|
||||||
| `part3` (suffix) | 光标后文本 | "亮的女孩是我的表姐" |
|
|
||||||
| `part4` | 专有词提示 | "漂亮\|特别" |
|
|
||||||
| `label` | 目标汉字ID | 1234 |
|
|
||||||
| `history_slot_ids` | 历史已确认字 | [0, 0, 0, 0, 0, 0, 0, 0] |
|
|
||||||
|
|
||||||
### 拼音增强策略
|
|
||||||
|
|
||||||
根据 `py_style_weight` 参数,拼音有以下三种形式:
|
|
||||||
|
|
||||||
| 形式 | 概率 | 示例 |
|
|
||||||
|------|------|------|
|
|
||||||
| 完整拼音 | 75% (9/12) | "piaoliang" |
|
|
||||||
| 仅声母 | 16.7% (2/12) | "pl" (通过 to_initials) |
|
|
||||||
| 仅首字母 | 8.3% (1/12) | "p" |
|
|
||||||
|
|
||||||
参数配置:`py_style_weight=(9, 2, 1)`
|
|
||||||
|
|
||||||
## 破词概率控制
|
|
||||||
|
|
||||||
### word_break_prob 参数
|
|
||||||
|
|
||||||
控制每个词从中间断开的概率,默认为 **10%**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
self.word_break_prob = 0.10 # 10%概率从词中间破开
|
|
||||||
```
|
|
||||||
|
|
||||||
### 破词位置分布
|
|
||||||
|
|
||||||
对于长度为 N 的词,破开位置 `break_pos` 的分布:
|
|
||||||
|
|
||||||
```python
|
|
||||||
break_pos = random.randint(1, N - 1)
|
|
||||||
```
|
|
||||||
|
|
||||||
- 2字词:break_pos = 1(100%在第1字后破开)
|
|
||||||
- 3字词:break_pos = 1 或 2(各50%)
|
|
||||||
- 4字词:break_pos = 1, 2, 或 3(各33%)
|
|
||||||
|
|
||||||
## 数据分布预期
|
|
||||||
|
|
||||||
### 理想分布
|
|
||||||
|
|
||||||
| 类别 | 预期比例 |
|
|
||||||
|------|----------|
|
|
||||||
| 单字样本 | ~15% |
|
|
||||||
| 2字词整词 | ~30% |
|
|
||||||
| 3字词整词 | ~20% |
|
|
||||||
| 破词样本 | ~10% |
|
|
||||||
| 其他 | ~25% |
|
|
||||||
|
|
||||||
### 拼音不完整率
|
|
||||||
|
|
||||||
由于 `py_style_weight=(9, 2, 1)`:
|
|
||||||
|
|
||||||
- 声母(initials):~16.7%
|
|
||||||
- 首字母:~8.3%
|
|
||||||
- **总计不完整**:~25%
|
|
||||||
|
|
||||||
## 代码实现位置
|
|
||||||
|
|
||||||
主要实现文件:`src/model/dataset.py`
|
|
||||||
|
|
||||||
| 函数/类 | 行号 | 功能 |
|
|
||||||
|---------|------|------|
|
|
||||||
| `segment_text()` | ~30 | jieba分词 |
|
|
||||||
| `build_word_boundaries()` | ~35 | 建立词边界映射 |
|
|
||||||
| `PinyinInputDataset.__iter__()` | ~280 | 核心迭代逻辑 |
|
|
||||||
| `get_mask_pinyin()` | ~215 | 拼音加强处理 |
|
|
||||||
| `_add_word_samples()` | ~240 | 样本构建 |
|
|
||||||
|
|
||||||
## 注意事项
|
|
||||||
|
|
||||||
1. **破词仅针对多字词**:单字词(如"的"、“是”)不会破词
|
|
||||||
2. **破词保持语义完整**:破词后仍能根据上下文预测正确汉字
|
|
||||||
3. **历史槽位模拟逐步确认**:同一词内已确认的字会填入 `history_slot_ids`
|
|
||||||
4. **10% EOS标记**:词尾有10%概率追加ID=0表示句子结束
|
|
||||||
|
|
@ -1,11 +0,0 @@
|
||||||
Frequency Analysis Results
|
|
||||||
==================================================
|
|
||||||
Min frequency: 1
|
|
||||||
Max frequency: 494748360
|
|
||||||
Mean frequency: 560007.90
|
|
||||||
Standard deviation: 5730144.34
|
|
||||||
10th percentile: 3
|
|
||||||
50th percentile: 93
|
|
||||||
90th percentile: 331538
|
|
||||||
IDs in range 5000-5500 min: 5594
|
|
||||||
IDs in range 5000-5500 max: 9569
|
|
||||||
20648
id_vs_freq.csv
20648
id_vs_freq.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -30,8 +30,6 @@ dependencies = [
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
train-model = "model.trainer:app"
|
train-model = "model.trainer:app"
|
||||||
monitor-training = "model.monitor:app"
|
monitor-training = "model.monitor:app"
|
||||||
preprocess-model = "model.preprocess:main"
|
|
||||||
inspect-preprocessed = "model.inspect_preprocessed:main"
|
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
# 设置当前项目的默认索引源
|
# 设置当前项目的默认索引源
|
||||||
|
|
|
||||||
20648
rank_freq.csv
20648
rank_freq.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -22,26 +22,9 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from model.dataset import PinyinInputDataset
|
from model.dataset import PinyinInputDataset
|
||||||
from model.query import QueryEngine
|
|
||||||
|
|
||||||
|
|
||||||
_id2char_cache = {}
|
def analyze_label_distribution(dataset: PinyinInputDataset, sample_size: int = 10000):
|
||||||
|
|
||||||
|
|
||||||
def get_char_by_id(query_engine: QueryEngine, char_id: int) -> str:
|
|
||||||
if char_id == 0:
|
|
||||||
return "<EOS>"
|
|
||||||
if char_id not in _id2char_cache:
|
|
||||||
info = query_engine.query_by_id(char_id)
|
|
||||||
_id2char_cache[char_id] = info.char if info else f"<ID:{char_id}>"
|
|
||||||
return _id2char_cache[char_id]
|
|
||||||
|
|
||||||
|
|
||||||
def analyze_label_distribution(
|
|
||||||
dataset: PinyinInputDataset,
|
|
||||||
sample_size: int = 10000,
|
|
||||||
query_engine: QueryEngine = None,
|
|
||||||
):
|
|
||||||
"""分析 label 在指定区间的分布"""
|
"""分析 label 在指定区间的分布"""
|
||||||
target_ranges = [
|
target_ranges = [
|
||||||
(0, 10),
|
(0, 10),
|
||||||
|
|
@ -80,23 +63,13 @@ def analyze_label_distribution(
|
||||||
in_target_range = True
|
in_target_range = True
|
||||||
|
|
||||||
if len(all_examples) < 200:
|
if len(all_examples) < 200:
|
||||||
label_char = (
|
|
||||||
get_char_by_id(query_engine, label) if query_engine else f"<ID:{label}>"
|
|
||||||
)
|
|
||||||
history_chars = (
|
|
||||||
[get_char_by_id(query_engine, hid) for hid in history]
|
|
||||||
if query_engine
|
|
||||||
else history
|
|
||||||
)
|
|
||||||
all_examples.append(
|
all_examples.append(
|
||||||
{
|
{
|
||||||
"label": label,
|
"label": label,
|
||||||
"label_char": label_char,
|
|
||||||
"prefix": prefix,
|
"prefix": prefix,
|
||||||
"suffix": suffix,
|
"suffix": suffix,
|
||||||
"pinyin": pinyin,
|
"pinyin": pinyin,
|
||||||
"history": history,
|
"history": history,
|
||||||
"history_chars": history_chars,
|
|
||||||
"part4": part4,
|
"part4": part4,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -123,13 +96,12 @@ def analyze_label_distribution(
|
||||||
random.shuffle(all_examples)
|
random.shuffle(all_examples)
|
||||||
for idx, ex in enumerate(all_examples[:20], 1):
|
for idx, ex in enumerate(all_examples[:20], 1):
|
||||||
print(f"\n样本 {idx}:")
|
print(f"\n样本 {idx}:")
|
||||||
print(f" Label: {ex['label']} ({ex['label_char']})")
|
print(f" Label: {ex['label']}")
|
||||||
print(f" Part4: {ex['part4']}")
|
print(f" Part4: {ex['part4']}")
|
||||||
print(f" 光标前: {ex['prefix']}")
|
print(f" 光标前: {ex['prefix']}")
|
||||||
print(f" 光标后: {ex['suffix']}")
|
print(f" 光标后: {ex['suffix']}")
|
||||||
print(f" 拼音: {ex['pinyin']}")
|
print(f" 拼音: {ex['pinyin']}")
|
||||||
print(f" 历史槽位: {ex['history']}")
|
print(f" 历史槽位: {ex['history']}")
|
||||||
print(f" 历史汉字: {ex['history_chars']}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
@ -155,9 +127,7 @@ def main():
|
||||||
help="数据集路径 (本地文件或HuggingFace路径)",
|
help="数据集路径 (本地文件或HuggingFace路径)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--sample_size", type=int, default=10000, help="采样大小")
|
parser.add_argument("--sample_size", type=int, default=10000, help="采样大小")
|
||||||
parser.add_argument(
|
parser.add_argument("--max_workers", type=int, default=-1, help="DataLoader workers")
|
||||||
"--max_workers", type=int, default=-1, help="DataLoader workers"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print(f"加载数据集: {args.data_path}")
|
print(f"加载数据集: {args.data_path}")
|
||||||
|
|
@ -178,11 +148,7 @@ def main():
|
||||||
print(" 3. 如果是 HuggingFace 数据集,路径应该正确")
|
print(" 3. 如果是 HuggingFace 数据集,路径应该正确")
|
||||||
return
|
return
|
||||||
|
|
||||||
query_engine = QueryEngine()
|
analyze_label_distribution(dataset, sample_size=args.sample_size)
|
||||||
query_engine.load()
|
|
||||||
analyze_label_distribution(
|
|
||||||
dataset, sample_size=args.sample_size, query_engine=query_engine
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,4 @@
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", message=".*pkg_resources.*")
|
|
||||||
|
|
||||||
import jieba
|
import jieba
|
||||||
import math
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
|
|
@ -67,26 +62,20 @@ class PinyinInputDataset(IterableDataset):
|
||||||
max_iter_length=1e6,
|
max_iter_length=1e6,
|
||||||
max_seq_length=128,
|
max_seq_length=128,
|
||||||
text_field: str = "text",
|
text_field: str = "text",
|
||||||
py_style_weight=(9, 2, 1),
|
py_style_weight=(90, 2, 1),
|
||||||
shuffle_buffer_size: int = 100000,
|
shuffle_buffer_size: int = 100000,
|
||||||
retention_ratio: float = 0.8,
|
retention_ratio: float = 0.8,
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
merge_short_words_prob: float = 0.5,
|
|
||||||
merge_max_short_words: int = 3,
|
|
||||||
merge_max_total_chars: int = 6,
|
|
||||||
low_freq_repeat: float = 50.0,
|
|
||||||
high_freq_repeat: float = 0.1,
|
|
||||||
):
|
):
|
||||||
# 频率调整参数 - 幂律平滑方案
|
# 频率调整参数 (可根据需要调整)
|
||||||
|
self.drop_start_freq = 10_000_000
|
||||||
|
self.max_drop_prob = 0.9
|
||||||
|
self.repeat_end_freq = 10_000
|
||||||
|
self.max_repeat_expect = 50
|
||||||
self.min_freq = 109
|
self.min_freq = 109
|
||||||
self.low_freq_repeat = low_freq_repeat
|
|
||||||
self.high_freq_repeat = high_freq_repeat
|
|
||||||
self.word_break_prob = 0.10
|
self.word_break_prob = 0.10
|
||||||
self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
self.cont_length_probs = [0.05, 0.16, 0.30, 0.20, 0.12, 0.08, 0.05, 0.04]
|
||||||
self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
self._history_weights = [0.2, 0.2, 0.2, 0.9, 1.2, 1.8, 2.5, 3.5, 4.0]
|
||||||
self.merge_short_words_prob = merge_short_words_prob
|
|
||||||
self.merge_max_short_words = merge_max_short_words
|
|
||||||
self.merge_max_total_chars = merge_max_total_chars
|
|
||||||
|
|
||||||
jieba.initialize()
|
jieba.initialize()
|
||||||
|
|
||||||
|
|
@ -122,35 +111,50 @@ class PinyinInputDataset(IterableDataset):
|
||||||
self.sample_freqs = self.query_engine.get_all_weights()
|
self.sample_freqs = self.query_engine.get_all_weights()
|
||||||
self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0
|
self.max_freq = max(self.sample_freqs.values()) if self.sample_freqs else 0
|
||||||
|
|
||||||
# 计算幂律平滑参数
|
|
||||||
if self.max_freq > self.min_freq:
|
|
||||||
self.alpha = math.log(
|
|
||||||
self.low_freq_repeat / self.high_freq_repeat
|
|
||||||
) / math.log(self.max_freq / self.min_freq)
|
|
||||||
self.C = self.low_freq_repeat * (self.min_freq**self.alpha)
|
|
||||||
else:
|
|
||||||
self.alpha = 0.0
|
|
||||||
self.C = 1.0
|
|
||||||
|
|
||||||
def adjust_frequency(self, freq: int) -> int:
|
def adjust_frequency(self, freq: int) -> int:
|
||||||
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃
|
"""削峰填谷 - 根据频率调整采样次数,0表示丢弃"""
|
||||||
使用幂律平滑方案:E(freq) = C × freq^(-α)
|
# 1. 削峰处理(高频字)
|
||||||
保持频率排序关系,单个连续函数
|
if freq >= self.drop_start_freq:
|
||||||
"""
|
# 线性丢弃概率计算
|
||||||
if freq <= 0:
|
max_freq = self.max_freq # 使用预计算的最大频率值
|
||||||
return 0
|
if max_freq <= self.drop_start_freq:
|
||||||
|
drop_prob = 0.0
|
||||||
# 计算期望采样次数
|
|
||||||
expected = self.C * (freq ** (-self.alpha))
|
|
||||||
|
|
||||||
# 采样策略
|
|
||||||
if expected >= 1.0:
|
|
||||||
# 泊松分布重复
|
|
||||||
repeat_count = np.random.poisson(expected)
|
|
||||||
return max(1, repeat_count)
|
|
||||||
else:
|
else:
|
||||||
# 伯努利采样:以概率expected返回1,否则返回0
|
drop_prob = (
|
||||||
return 1 if random.random() < expected else 0
|
self.max_drop_prob
|
||||||
|
* (freq - self.drop_start_freq)
|
||||||
|
/ (max_freq - self.drop_start_freq)
|
||||||
|
)
|
||||||
|
if random.random() < drop_prob:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# 2. 填谷处理(低频字)
|
||||||
|
elif freq <= self.repeat_end_freq:
|
||||||
|
# 线性重复期望计算
|
||||||
|
if freq <= self.min_freq:
|
||||||
|
repeat_expect = self.max_repeat_expect
|
||||||
|
else:
|
||||||
|
if self.repeat_end_freq == self.min_freq:
|
||||||
|
repeat_expect = 0
|
||||||
|
else:
|
||||||
|
repeat_expect = (
|
||||||
|
self.max_repeat_expect
|
||||||
|
* (self.repeat_end_freq - freq)
|
||||||
|
/ (self.repeat_end_freq - self.min_freq)
|
||||||
|
)
|
||||||
|
# 使用泊松分布实现随机重复
|
||||||
|
repeat_count = np.random.poisson(repeat_expect)
|
||||||
|
if repeat_expect < 1.0:
|
||||||
|
# 小期望值时,以概率 repeat_expect 采样 1 次
|
||||||
|
return 1 if random.random() < repeat_expect else 0
|
||||||
|
else:
|
||||||
|
return max(1, repeat_count) # 原逻辑
|
||||||
|
|
||||||
|
# 3. 中间频率字
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
# 生成对应文本的拼音
|
# 生成对应文本的拼音
|
||||||
def generate_pinyin(self, text: str) -> List[str]:
|
def generate_pinyin(self, text: str) -> List[str]:
|
||||||
|
|
@ -205,32 +209,21 @@ class PinyinInputDataset(IterableDataset):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# 生成需要预测汉字对应的拼音,并进行加强
|
||||||
def get_mask_pinyin(
|
def get_mask_pinyin(
|
||||||
self, text: str, pinyin_list: List[str]
|
self, text: str, pinyin_list: List[str]
|
||||||
) -> Tuple[int, List[str]]:
|
) -> Tuple[int, List[str]]:
|
||||||
# 整词统一拼音风格,避免多字词完整拼音概率指数衰减
|
|
||||||
style = random.random()
|
|
||||||
cumulative = 0.0
|
|
||||||
style_idx = 0
|
|
||||||
for i, w in enumerate(self.py_style_weight):
|
|
||||||
cumulative += w
|
|
||||||
if style < cumulative:
|
|
||||||
style_idx = i
|
|
||||||
break
|
|
||||||
|
|
||||||
mask_pinyin = []
|
mask_pinyin = []
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
if not self.query_engine.is_chinese_char(text[i]):
|
if not self.query_engine.is_chinese_char(text[i]):
|
||||||
break
|
break
|
||||||
full_py = pinyin_list[i]
|
|
||||||
if style_idx == 0:
|
|
||||||
py = full_py
|
|
||||||
elif style_idx == 1:
|
|
||||||
py = to_initials(full_py)
|
|
||||||
if py == "":
|
|
||||||
py = full_py[0]
|
|
||||||
else:
|
else:
|
||||||
py = full_py[0]
|
py = np.random.choice(
|
||||||
|
(pinyin_list[i], to_initials(pinyin_list[i]), pinyin_list[i][0]),
|
||||||
|
p=self.py_style_weight,
|
||||||
|
)
|
||||||
|
if py == "":
|
||||||
|
py = pinyin_list[i][0]
|
||||||
mask_pinyin.append(py)
|
mask_pinyin.append(py)
|
||||||
return len(mask_pinyin), mask_pinyin
|
return len(mask_pinyin), mask_pinyin
|
||||||
|
|
||||||
|
|
@ -266,19 +259,13 @@ class PinyinInputDataset(IterableDataset):
|
||||||
repeats = max(1, int(base_repeats * weight))
|
repeats = max(1, int(base_repeats * weight))
|
||||||
|
|
||||||
history = labels[:label_idx]
|
history = labels[:label_idx]
|
||||||
if len(history) > 8:
|
len_h = len(history)
|
||||||
history = history[-8:]
|
history.extend([0] * (8 - len_h))
|
||||||
else:
|
|
||||||
history.extend([0] * (8 - len(history)))
|
|
||||||
|
|
||||||
sample_dict = {
|
sample_dict = {
|
||||||
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
"input_ids": encoded["input_ids"],
|
||||||
"token_type_ids": torch.tensor(
|
"token_type_ids": encoded["token_type_ids"],
|
||||||
encoded["token_type_ids"], dtype=torch.long
|
"attention_mask": encoded["attention_mask"],
|
||||||
),
|
|
||||||
"attention_mask": torch.tensor(
|
|
||||||
encoded["attention_mask"], dtype=torch.long
|
|
||||||
),
|
|
||||||
"label": torch.tensor([label], dtype=torch.long),
|
"label": torch.tensor([label], dtype=torch.long),
|
||||||
"history_slot_ids": torch.tensor(history, dtype=torch.long),
|
"history_slot_ids": torch.tensor(history, dtype=torch.long),
|
||||||
"prefix": f"{part4}^{part1}",
|
"prefix": f"{part4}^{part1}",
|
||||||
|
|
@ -304,12 +291,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if worker_id >= num_workers:
|
if worker_id >= num_workers:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id)
|
||||||
worker_dataset = self.dataset.shard(
|
|
||||||
num_shards=num_workers, index=worker_id
|
|
||||||
)
|
|
||||||
except (IndexError, ValueError):
|
|
||||||
worker_dataset = self.dataset
|
|
||||||
|
|
||||||
total_quota = int(self.max_iter_length)
|
total_quota = int(self.max_iter_length)
|
||||||
base_quota = total_quota // num_workers
|
base_quota = total_quota // num_workers
|
||||||
|
|
@ -339,58 +321,17 @@ class PinyinInputDataset(IterableDataset):
|
||||||
word_boundaries = build_word_boundaries(words)
|
word_boundaries = build_word_boundaries(words)
|
||||||
pinyin_list = self.generate_pinyin(text)
|
pinyin_list = self.generate_pinyin(text)
|
||||||
|
|
||||||
idx = 0
|
for word_start, word_end in word_boundaries:
|
||||||
while idx < len(word_boundaries):
|
|
||||||
word_start, word_end = word_boundaries[idx]
|
|
||||||
|
|
||||||
char_positions = []
|
char_positions = []
|
||||||
for i in range(word_start, word_end):
|
for i in range(word_start, word_end):
|
||||||
if self.query_engine.is_chinese_char(text[i]):
|
if self.query_engine.is_chinese_char(text[i]):
|
||||||
char_positions.append(i)
|
char_positions.append(i)
|
||||||
|
|
||||||
if not char_positions:
|
if not char_positions:
|
||||||
idx += 1
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
word_len_chars = len(char_positions)
|
word_len_chars = len(char_positions)
|
||||||
|
|
||||||
merge_end_idx = idx + 1
|
|
||||||
if word_len_chars <= 2:
|
|
||||||
accumulated_positions = list(char_positions)
|
|
||||||
accumulated_count = 1
|
|
||||||
next_idx = idx + 1
|
|
||||||
|
|
||||||
while next_idx < len(word_boundaries):
|
|
||||||
ns, ne = word_boundaries[next_idx]
|
|
||||||
next_positions = []
|
|
||||||
for i in range(ns, ne):
|
|
||||||
if self.query_engine.is_chinese_char(text[i]):
|
|
||||||
next_positions.append(i)
|
|
||||||
next_len = len(next_positions)
|
|
||||||
|
|
||||||
if next_len == 0 or next_len > 2:
|
|
||||||
break
|
|
||||||
if (
|
|
||||||
len(accumulated_positions) + next_len
|
|
||||||
> self.merge_max_total_chars
|
|
||||||
):
|
|
||||||
break
|
|
||||||
if accumulated_count + 1 > self.merge_max_short_words:
|
|
||||||
break
|
|
||||||
if random.random() > self.merge_short_words_prob:
|
|
||||||
break
|
|
||||||
|
|
||||||
accumulated_positions.extend(next_positions)
|
|
||||||
accumulated_count += 1
|
|
||||||
next_idx += 1
|
|
||||||
|
|
||||||
if accumulated_count > 1:
|
|
||||||
char_positions = accumulated_positions
|
|
||||||
word_len_chars = len(char_positions)
|
|
||||||
merge_end_idx = next_idx
|
|
||||||
word_start = word_boundaries[idx][0]
|
|
||||||
word_end = word_boundaries[next_idx - 1][1]
|
|
||||||
|
|
||||||
should_break = (
|
should_break = (
|
||||||
word_len_chars > 1 and random.random() < self.word_break_prob
|
word_len_chars > 1 and random.random() < self.word_break_prob
|
||||||
)
|
)
|
||||||
|
|
@ -406,15 +347,9 @@ class PinyinInputDataset(IterableDataset):
|
||||||
prefix_pinyin = [pinyin_list[i] for i in prefix_positions]
|
prefix_pinyin = [pinyin_list[i] for i in prefix_positions]
|
||||||
|
|
||||||
_, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin)
|
_, mask_pinyin = self.get_mask_pinyin(prefix_text, prefix_pinyin)
|
||||||
r = random.random()
|
split_char = np.random.choice(
|
||||||
if r < 0.9:
|
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
|
||||||
split_char = ""
|
)
|
||||||
elif r < 0.94:
|
|
||||||
split_char = "`"
|
|
||||||
elif r < 0.98:
|
|
||||||
split_char = "'"
|
|
||||||
else:
|
|
||||||
split_char = "-"
|
|
||||||
part2 = split_char.join(mask_pinyin)
|
part2 = split_char.join(mask_pinyin)
|
||||||
pinyin_ids = self._compute_pinyin_ids(part2)
|
pinyin_ids = self._compute_pinyin_ids(part2)
|
||||||
|
|
||||||
|
|
@ -429,7 +364,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"e: {e}, (text, pinyin): {prefix_text} - {prefix_pinyin}"
|
f"e: {e}, (text, pinyin): {prefix_text} - {prefix_pinyin}"
|
||||||
)
|
)
|
||||||
idx = merge_end_idx
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 整词末尾 10% 概率追加 EOS(破词前缀不加)
|
# 整词末尾 10% 概率追加 EOS(破词前缀不加)
|
||||||
|
|
@ -442,7 +376,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# part3: 词后文本
|
# part3: 词后文本
|
||||||
part3 = ""
|
part3 = ""
|
||||||
if random.random() > 0.7:
|
if random.random() > 0.7:
|
||||||
part3 = text[word_end : word_end + random.randint(1, 16)]
|
part3 = text[word_end : word_end + np.random.choice(range(1, 17))]
|
||||||
|
|
||||||
# part4: 词提示
|
# part4: 词提示
|
||||||
part4 = ""
|
part4 = ""
|
||||||
|
|
@ -458,7 +392,9 @@ class PinyinInputDataset(IterableDataset):
|
||||||
f"{part4}|{part1}",
|
f"{part4}|{part1}",
|
||||||
part3,
|
part3,
|
||||||
max_length=self.max_seq_length,
|
max_length=self.max_seq_length,
|
||||||
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
return_token_type_ids=True,
|
return_token_type_ids=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -478,15 +414,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
cont_start = char_positions[break_pos]
|
cont_start = char_positions[break_pos]
|
||||||
|
|
||||||
# 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止
|
# 续接目标:从断点开始,可延伸到后续词,遇到非汉字停止
|
||||||
cont_r = random.random()
|
target_len = np.random.choice(range(1, 9), p=self.cont_length_probs)
|
||||||
cont_probs = self.cont_length_probs
|
|
||||||
cont_cumulative = 0.0
|
|
||||||
target_len = 4
|
|
||||||
for cont_len, cont_p in enumerate(cont_probs):
|
|
||||||
cont_cumulative += cont_p
|
|
||||||
if cont_r < cont_cumulative:
|
|
||||||
target_len = cont_len + 1
|
|
||||||
break
|
|
||||||
cont_positions = []
|
cont_positions = []
|
||||||
pos = cont_start
|
pos = cont_start
|
||||||
while len(cont_positions) < target_len and pos < len(text):
|
while len(cont_positions) < target_len and pos < len(text):
|
||||||
|
|
@ -503,15 +431,9 @@ class PinyinInputDataset(IterableDataset):
|
||||||
cont_pinyin = [pinyin_list[i] for i in cont_positions]
|
cont_pinyin = [pinyin_list[i] for i in cont_positions]
|
||||||
|
|
||||||
_, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin)
|
_, mask_pinyin_cont = self.get_mask_pinyin(cont_text, cont_pinyin)
|
||||||
r2 = random.random()
|
split_char_cont = np.random.choice(
|
||||||
if r2 < 0.9:
|
["", "`", "'", "-"], p=[0.9, 0.04, 0.04, 0.02]
|
||||||
split_char_cont = ""
|
)
|
||||||
elif r2 < 0.94:
|
|
||||||
split_char_cont = "`"
|
|
||||||
elif r2 < 0.98:
|
|
||||||
split_char_cont = "'"
|
|
||||||
else:
|
|
||||||
split_char_cont = "-"
|
|
||||||
part2_cont = split_char_cont.join(mask_pinyin_cont)
|
part2_cont = split_char_cont.join(mask_pinyin_cont)
|
||||||
pinyin_ids_cont = self._compute_pinyin_ids(part2_cont)
|
pinyin_ids_cont = self._compute_pinyin_ids(part2_cont)
|
||||||
|
|
||||||
|
|
@ -526,7 +448,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"e: {e}, (text, pinyin): {cont_text} - {cont_pinyin}"
|
f"e: {e}, (text, pinyin): {cont_text} - {cont_pinyin}"
|
||||||
)
|
)
|
||||||
idx = merge_end_idx
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 续接末尾 10% 概率追加 EOS
|
# 续接末尾 10% 概率追加 EOS
|
||||||
|
|
@ -540,13 +461,17 @@ class PinyinInputDataset(IterableDataset):
|
||||||
cont_end = cont_positions[-1] + 1
|
cont_end = cont_positions[-1] + 1
|
||||||
part3_cont = ""
|
part3_cont = ""
|
||||||
if random.random() > 0.7:
|
if random.random() > 0.7:
|
||||||
part3_cont = text[cont_end : cont_end + random.randint(1, 16)]
|
part3_cont = text[
|
||||||
|
cont_end : cont_end + np.random.choice(range(1, 17))
|
||||||
|
]
|
||||||
|
|
||||||
encoded_cont = self.tokenizer(
|
encoded_cont = self.tokenizer(
|
||||||
f"{part4}|{part1_cont}",
|
f"{part4}|{part1_cont}",
|
||||||
part3_cont,
|
part3_cont,
|
||||||
max_length=self.max_seq_length,
|
max_length=self.max_seq_length,
|
||||||
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
return_token_type_ids=True,
|
return_token_type_ids=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -561,8 +486,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
pinyin_ids_cont,
|
pinyin_ids_cont,
|
||||||
)
|
)
|
||||||
|
|
||||||
idx = merge_end_idx
|
|
||||||
|
|
||||||
# 处理shuffle buffer - 单缓冲区半保留方案
|
# 处理shuffle buffer - 单缓冲区半保留方案
|
||||||
if len(batch_samples) >= self.shuffle_buffer_size:
|
if len(batch_samples) >= self.shuffle_buffer_size:
|
||||||
indices = np.random.permutation(len(batch_samples))
|
indices = np.random.permutation(len(batch_samples))
|
||||||
|
|
|
||||||
|
|
@ -1,401 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
预处理数据质量分析脚本
|
|
||||||
|
|
||||||
功能:
|
|
||||||
1. 统计 labels 的分布(出现次数、比例,最大/最小,未出现标签数)
|
|
||||||
2. 随机抽样还原为人类可读文本,导出为 CSV 文件
|
|
||||||
|
|
||||||
用法:
|
|
||||||
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train
|
|
||||||
python -m model.inspect_preprocessed --data-dir /path/to/preprocessed/train --num-samples 50 --output samples.csv
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import csv
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
from collections import Counter
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from loguru import logger
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.table import Table
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .char_info import CharInfo
|
|
||||||
from .dataset import CHAR_TO_ID
|
|
||||||
from .preprocessed_dataset import PreProcessedDataset
|
|
||||||
from .query import QueryEngine
|
|
||||||
|
|
||||||
ID_TO_CHAR = {v: k for k, v in CHAR_TO_ID.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def decode_pinyin_ids(pinyin_ids: list) -> str:
|
|
||||||
"""将 pinyin_ids 还原为拼音字符串"""
|
|
||||||
chars = []
|
|
||||||
for pid in pinyin_ids:
|
|
||||||
if pid == 0:
|
|
||||||
break
|
|
||||||
chars.append(ID_TO_CHAR.get(pid, "?"))
|
|
||||||
return "".join(chars)
|
|
||||||
|
|
||||||
|
|
||||||
def decode_history(history_ids: list, query_engine: QueryEngine) -> str:
|
|
||||||
"""将 history_slot_ids 还原为文字"""
|
|
||||||
parts = []
|
|
||||||
for hid in history_ids:
|
|
||||||
if hid == 0:
|
|
||||||
parts.append("<PAD>")
|
|
||||||
else:
|
|
||||||
info = query_engine.query_by_id(hid)
|
|
||||||
if info is not None:
|
|
||||||
parts.append(f"{info.char}({info.pinyin})")
|
|
||||||
else:
|
|
||||||
parts.append(f"<ID:{hid}>")
|
|
||||||
return " | ".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
def analyze_labels(dataset: PreProcessedDataset, max_shards: int = 0):
|
|
||||||
"""统计 labels 分布,带进度条"""
|
|
||||||
logger.info("正在统计 labels 分布...")
|
|
||||||
counter = Counter()
|
|
||||||
total = 0
|
|
||||||
|
|
||||||
num_shards = dataset._num_shards if dataset._is_sharded else 1
|
|
||||||
effective_shards = min(num_shards, max_shards) if max_shards > 0 else num_shards
|
|
||||||
|
|
||||||
pbar = tqdm(range(effective_shards), desc="统计 labels", unit="shard")
|
|
||||||
|
|
||||||
for shard_idx in pbar:
|
|
||||||
if dataset._is_sharded:
|
|
||||||
shard_data = dict(np.load(dataset.data_dir / f"shard_{shard_idx:06d}.npz"))
|
|
||||||
labels = shard_data["labels"].astype(np.int64)
|
|
||||||
else:
|
|
||||||
labels = dataset.labels[:].astype(np.int64)
|
|
||||||
|
|
||||||
unique, counts = np.unique(labels, return_counts=True)
|
|
||||||
for uid, cnt in zip(unique, counts):
|
|
||||||
counter[int(uid)] += cnt
|
|
||||||
total += len(labels)
|
|
||||||
|
|
||||||
if dataset._is_sharded:
|
|
||||||
del shard_data
|
|
||||||
|
|
||||||
return counter, total
|
|
||||||
|
|
||||||
|
|
||||||
def decode_sample(sample: dict, tokenizer, query_engine: QueryEngine) -> dict:
|
|
||||||
"""将一个样本还原为人类可读格式"""
|
|
||||||
input_ids = (
|
|
||||||
sample["input_ids"].tolist()
|
|
||||||
if hasattr(sample["input_ids"], "tolist")
|
|
||||||
else sample["input_ids"]
|
|
||||||
)
|
|
||||||
token_type_ids = (
|
|
||||||
sample["token_type_ids"].tolist()
|
|
||||||
if hasattr(sample["token_type_ids"], "tolist")
|
|
||||||
else sample["token_type_ids"]
|
|
||||||
)
|
|
||||||
labels = (
|
|
||||||
sample["labels"].item()
|
|
||||||
if hasattr(sample["labels"], "item")
|
|
||||||
else sample["labels"]
|
|
||||||
)
|
|
||||||
history_ids = (
|
|
||||||
sample["history_slot_ids"].tolist()
|
|
||||||
if hasattr(sample["history_slot_ids"], "tolist")
|
|
||||||
else sample["history_slot_ids"]
|
|
||||||
)
|
|
||||||
pinyin_ids = (
|
|
||||||
sample["pinyin_ids"].tolist()
|
|
||||||
if hasattr(sample["pinyin_ids"], "tolist")
|
|
||||||
else sample["pinyin_ids"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 还原 token 文本
|
|
||||||
token_text = tokenizer.decode(input_ids, skip_special_tokens=False)
|
|
||||||
|
|
||||||
# 找到 token_type_ids 切换点,分离 sentence A 和 sentence B
|
|
||||||
sep_positions = [i for i, tid in enumerate(token_type_ids) if tid == 1]
|
|
||||||
if sep_positions:
|
|
||||||
sep_start = sep_positions[0]
|
|
||||||
sent_a_ids = [
|
|
||||||
tid
|
|
||||||
for tid, tt in zip(input_ids[:sep_start], token_type_ids[:sep_start])
|
|
||||||
if tt == 0
|
|
||||||
]
|
|
||||||
sent_b_ids = [tid for tid, tt in zip(input_ids, token_type_ids) if tt == 1]
|
|
||||||
else:
|
|
||||||
sent_a_ids = input_ids
|
|
||||||
sent_b_ids = []
|
|
||||||
|
|
||||||
context_text = tokenizer.decode(sent_a_ids, skip_special_tokens=True)
|
|
||||||
suffix_text = (
|
|
||||||
tokenizer.decode(sent_b_ids, skip_special_tokens=True) if sent_b_ids else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
pinyin_str = decode_pinyin_ids(pinyin_ids)
|
|
||||||
label_info = query_engine.query_by_id(labels)
|
|
||||||
history_str = decode_history(history_ids, query_engine)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"context": context_text,
|
|
||||||
"suffix": suffix_text,
|
|
||||||
"pinyin": pinyin_str,
|
|
||||||
"label_id": labels,
|
|
||||||
"label_char": f"{label_info.char}({label_info.pinyin})"
|
|
||||||
if label_info
|
|
||||||
else f"<ID:{labels}>",
|
|
||||||
"label_count": label_info.count if label_info else 0,
|
|
||||||
"history": history_str,
|
|
||||||
"full_tokens": token_text,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
console = Console()
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="预处理数据质量分析")
|
|
||||||
parser.add_argument(
|
|
||||||
"--data-dir",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="预处理数据目录(train/ 或 eval/)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-samples",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="随机抽样的样本数量(默认50)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="CSV 输出文件路径(默认: <data-dir>/samples.csv)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-shards",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="统计 labels 时最多读取的分片数(0=全部)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="随机种子",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--top-k",
|
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help="显示出现次数最多和最少的标签数量",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
random.seed(args.seed)
|
|
||||||
np.random.seed(args.seed)
|
|
||||||
|
|
||||||
if args.output is None:
|
|
||||||
args.output = str(Path(args.data_dir) / "samples.csv")
|
|
||||||
|
|
||||||
# 加载数据集
|
|
||||||
logger.info(f"加载数据集: {args.data_dir}")
|
|
||||||
dataset = PreProcessedDataset(args.data_dir)
|
|
||||||
console.print(f"[bold cyan]数据集: {len(dataset):,} 个样本[/bold cyan]")
|
|
||||||
if dataset._is_sharded:
|
|
||||||
console.print(
|
|
||||||
f" 分片数: {dataset._num_shards}, 每分片: {dataset._shard_size:,} 样本"
|
|
||||||
)
|
|
||||||
console.print()
|
|
||||||
|
|
||||||
# 加载 QueryEngine
|
|
||||||
logger.info("加载 QueryEngine...")
|
|
||||||
query_engine = QueryEngine()
|
|
||||||
query_engine.load()
|
|
||||||
|
|
||||||
# 加载 Tokenizer
|
|
||||||
logger.info("加载 Tokenizer...")
|
|
||||||
from importlib.resources import files as pkg_files
|
|
||||||
from modelscope import AutoTokenizer
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
Path(str(pkg_files(__package__))) / "assets" / "tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ====== 1. Labels 分布分析 ======
|
|
||||||
console.print("[bold yellow]====== Labels 分布分析 ======[/bold yellow]")
|
|
||||||
counter, total = analyze_labels(dataset, max_shards=args.max_shards)
|
|
||||||
|
|
||||||
# 获取词表总大小(_id_to_info 已包含 EOS id=0)
|
|
||||||
vocab_size = len(query_engine._id_to_info)
|
|
||||||
|
|
||||||
appeared_ids = set(counter.keys())
|
|
||||||
all_ids = set(query_engine._id_to_info.keys())
|
|
||||||
missing_ids = all_ids - appeared_ids
|
|
||||||
|
|
||||||
console.print(f"\n总样本数: {total:,}")
|
|
||||||
console.print(f"词表大小: {vocab_size:,} (含 EOS)")
|
|
||||||
console.print(f"唯一标签数: {len(counter):,}")
|
|
||||||
console.print(
|
|
||||||
f"EOS (id=0) 出现次数: {counter.get(0, 0):,} ({counter.get(0, 0) / total * 100:.2f}%)"
|
|
||||||
)
|
|
||||||
console.print(
|
|
||||||
f"[bold red]未出现的标签数: {len(missing_ids):,} / {vocab_size:,} ({len(missing_ids) / vocab_size * 100:.2f}%)[/bold red]"
|
|
||||||
)
|
|
||||||
|
|
||||||
most_common = counter.most_common(args.top_k)
|
|
||||||
least_common = (
|
|
||||||
counter.most_common()[: -args.top_k - 1 : -1]
|
|
||||||
if len(counter) > args.top_k
|
|
||||||
else counter.most_common()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 最多标签表
|
|
||||||
table_top = Table(
|
|
||||||
title=f"出现次数最多的 {args.top_k} 个标签",
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold magenta",
|
|
||||||
)
|
|
||||||
table_top.add_column("排名", style="cyan", width=6)
|
|
||||||
table_top.add_column("ID", style="green", width=8)
|
|
||||||
table_top.add_column("字符(拼音)", style="yellow", width=20)
|
|
||||||
table_top.add_column("频次", style="red", width=12)
|
|
||||||
table_top.add_column("占比", style="blue", width=10)
|
|
||||||
|
|
||||||
for rank, (label_id, count) in enumerate(most_common, 1):
|
|
||||||
info = query_engine.query_by_id(label_id)
|
|
||||||
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
|
|
||||||
pct = count / total * 100
|
|
||||||
table_top.add_row(
|
|
||||||
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.3f}%"
|
|
||||||
)
|
|
||||||
console.print(table_top)
|
|
||||||
|
|
||||||
# 最少标签表
|
|
||||||
table_bottom = Table(
|
|
||||||
title=f"出现次数最少的 {min(args.top_k, len(counter))} 个标签",
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold magenta",
|
|
||||||
)
|
|
||||||
table_bottom.add_column("排名", style="cyan", width=6)
|
|
||||||
table_bottom.add_column("ID", style="green", width=8)
|
|
||||||
table_bottom.add_column("字符(拼音)", style="yellow", width=20)
|
|
||||||
table_bottom.add_column("频次", style="red", width=12)
|
|
||||||
table_bottom.add_column("占比", style="blue", width=10)
|
|
||||||
|
|
||||||
for rank, (label_id, count) in enumerate(least_common, 1):
|
|
||||||
info = query_engine.query_by_id(label_id)
|
|
||||||
label_str = f"{info.char}({info.pinyin})" if info else f"<ID:{label_id}>"
|
|
||||||
pct = count / total * 100
|
|
||||||
table_bottom.add_row(
|
|
||||||
str(rank), str(label_id), label_str, f"{count:,}", f"{pct:.6f}%"
|
|
||||||
)
|
|
||||||
console.print(table_bottom)
|
|
||||||
|
|
||||||
# 频次分布概览
|
|
||||||
table_dist = Table(
|
|
||||||
title="频次分布概览", show_header=True, header_style="bold magenta"
|
|
||||||
)
|
|
||||||
table_dist.add_column("频次区间", style="cyan")
|
|
||||||
table_dist.add_column("标签数", style="green")
|
|
||||||
table_dist.add_column("占总标签数比例", style="yellow")
|
|
||||||
|
|
||||||
bins = [
|
|
||||||
(1, 10),
|
|
||||||
(11, 100),
|
|
||||||
(101, 1000),
|
|
||||||
(1001, 10000),
|
|
||||||
(10001, 100000),
|
|
||||||
(100001, 1000000),
|
|
||||||
(1000001, float("inf")),
|
|
||||||
]
|
|
||||||
for lo, hi in bins:
|
|
||||||
count_in_bin = sum(1 for c in counter.values() if lo <= c <= hi)
|
|
||||||
if count_in_bin > 0:
|
|
||||||
hi_str = str(int(hi)) if hi != float("inf") else "∞"
|
|
||||||
table_dist.add_row(
|
|
||||||
f"{lo}-{hi_str}",
|
|
||||||
f"{count_in_bin:,}",
|
|
||||||
f"{count_in_bin / len(counter) * 100:.1f}%",
|
|
||||||
)
|
|
||||||
# 未出现
|
|
||||||
if len(missing_ids) > 0:
|
|
||||||
table_dist.add_row(
|
|
||||||
"未出现",
|
|
||||||
f"{len(missing_ids):,}",
|
|
||||||
f"{len(missing_ids) / vocab_size * 100:.1f}%",
|
|
||||||
)
|
|
||||||
console.print(table_dist)
|
|
||||||
|
|
||||||
# ====== 2. 随机抽样还原 → CSV ======
|
|
||||||
num_samples = min(args.num_samples, len(dataset))
|
|
||||||
console.print(
|
|
||||||
f"\n[bold yellow]====== 随机抽样还原 ({num_samples} 个样本) → {args.output} ======[/bold yellow]"
|
|
||||||
)
|
|
||||||
|
|
||||||
indices = random.sample(range(len(dataset)), num_samples)
|
|
||||||
|
|
||||||
csv_path = Path(args.output)
|
|
||||||
csv_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
csv_headers = [
|
|
||||||
"index",
|
|
||||||
"pinyin",
|
|
||||||
"label_char",
|
|
||||||
"label_id",
|
|
||||||
"label_count",
|
|
||||||
"context",
|
|
||||||
"suffix",
|
|
||||||
"history",
|
|
||||||
"full_tokens",
|
|
||||||
]
|
|
||||||
|
|
||||||
with open(csv_path, "w", encoding="utf-8", newline="") as f:
|
|
||||||
writer = csv.writer(f)
|
|
||||||
writer.writerow(csv_headers)
|
|
||||||
|
|
||||||
for i, idx in enumerate(tqdm(indices, desc="解码样本", unit="sample")):
|
|
||||||
sample = dataset[idx]
|
|
||||||
decoded = decode_sample(sample, tokenizer, query_engine)
|
|
||||||
writer.writerow(
|
|
||||||
[
|
|
||||||
idx,
|
|
||||||
decoded["pinyin"],
|
|
||||||
decoded["label_char"],
|
|
||||||
decoded["label_id"],
|
|
||||||
decoded["label_count"],
|
|
||||||
decoded["context"],
|
|
||||||
decoded["suffix"],
|
|
||||||
decoded["history"],
|
|
||||||
decoded["full_tokens"],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print(
|
|
||||||
f"[bold green]✓ 已导出 {num_samples} 个样本到 {csv_path}[/bold green]"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 打印前 5 个样本的概要
|
|
||||||
console.print(f"\n[bold cyan]前 {min(5, num_samples)} 个样本概览:[/bold cyan]")
|
|
||||||
with open(csv_path, "r", encoding="utf-8") as f:
|
|
||||||
reader = csv.DictReader(f)
|
|
||||||
for i, row in enumerate(reader):
|
|
||||||
if i >= 5:
|
|
||||||
break
|
|
||||||
console.print(
|
|
||||||
f" [{i + 1}] 拼音={row['pinyin']} 目标={row['label_char']} "
|
|
||||||
f"上下文={row['context'][:50]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print("\n[bold green]分析完成[/bold green]")
|
|
||||||
|
|
||||||
|
|
||||||
app = main
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,335 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
预处理脚本:将 PinyinInputDataset 的输出转换为分片压缩 .npz 文件。
|
|
||||||
|
|
||||||
采用分片流式写入,内存峰值固定为 shard_size 级别,不随总样本数增长。
|
|
||||||
每个分片使用 np.savez_compressed 保存(zlib 压缩),GPU 服务器无需解压到硬盘。
|
|
||||||
|
|
||||||
用法:
|
|
||||||
python -m model.preprocess \
|
|
||||||
--train-data-path "some/hf_dataset" \
|
|
||||||
--eval-data-path "some/hf_dataset" \
|
|
||||||
--output-dir ./preprocessed \
|
|
||||||
--num-train-samples 5000000 \
|
|
||||||
--num-eval-samples 8192
|
|
||||||
|
|
||||||
生成目录结构:
|
|
||||||
output_dir/
|
|
||||||
train/
|
|
||||||
metadata.json
|
|
||||||
shard_000.npz (5M样本, 6个字段, zlib压缩)
|
|
||||||
shard_001.npz
|
|
||||||
...
|
|
||||||
eval/
|
|
||||||
metadata.json
|
|
||||||
shard_000.npz
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import gc
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
|
||||||
from rich.console import Console
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .dataset import PinyinInputDataset
|
|
||||||
from .trainer import preprocess_collate_fn, worker_init_fn
|
|
||||||
|
|
||||||
FIELDS = [
|
|
||||||
"input_ids",
|
|
||||||
"token_type_ids",
|
|
||||||
"attention_mask",
|
|
||||||
"labels",
|
|
||||||
"history_slot_ids",
|
|
||||||
"pinyin_ids",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_batch(batch: dict, take: int) -> Dict[str, np.ndarray]:
|
|
||||||
"""从 DataLoader batch 中提取指定数量的样本,转为 int16 numpy 数组"""
|
|
||||||
result = {}
|
|
||||||
for f in FIELDS:
|
|
||||||
tensor = batch[f][:take]
|
|
||||||
arr = tensor.numpy().astype(np.int16)
|
|
||||||
if f == "labels" and arr.ndim > 1 and arr.shape[-1] == 1:
|
|
||||||
arr = arr.squeeze(-1)
|
|
||||||
result[f] = arr
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def collect_samples(
|
|
||||||
dataloader: DataLoader,
|
|
||||||
num_samples: int,
|
|
||||||
output_dir: Path,
|
|
||||||
split_name: str,
|
|
||||||
max_seq_length: int = 128,
|
|
||||||
shard_size: int = 5_000_000,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
分片流式收集样本,每累积 shard_size 个样本保存为一个压缩 .npz 分片。
|
|
||||||
|
|
||||||
内存峰值 = shard_size × 每样本字节数(约578字节 @ shard_size=5M → 约2.9GB)
|
|
||||||
"""
|
|
||||||
split_dir = output_dir / split_name
|
|
||||||
split_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
shard_buffers: Dict[str, List[np.ndarray]] = {f: [] for f in FIELDS}
|
|
||||||
shard_count = 0
|
|
||||||
shard_idx = 0
|
|
||||||
total = 0
|
|
||||||
|
|
||||||
pbar = tqdm(total=num_samples, desc=f"Processing {split_name}", unit="samples")
|
|
||||||
|
|
||||||
for batch in dataloader:
|
|
||||||
batch_size = batch["input_ids"].size(0)
|
|
||||||
remaining = num_samples - total
|
|
||||||
if remaining <= 0:
|
|
||||||
break
|
|
||||||
take = min(batch_size, remaining)
|
|
||||||
|
|
||||||
extracted = _extract_batch(batch, take)
|
|
||||||
for f in FIELDS:
|
|
||||||
shard_buffers[f].append(extracted[f])
|
|
||||||
|
|
||||||
shard_count += take
|
|
||||||
total += take
|
|
||||||
pbar.update(take)
|
|
||||||
|
|
||||||
if shard_count >= shard_size:
|
|
||||||
merged = {}
|
|
||||||
for f in FIELDS:
|
|
||||||
merged[f] = np.concatenate(shard_buffers[f], axis=0)
|
|
||||||
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
|
|
||||||
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
|
|
||||||
|
|
||||||
shard_idx += 1
|
|
||||||
shard_buffers = {f: [] for f in FIELDS}
|
|
||||||
shard_count = 0
|
|
||||||
del merged
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
if total >= num_samples:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 写入最后一个不满的分片
|
|
||||||
if shard_count > 0:
|
|
||||||
merged = {}
|
|
||||||
for f in FIELDS:
|
|
||||||
merged[f] = np.concatenate(shard_buffers[f], axis=0)
|
|
||||||
np.savez_compressed(split_dir / f"shard_{shard_idx:06d}.npz", **merged)
|
|
||||||
logger.debug(f"Saved {split_name} shard {shard_idx}: {shard_count} samples")
|
|
||||||
shard_idx += 1
|
|
||||||
|
|
||||||
pbar.close()
|
|
||||||
|
|
||||||
actual_count = min(total, num_samples)
|
|
||||||
num_shards = shard_idx
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"num_samples": actual_count,
|
|
||||||
"max_seq_length": max_seq_length,
|
|
||||||
"dtype": "int16",
|
|
||||||
"fields": FIELDS,
|
|
||||||
"shard_size": shard_size,
|
|
||||||
"num_shards": num_shards,
|
|
||||||
}
|
|
||||||
with open(split_dir / "metadata.json", "w", encoding="utf-8") as fp:
|
|
||||||
json.dump(metadata, fp, indent=2, ensure_ascii=False)
|
|
||||||
|
|
||||||
total_size = sum(
|
|
||||||
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"{split_name}: {actual_count} samples in {num_shards} shards, "
|
|
||||||
f"{total_size / (1024**3):.2f} GB (compressed)"
|
|
||||||
)
|
|
||||||
|
|
||||||
return actual_count
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
console = Console()
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="预处理数据集为分片压缩npz文件")
|
|
||||||
parser.add_argument(
|
|
||||||
"--train-data-path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="训练数据集路径(HuggingFace格式)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--eval-data-path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="评估数据集路径(HuggingFace格式)",
|
|
||||||
)
|
|
||||||
parser.add_argument("--output-dir", type=str, required=True, help="输出目录")
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-train-samples", type=int, required=True, help="训练集样本数量"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-eval-samples", type=int, required=True, help="评估集样本数量"
|
|
||||||
)
|
|
||||||
parser.add_argument("--batch-size", type=int, default=128, help="批大小")
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-workers", type=int, default=2, help="DataLoader worker数量"
|
|
||||||
)
|
|
||||||
parser.add_argument("--max-seq-length", type=int, default=128, help="最大序列长度")
|
|
||||||
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
|
||||||
parser.add_argument(
|
|
||||||
"--shard-size",
|
|
||||||
type=int,
|
|
||||||
default=5_000_000,
|
|
||||||
help="分片大小(样本数),控制内存峰值(默认500万,约2.9GB/分片未压缩)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--py-style-weight",
|
|
||||||
type=str,
|
|
||||||
default="9,2,1",
|
|
||||||
help="拼音风格权重(逗号分隔)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--shuffle-buffer-size",
|
|
||||||
type=int,
|
|
||||||
default=2000000,
|
|
||||||
help="数据集shuffle缓冲区大小",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--length-weights",
|
|
||||||
type=str,
|
|
||||||
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
|
|
||||||
help="词长权重",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
torch.manual_seed(args.seed)
|
|
||||||
np.random.seed(args.seed)
|
|
||||||
|
|
||||||
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
|
|
||||||
length_weights = {
|
|
||||||
int(k): int(v)
|
|
||||||
for k, v in (item.split(":") for item in args.length_weights.split(","))
|
|
||||||
}
|
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
train_max_iter = args.num_train_samples * 5
|
|
||||||
eval_max_iter = args.num_eval_samples * 5
|
|
||||||
|
|
||||||
shard_mem_gb = args.shard_size * 578 / (1024**3)
|
|
||||||
console.print("[bold cyan]=== 数据预处理 ===[/bold cyan]")
|
|
||||||
console.print(f"训练集目标: {args.num_train_samples:,} 样本")
|
|
||||||
console.print(f"评估集目标: {args.num_eval_samples:,} 样本")
|
|
||||||
console.print(f"输出目录: {output_dir}")
|
|
||||||
console.print(f"数据类型: int16")
|
|
||||||
console.print(
|
|
||||||
f"分片大小: {args.shard_size:,} 样本 (约 {shard_mem_gb:.1f} GB/分片 未压缩)"
|
|
||||||
)
|
|
||||||
console.print()
|
|
||||||
|
|
||||||
num_train_workers = args.num_workers
|
|
||||||
num_eval_workers = max(1, args.num_workers // 2)
|
|
||||||
|
|
||||||
console.print("[bold cyan]创建训练数据集...[/bold cyan]")
|
|
||||||
train_dataset = PinyinInputDataset(
|
|
||||||
data_path=args.train_data_path,
|
|
||||||
max_workers=num_train_workers,
|
|
||||||
max_iter_length=train_max_iter,
|
|
||||||
max_seq_length=args.max_seq_length,
|
|
||||||
text_field="text",
|
|
||||||
py_style_weight=py_style_weight,
|
|
||||||
shuffle_buffer_size=100,
|
|
||||||
length_weights=length_weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
num_workers=num_train_workers,
|
|
||||||
pin_memory=False,
|
|
||||||
worker_init_fn=worker_init_fn,
|
|
||||||
collate_fn=preprocess_collate_fn(args.max_seq_length),
|
|
||||||
prefetch_factor=2,
|
|
||||||
persistent_workers=True if num_train_workers > 0 else False,
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print("[bold cyan]创建评估数据集...[/bold cyan]")
|
|
||||||
eval_dataset = PinyinInputDataset(
|
|
||||||
data_path=args.eval_data_path,
|
|
||||||
max_workers=num_eval_workers,
|
|
||||||
max_iter_length=eval_max_iter,
|
|
||||||
max_seq_length=args.max_seq_length,
|
|
||||||
text_field="text",
|
|
||||||
py_style_weight=py_style_weight,
|
|
||||||
shuffle_buffer_size=100,
|
|
||||||
length_weights=length_weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_dataloader = DataLoader(
|
|
||||||
eval_dataset,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
num_workers=num_eval_workers,
|
|
||||||
pin_memory=False,
|
|
||||||
worker_init_fn=worker_init_fn,
|
|
||||||
collate_fn=preprocess_collate_fn(args.max_seq_length),
|
|
||||||
prefetch_factor=2,
|
|
||||||
persistent_workers=True if num_eval_workers > 0 else False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("开始收集训练数据...")
|
|
||||||
train_count = collect_samples(
|
|
||||||
train_dataloader,
|
|
||||||
args.num_train_samples,
|
|
||||||
output_dir,
|
|
||||||
"train",
|
|
||||||
args.max_seq_length,
|
|
||||||
args.shard_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
if train_count < args.num_train_samples:
|
|
||||||
logger.warning(
|
|
||||||
f"训练集样本不足: 目标 {args.num_train_samples}, 实际 {train_count}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("开始收集评估数据...")
|
|
||||||
eval_count = collect_samples(
|
|
||||||
eval_dataloader,
|
|
||||||
args.num_eval_samples,
|
|
||||||
output_dir,
|
|
||||||
"eval",
|
|
||||||
args.max_seq_length,
|
|
||||||
args.shard_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
if eval_count < args.num_eval_samples:
|
|
||||||
logger.warning(
|
|
||||||
f"评估集样本不足: 目标 {args.num_eval_samples}, 实际 {eval_count}"
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print("\n[bold green]=== 预处理完成 ===[/bold green]")
|
|
||||||
console.print(f"训练集: {train_count:,} 样本")
|
|
||||||
console.print(f"评估集: {eval_count:,} 样本")
|
|
||||||
console.print(f"输出目录: {output_dir}")
|
|
||||||
|
|
||||||
for split in ["train", "eval"]:
|
|
||||||
split_dir = output_dir / split
|
|
||||||
if split_dir.exists():
|
|
||||||
total_size = sum(
|
|
||||||
f.stat().st_size for f in split_dir.iterdir() if f.suffix == ".npz"
|
|
||||||
)
|
|
||||||
console.print(f"{split}/: {total_size / (1024**3):.2f} GB (compressed)")
|
|
||||||
|
|
||||||
|
|
||||||
app = main
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,188 +0,0 @@
|
||||||
"""
|
|
||||||
预处理数据集加载器
|
|
||||||
|
|
||||||
支持两种格式:
|
|
||||||
1. 分片压缩格式(.npz):从压缩文件中按需加载分片,LRU 缓存最多持有 max_cache_shards 个分片
|
|
||||||
2. 单体格式(.npy):向后兼容,使用 mmap 零拷贝加载
|
|
||||||
|
|
||||||
GPU 服务器仅需存放压缩后的 .npz 文件,无需解压到硬盘。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gc
|
|
||||||
import json
|
|
||||||
from collections import OrderedDict
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
FIELDS = [
|
|
||||||
"input_ids",
|
|
||||||
"token_type_ids",
|
|
||||||
"attention_mask",
|
|
||||||
"labels",
|
|
||||||
"history_slot_ids",
|
|
||||||
"pinyin_ids",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def is_preprocessed_data(path: str) -> bool:
|
|
||||||
"""判断路径是否为预处理数据目录"""
|
|
||||||
p = Path(path)
|
|
||||||
return p.is_dir() and (p / "metadata.json").exists()
|
|
||||||
|
|
||||||
|
|
||||||
class _ShardCache:
|
|
||||||
"""LRU 缓存,管理按需加载的 .npz 分片,最多持有 max_size 个分片"""
|
|
||||||
|
|
||||||
def __init__(self, max_size: int = 2):
|
|
||||||
self.max_size = max_size
|
|
||||||
self._cache: OrderedDict[int, Dict[str, np.ndarray]] = OrderedDict()
|
|
||||||
|
|
||||||
def get(self, shard_idx: int, loader_fn) -> Dict[str, np.ndarray]:
|
|
||||||
if shard_idx in self._cache:
|
|
||||||
self._cache.move_to_end(shard_idx)
|
|
||||||
return self._cache[shard_idx]
|
|
||||||
|
|
||||||
data = loader_fn(shard_idx)
|
|
||||||
self._cache[shard_idx] = data
|
|
||||||
self._cache.move_to_end(shard_idx)
|
|
||||||
|
|
||||||
while len(self._cache) > self.max_size:
|
|
||||||
evicted_key, evicted_data = self._cache.popitem(last=False)
|
|
||||||
del evicted_data
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
self._cache.clear()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
|
|
||||||
class PreProcessedDataset(Dataset):
|
|
||||||
"""
|
|
||||||
预处理数据集加载器,自动检测数据格式:
|
|
||||||
|
|
||||||
- 分片压缩格式(metadata.json 含 shard_size 字段):
|
|
||||||
从 .npz 分片按需加载,LRU 缓存控制内存
|
|
||||||
- 单体格式(向后兼容):
|
|
||||||
mmap 零拷贝加载 .npy 文件
|
|
||||||
|
|
||||||
所有数据以 int16 存储,读取时转为 torch.long (int64)。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data_dir: str, max_cache_shards: int = 2):
|
|
||||||
self.data_dir = Path(data_dir)
|
|
||||||
|
|
||||||
with open(self.data_dir / "metadata.json", "r", encoding="utf-8") as f:
|
|
||||||
self.metadata = json.load(f)
|
|
||||||
|
|
||||||
self.num_samples = self.metadata["num_samples"]
|
|
||||||
self.max_seq_length = self.metadata["max_seq_length"]
|
|
||||||
self._shard_size: Optional[int] = self.metadata.get("shard_size")
|
|
||||||
self._num_shards: Optional[int] = self.metadata.get("num_shards")
|
|
||||||
|
|
||||||
if self._shard_size is not None and self._num_shards is not None:
|
|
||||||
self._is_sharded = True
|
|
||||||
self._cache = _ShardCache(max_size=max_cache_shards)
|
|
||||||
logger.info(
|
|
||||||
f"Loaded sharded dataset: {self.num_samples:,} samples, "
|
|
||||||
f"{self._num_shards} shards, shard_size={self._shard_size:,}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._is_sharded = False
|
|
||||||
self._load_single_files()
|
|
||||||
logger.info(
|
|
||||||
f"Loaded single-file dataset: {self.num_samples:,} samples (mmap)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _load_single_files(self):
|
|
||||||
"""向后兼容:加载单体 .npy 文件(mmap 模式)"""
|
|
||||||
self.input_ids = np.load(self.data_dir / "input_ids.npy", mmap_mode="r")
|
|
||||||
self.token_type_ids = np.load(
|
|
||||||
self.data_dir / "token_type_ids.npy", mmap_mode="r"
|
|
||||||
)
|
|
||||||
self.attention_mask = np.load(
|
|
||||||
self.data_dir / "attention_mask.npy", mmap_mode="r"
|
|
||||||
)
|
|
||||||
self.labels = np.load(self.data_dir / "labels.npy", mmap_mode="r")
|
|
||||||
self.history_slot_ids = np.load(
|
|
||||||
self.data_dir / "history_slot_ids.npy", mmap_mode="r"
|
|
||||||
)
|
|
||||||
self.pinyin_ids = np.load(self.data_dir / "pinyin_ids.npy", mmap_mode="r")
|
|
||||||
|
|
||||||
def _load_shard(self, shard_idx: int) -> Dict[str, np.ndarray]:
|
|
||||||
"""加载一个 .npz 分片到内存"""
|
|
||||||
shard_path = self.data_dir / f"shard_{shard_idx:06d}.npz"
|
|
||||||
data = dict(np.load(shard_path))
|
|
||||||
for key in data:
|
|
||||||
data[key] = data[key].astype(np.int64)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self.num_samples
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
|
||||||
if not 0 <= idx < self.num_samples:
|
|
||||||
raise IndexError(
|
|
||||||
f"Index {idx} out of range for dataset with {self.num_samples} samples"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._is_sharded:
|
|
||||||
shard_idx = idx // self._shard_size
|
|
||||||
local_idx = idx % self._shard_size
|
|
||||||
shard_data = self._cache.get(shard_idx, self._load_shard)
|
|
||||||
return {
|
|
||||||
"input_ids": torch.from_numpy(
|
|
||||||
shard_data["input_ids"][local_idx].copy()
|
|
||||||
),
|
|
||||||
"token_type_ids": torch.from_numpy(
|
|
||||||
shard_data["token_type_ids"][local_idx].copy()
|
|
||||||
),
|
|
||||||
"attention_mask": torch.from_numpy(
|
|
||||||
shard_data["attention_mask"][local_idx].copy()
|
|
||||||
),
|
|
||||||
"labels": torch.tensor(
|
|
||||||
shard_data["labels"][local_idx], dtype=torch.long
|
|
||||||
),
|
|
||||||
"history_slot_ids": torch.from_numpy(
|
|
||||||
shard_data["history_slot_ids"][local_idx].copy()
|
|
||||||
),
|
|
||||||
"pinyin_ids": torch.from_numpy(
|
|
||||||
shard_data["pinyin_ids"][local_idx].copy()
|
|
||||||
),
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"input_ids": torch.from_numpy(self.input_ids[idx].astype(np.int64)),
|
|
||||||
"token_type_ids": torch.from_numpy(
|
|
||||||
self.token_type_ids[idx].astype(np.int64)
|
|
||||||
),
|
|
||||||
"attention_mask": torch.from_numpy(
|
|
||||||
self.attention_mask[idx].astype(np.int64)
|
|
||||||
),
|
|
||||||
"labels": torch.tensor(self.labels[idx], dtype=torch.long),
|
|
||||||
"history_slot_ids": torch.from_numpy(
|
|
||||||
self.history_slot_ids[idx].astype(np.int64)
|
|
||||||
),
|
|
||||||
"pinyin_ids": torch.from_numpy(self.pinyin_ids[idx].astype(np.int64)),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def preprocessed_collate_fn(batch):
|
|
||||||
"""
|
|
||||||
预处理数据的 collate 函数。
|
|
||||||
不含 string 字段(prefix/suffix/pinyin),仅处理 tensor 字段。
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"input_ids": torch.stack([item["input_ids"] for item in batch]),
|
|
||||||
"token_type_ids": torch.stack([item["token_type_ids"] for item in batch]),
|
|
||||||
"attention_mask": torch.stack([item["attention_mask"] for item in batch]),
|
|
||||||
"labels": torch.stack([item["labels"] for item in batch]),
|
|
||||||
"history_slot_ids": torch.stack([item["history_slot_ids"] for item in batch]),
|
|
||||||
"pinyin_ids": torch.stack([item["pinyin_ids"] for item in batch]),
|
|
||||||
}
|
|
||||||
|
|
@ -1,250 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
缺失字符补充工具
|
|
||||||
|
|
||||||
步骤 1: find-missing — 扫描已预处理数据,找出从未出现的 label ID,输出 JSON
|
|
||||||
步骤 2: generate-template — 根据 JSON 生成 JSONL 占位文件,供用户手动填入包含缺失字的真实文本
|
|
||||||
|
|
||||||
用法:
|
|
||||||
python -m model.supplement_missing find-missing \
|
|
||||||
--preprocessed-dir ./preprocessed/train \
|
|
||||||
--output missing_chars.json
|
|
||||||
|
|
||||||
python -m model.supplement_missing generate-template \
|
|
||||||
--missing-chars missing_chars.json \
|
|
||||||
--output supplement_texts.jsonl
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Set
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from loguru import logger
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.table import Table
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .query import QueryEngine
|
|
||||||
|
|
||||||
|
|
||||||
def scan_labels(preprocessed_dir: Path) -> Set[int]:
|
|
||||||
"""扫描预处理目录中所有 .npz 分片,收集所有出现过的 label ID"""
|
|
||||||
appeared: Set[int] = set()
|
|
||||||
|
|
||||||
shard_files = sorted(preprocessed_dir.glob("shard_*.npz"))
|
|
||||||
if not shard_files:
|
|
||||||
logger.warning(f"未找到 .npz 分片文件: {preprocessed_dir}")
|
|
||||||
return appeared
|
|
||||||
|
|
||||||
for shard_path in tqdm(shard_files, desc="扫描分片", unit="shard"):
|
|
||||||
data = np.load(shard_path)
|
|
||||||
labels = data["labels"].astype(np.int64)
|
|
||||||
if labels.ndim > 1 and labels.shape[-1] == 1:
|
|
||||||
labels = labels.squeeze(-1)
|
|
||||||
unique_ids = np.unique(labels)
|
|
||||||
appeared.update(int(uid) for uid in unique_ids)
|
|
||||||
del data
|
|
||||||
|
|
||||||
return appeared
|
|
||||||
|
|
||||||
|
|
||||||
def cmd_find_missing(args):
|
|
||||||
console = Console()
|
|
||||||
preprocessed_dir = Path(args.preprocessed_dir)
|
|
||||||
|
|
||||||
if not preprocessed_dir.exists():
|
|
||||||
console.print(f"[bold red]目录不存在: {preprocessed_dir}[/bold red]")
|
|
||||||
return
|
|
||||||
|
|
||||||
metadata_path = preprocessed_dir / "metadata.json"
|
|
||||||
if not metadata_path.exists():
|
|
||||||
console.print(f"[bold red]未找到 metadata.json: {metadata_path}[/bold red]")
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(metadata_path, "r", encoding="utf-8") as f:
|
|
||||||
metadata = json.load(f)
|
|
||||||
console.print(
|
|
||||||
f"[bold cyan]预处理数据: {metadata['num_samples']:,} 样本, {metadata['num_shards']} 分片[/bold cyan]"
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print("[bold cyan]扫描 labels...[/bold cyan]")
|
|
||||||
appeared = scan_labels(preprocessed_dir)
|
|
||||||
|
|
||||||
console.print("[bold cyan]加载 QueryEngine...[/bold cyan]")
|
|
||||||
query_engine = QueryEngine()
|
|
||||||
query_engine.load()
|
|
||||||
|
|
||||||
all_ids = set(query_engine._id_to_info.keys())
|
|
||||||
missing_ids = all_ids - appeared
|
|
||||||
|
|
||||||
missing_chars = []
|
|
||||||
for mid in sorted(missing_ids):
|
|
||||||
if mid == 0:
|
|
||||||
continue
|
|
||||||
info = query_engine.query_by_id(mid)
|
|
||||||
if info is not None:
|
|
||||||
missing_chars.append(
|
|
||||||
{
|
|
||||||
"id": info.id,
|
|
||||||
"char": info.char,
|
|
||||||
"pinyin": info.pinyin,
|
|
||||||
"count": info.count,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"missing_count": len(missing_chars),
|
|
||||||
"missing_chars": missing_chars,
|
|
||||||
}
|
|
||||||
|
|
||||||
output_path = Path(args.output)
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
console.print(f"\n[bold green]=== 扫描完成 ===[/bold green]")
|
|
||||||
console.print(f"词表大小: {len(all_ids):,} (含 EOS)")
|
|
||||||
console.print(f"已出现标签: {len(appeared):,}")
|
|
||||||
console.print(
|
|
||||||
f"[bold red]缺失标签: {len(missing_ids):,}[/bold red] (其中非 EOS: {len(missing_chars)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
if missing_chars:
|
|
||||||
table = Table(
|
|
||||||
title=f"缺失字符 (共 {len(missing_chars)} 个)",
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold magenta",
|
|
||||||
)
|
|
||||||
table.add_column("ID", style="cyan", width=8)
|
|
||||||
table.add_column("字符", style="yellow", width=6)
|
|
||||||
table.add_column("拼音", style="green", width=12)
|
|
||||||
table.add_column("语料频次", style="red", width=12)
|
|
||||||
for entry in missing_chars:
|
|
||||||
table.add_row(
|
|
||||||
str(entry["id"]),
|
|
||||||
entry["char"],
|
|
||||||
entry["pinyin"],
|
|
||||||
f"{entry['count']:,}",
|
|
||||||
)
|
|
||||||
console.print(table)
|
|
||||||
|
|
||||||
console.print(f"\n已输出到: {output_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def cmd_generate_template(args):
|
|
||||||
console = Console()
|
|
||||||
|
|
||||||
missing_path = Path(args.missing_chars)
|
|
||||||
if not missing_path.exists():
|
|
||||||
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(missing_path, "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
missing_chars = data.get("missing_chars", [])
|
|
||||||
if not missing_chars:
|
|
||||||
console.print("[bold green]没有缺失字符,无需生成模板[/bold green]")
|
|
||||||
return
|
|
||||||
|
|
||||||
num_entries = args.num_entries
|
|
||||||
total_lines = len(missing_chars) * num_entries
|
|
||||||
|
|
||||||
console.print(f"[bold cyan]缺失字符数: {len(missing_chars)}[/bold cyan]")
|
|
||||||
console.print(f"[bold cyan]每字符模板数: {num_entries}[/bold cyan]")
|
|
||||||
console.print(f"[bold cyan]总模板行数: {total_lines}[/bold cyan]")
|
|
||||||
|
|
||||||
output_path = Path(args.output)
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
|
||||||
for entry in missing_chars:
|
|
||||||
for i in range(num_entries):
|
|
||||||
line = json.dumps(
|
|
||||||
{"text": f"请在这里输入包含「{entry['char']}」字的第{i + 1}条文本"},
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
f.write(line + "\n")
|
|
||||||
|
|
||||||
console.print(f"[bold green]模板已生成: {output_path}[/bold green]")
|
|
||||||
console.print(
|
|
||||||
f"共 {total_lines} 条({len(missing_chars)} 字符 × {num_entries} 条/字符),"
|
|
||||||
f"请手动编辑该文件,将占位文本替换为包含对应字符的真实文本。"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="缺失字符补充工具",
|
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
||||||
epilog="""
|
|
||||||
子命令:
|
|
||||||
find-missing 扫描已预处理数据,找出从未出现的 label ID
|
|
||||||
generate-template 根据缺失字符 JSON 生成 JSONL 占位文件
|
|
||||||
|
|
||||||
示例:
|
|
||||||
python -m model.supplement_missing find-missing \\
|
|
||||||
--preprocessed-dir ./preprocessed/train \\
|
|
||||||
--output missing_chars.json
|
|
||||||
|
|
||||||
python -m model.supplement_missing generate-template \\
|
|
||||||
--missing-chars missing_chars.json \\
|
|
||||||
--output supplement_texts.jsonl
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
|
||||||
|
|
||||||
# find-missing
|
|
||||||
p_find = subparsers.add_parser("find-missing", help="扫描预处理数据,找出缺失标签")
|
|
||||||
p_find.add_argument(
|
|
||||||
"--preprocessed-dir",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="预处理数据目录(包含 shard_*.npz 和 metadata.json)",
|
|
||||||
)
|
|
||||||
p_find.add_argument(
|
|
||||||
"--output",
|
|
||||||
type=str,
|
|
||||||
default="missing_chars.json",
|
|
||||||
help="输出 JSON 文件路径(默认: missing_chars.json)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# generate-template
|
|
||||||
p_gen = subparsers.add_parser("generate-template", help="生成补充文本模板")
|
|
||||||
p_gen.add_argument(
|
|
||||||
"--missing-chars",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
|
|
||||||
)
|
|
||||||
p_gen.add_argument(
|
|
||||||
"--output",
|
|
||||||
type=str,
|
|
||||||
default="supplement_texts.jsonl",
|
|
||||||
help="输出 JSONL 文件路径(默认: supplement_texts.jsonl)",
|
|
||||||
)
|
|
||||||
p_gen.add_argument(
|
|
||||||
"--num-entries",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="每个缺失字符生成的模板条数(默认: 3)",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.command is None:
|
|
||||||
parser.print_help()
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.command == "find-missing":
|
|
||||||
cmd_find_missing(args)
|
|
||||||
elif args.command == "generate-template":
|
|
||||||
cmd_generate_template(args)
|
|
||||||
|
|
||||||
|
|
||||||
app = main
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -28,11 +28,6 @@ from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from .dataset import PinyinInputDataset
|
from .dataset import PinyinInputDataset
|
||||||
from .preprocessed_dataset import (
|
|
||||||
PreProcessedDataset,
|
|
||||||
is_preprocessed_data,
|
|
||||||
preprocessed_collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 导入模型和数据
|
# 导入模型和数据
|
||||||
from .model import InputMethodEngine
|
from .model import InputMethodEngine
|
||||||
|
|
@ -102,11 +97,6 @@ class Trainer:
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
self.train_dataloader = train_dataloader
|
self.train_dataloader = train_dataloader
|
||||||
if isinstance(eval_dataloader, DataLoader) and not isinstance(
|
|
||||||
eval_dataloader.dataset, torch.utils.data.IterableDataset
|
|
||||||
):
|
|
||||||
self.eval_dataloader = eval_dataloader
|
|
||||||
else:
|
|
||||||
self.eval_dataloader = list([i for i in eval_dataloader])
|
self.eval_dataloader = list([i for i in eval_dataloader])
|
||||||
self.output_dir = Path(output_dir)
|
self.output_dir = Path(output_dir)
|
||||||
self.num_epochs = num_epochs
|
self.num_epochs = num_epochs
|
||||||
|
|
@ -975,62 +965,25 @@ def worker_init_fn(worker_id: int) -> None:
|
||||||
random.seed(worker_seed)
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(batch: List[Dict[str, Any]], max_seq_length: int = 0) -> Dict[str, Any]:
|
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
自定义批处理函数,将多个样本组合成一个batch。
|
自定义批处理函数,将多个样本组合成一个batch
|
||||||
支持动态填充:根据batch内最大序列长度进行padding,而非固定max_length。
|
|
||||||
当 max_seq_length > 0 时,pad到指定长度(用于预处理)。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: 样本列表,每个样本是一个字典
|
batch: 样本列表,每个样本是一个字典
|
||||||
max_seq_length: 目标序列长度,0表示动态padding
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
批处理后的字典,tensor字段已stack,字符串字段保持为列表
|
批处理后的字典,tensor字段已stack,字符串字段保持为列表
|
||||||
"""
|
"""
|
||||||
input_ids_list = [item["input_ids"] for item in batch]
|
# 处理tensor字段 - 使用squeeze去除多余的batch维度
|
||||||
token_type_ids_list = [item["token_type_ids"] for item in batch]
|
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
|
||||||
attention_mask_list = [item["attention_mask"] for item in batch]
|
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
|
||||||
|
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
|
||||||
if max_seq_length > 0:
|
|
||||||
target_len = max_seq_length
|
|
||||||
else:
|
|
||||||
target_len = max(ids.shape[0] for ids in input_ids_list)
|
|
||||||
|
|
||||||
padded_input_ids = []
|
|
||||||
padded_token_type_ids = []
|
|
||||||
padded_attention_mask = []
|
|
||||||
for ids, tt_ids, mask in zip(
|
|
||||||
input_ids_list, token_type_ids_list, attention_mask_list
|
|
||||||
):
|
|
||||||
seq_len = ids.shape[0]
|
|
||||||
if seq_len < target_len:
|
|
||||||
pad_len = target_len - seq_len
|
|
||||||
padded_input_ids.append(
|
|
||||||
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
|
|
||||||
)
|
|
||||||
padded_token_type_ids.append(
|
|
||||||
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
|
|
||||||
)
|
|
||||||
padded_attention_mask.append(
|
|
||||||
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
|
|
||||||
)
|
|
||||||
elif seq_len > target_len:
|
|
||||||
padded_input_ids.append(ids[:target_len])
|
|
||||||
padded_token_type_ids.append(tt_ids[:target_len])
|
|
||||||
padded_attention_mask.append(mask[:target_len])
|
|
||||||
else:
|
|
||||||
padded_input_ids.append(ids)
|
|
||||||
padded_token_type_ids.append(tt_ids)
|
|
||||||
padded_attention_mask.append(mask)
|
|
||||||
|
|
||||||
input_ids = torch.stack(padded_input_ids)
|
|
||||||
token_type_ids = torch.stack(padded_token_type_ids)
|
|
||||||
attention_mask = torch.stack(padded_attention_mask)
|
|
||||||
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
||||||
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
||||||
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
||||||
|
|
||||||
|
# 字符串字段保持为列表
|
||||||
prefixes = [item["prefix"] for item in batch]
|
prefixes = [item["prefix"] for item in batch]
|
||||||
suffixes = [item["suffix"] for item in batch]
|
suffixes = [item["suffix"] for item in batch]
|
||||||
pinyins = [item["pinyin"] for item in batch]
|
pinyins = [item["pinyin"] for item in batch]
|
||||||
|
|
@ -1048,18 +1001,9 @@ def collate_fn(batch: List[Dict[str, Any]], max_seq_length: int = 0) -> Dict[str
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def preprocess_collate_fn(max_seq_length: int):
|
|
||||||
"""创建用于预处理的collate_fn,始终pad到max_seq_length"""
|
|
||||||
|
|
||||||
def _collate(batch):
|
|
||||||
return collate_fn(batch, max_seq_length=max_seq_length)
|
|
||||||
|
|
||||||
return _collate
|
|
||||||
|
|
||||||
|
|
||||||
# Typer CLI应用
|
# Typer CLI应用
|
||||||
def create_dataloader(
|
def create_dataloader(
|
||||||
dataset,
|
dataset: PinyinInputDataset,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
num_workers: int = 2,
|
num_workers: int = 2,
|
||||||
pin_memory: bool = True,
|
pin_memory: bool = True,
|
||||||
|
|
@ -1067,23 +1011,20 @@ def create_dataloader(
|
||||||
max_iter_length: Optional[int] = None,
|
max_iter_length: Optional[int] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
创建数据加载器,自动识别数据集类型。
|
创建数据加载器,优先使用DataLoader2,如果不可用则回退到DataLoader。
|
||||||
|
专门针对流式数据集优化。
|
||||||
|
|
||||||
- PinyinInputDataset(IterableDataset):使用流式加载
|
Args:
|
||||||
- PreProcessedDataset(map-style):使用标准加载,支持 shuffle
|
dataset: PinyinInputDataset实例
|
||||||
|
batch_size: 批次大小
|
||||||
|
num_workers: worker数量(对于流式数据集建议为2)
|
||||||
|
pin_memory: 是否固定内存
|
||||||
|
shuffle: 是否打乱(流式数据集内部处理打乱)
|
||||||
|
max_iter_length: 最大迭代长度,用于计算总步数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
数据加载器实例
|
||||||
"""
|
"""
|
||||||
if isinstance(dataset, PreProcessedDataset):
|
|
||||||
logger.info(f"📊 使用预处理数据集,样本数: {len(dataset)}")
|
|
||||||
return DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=shuffle,
|
|
||||||
num_workers=num_workers,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
collate_fn=preprocessed_collate_fn,
|
|
||||||
persistent_workers=True if num_workers > 0 else False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}")
|
logger.info(f"📊 使用标准DataLoader,worker数量: {num_workers}")
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|
@ -1223,82 +1164,12 @@ def train(
|
||||||
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
config_table.add_row("训练", "混合精度", str(mixed_precision))
|
||||||
|
|
||||||
config_table.add_row("其他", "自动恢复", str(auto_resume))
|
config_table.add_row("其他", "自动恢复", str(auto_resume))
|
||||||
|
console.print(config_table)
|
||||||
|
|
||||||
# 创建输出目录
|
# 创建输出目录
|
||||||
output_path = Path(output_dir)
|
output_path = Path(output_dir)
|
||||||
output_path.mkdir(parents=True, exist_ok=True)
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# 检测数据类型并创建数据加载器
|
|
||||||
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
|
||||||
|
|
||||||
is_train_preprocessed = is_preprocessed_data(train_data_path)
|
|
||||||
is_eval_preprocessed = is_preprocessed_data(eval_data_path)
|
|
||||||
|
|
||||||
if is_train_preprocessed:
|
|
||||||
train_dataset = PreProcessedDataset(train_data_path)
|
|
||||||
total_steps = (len(train_dataset) // batch_size) * num_epochs
|
|
||||||
train_dataloader = create_dataloader(
|
|
||||||
dataset=train_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=num_workers,
|
|
||||||
pin_memory=torch.cuda.is_available(),
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
config_table.add_row("数据", "训练数据类型", "预处理数据")
|
|
||||||
else:
|
|
||||||
train_dataset = PinyinInputDataset(
|
|
||||||
data_path=train_data_path,
|
|
||||||
max_workers=-1,
|
|
||||||
max_iter_length=max_iter_length,
|
|
||||||
max_seq_length=max_seq_len,
|
|
||||||
text_field="text",
|
|
||||||
py_style_weight=(9, 2, 1),
|
|
||||||
shuffle_buffer_size=2000000,
|
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
|
||||||
)
|
|
||||||
total_steps = int(max_iter_length * num_epochs / batch_size)
|
|
||||||
train_dataloader = create_dataloader(
|
|
||||||
dataset=train_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=num_workers,
|
|
||||||
pin_memory=torch.cuda.is_available(),
|
|
||||||
max_iter_length=max_iter_length,
|
|
||||||
)
|
|
||||||
config_table.add_row("数据", "训练数据类型", "流式数据")
|
|
||||||
|
|
||||||
if is_eval_preprocessed:
|
|
||||||
eval_dataset = PreProcessedDataset(eval_data_path)
|
|
||||||
eval_dataloader = create_dataloader(
|
|
||||||
dataset=eval_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=2,
|
|
||||||
pin_memory=torch.cuda.is_available(),
|
|
||||||
shuffle=False,
|
|
||||||
)
|
|
||||||
config_table.add_row("数据", "评估数据类型", "预处理数据")
|
|
||||||
else:
|
|
||||||
eval_dataset = PinyinInputDataset(
|
|
||||||
data_path=eval_data_path,
|
|
||||||
max_workers=-1,
|
|
||||||
max_iter_length=batch_size * 64,
|
|
||||||
max_seq_length=max_seq_len,
|
|
||||||
text_field="text",
|
|
||||||
py_style_weight=(9, 2, 1),
|
|
||||||
shuffle_buffer_size=2000000,
|
|
||||||
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
|
||||||
)
|
|
||||||
eval_dataloader = create_dataloader(
|
|
||||||
dataset=eval_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=2,
|
|
||||||
pin_memory=torch.cuda.is_available(),
|
|
||||||
max_iter_length=batch_size * 64,
|
|
||||||
)
|
|
||||||
config_table.add_row("数据", "评估数据类型", "流式数据")
|
|
||||||
|
|
||||||
config_table.add_row("数据", "总步数", str(total_steps))
|
|
||||||
console.print(config_table)
|
|
||||||
|
|
||||||
# 保存配置
|
# 保存配置
|
||||||
config = {
|
config = {
|
||||||
"train_data_path": train_data_path,
|
"train_data_path": train_data_path,
|
||||||
|
|
@ -1331,9 +1202,6 @@ def train(
|
||||||
"auto_resume": auto_resume,
|
"auto_resume": auto_resume,
|
||||||
"max_iter_length": max_iter_length,
|
"max_iter_length": max_iter_length,
|
||||||
"compile": compile,
|
"compile": compile,
|
||||||
"is_train_preprocessed": is_train_preprocessed,
|
|
||||||
"is_eval_preprocessed": is_eval_preprocessed,
|
|
||||||
"total_steps": total_steps,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config_file = output_path / "training_config.json"
|
config_file = output_path / "training_config.json"
|
||||||
|
|
@ -1342,6 +1210,52 @@ def train(
|
||||||
|
|
||||||
logger.info(f"Configuration saved to {config_file}")
|
logger.info(f"Configuration saved to {config_file}")
|
||||||
|
|
||||||
|
# 创建数据加载器
|
||||||
|
console.print("[bold cyan]正在创建数据加载器...[/bold cyan]")
|
||||||
|
|
||||||
|
# 训练数据集
|
||||||
|
train_dataset = PinyinInputDataset(
|
||||||
|
data_path=train_data_path,
|
||||||
|
max_workers=-1, # 自动选择worker数量
|
||||||
|
max_iter_length=max_iter_length,
|
||||||
|
max_seq_length=max_seq_len,
|
||||||
|
text_field="text",
|
||||||
|
py_style_weight=(9, 2, 1),
|
||||||
|
shuffle_buffer_size=2000000,
|
||||||
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 训练数据加载器
|
||||||
|
# 注意:PinyinInputDataset是IterableDataset,所以不能使用shuffle参数
|
||||||
|
# 多worker配置:每个worker处理数据集的一个分片,由dataset.__iter__中的shard处理
|
||||||
|
train_dataloader = create_dataloader(
|
||||||
|
dataset=train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
max_iter_length=max_iter_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 评估数据集(使用相同的设置,但可以调整参数)
|
||||||
|
eval_dataset = PinyinInputDataset(
|
||||||
|
data_path=eval_data_path,
|
||||||
|
max_workers=-1,
|
||||||
|
max_iter_length=batch_size * 64, # 评估集较小
|
||||||
|
max_seq_length=max_seq_len,
|
||||||
|
text_field="text",
|
||||||
|
py_style_weight=(9, 2, 1),
|
||||||
|
shuffle_buffer_size=2000000,
|
||||||
|
length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dataloader = create_dataloader(
|
||||||
|
dataset=eval_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=2, # 评估使用较少的worker
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
max_iter_length=batch_size * 64,
|
||||||
|
)
|
||||||
|
|
||||||
console.print("[bold cyan]正在创建模型...[/bold cyan]")
|
console.print("[bold cyan]正在创建模型...[/bold cyan]")
|
||||||
model = InputMethodEngine(
|
model = InputMethodEngine(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
|
|
@ -1365,7 +1279,7 @@ def train(
|
||||||
model=model,
|
model=model,
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
eval_dataloader=eval_dataloader,
|
eval_dataloader=eval_dataloader,
|
||||||
total_steps=total_steps,
|
total_steps=int(max_iter_length * num_epochs / batch_size),
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
num_epochs=num_epochs,
|
num_epochs=num_epochs,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ train_dataset = PinyinInputDataset(
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=512,
|
batch_size=512,
|
||||||
num_workers=16,
|
num_workers=2,
|
||||||
worker_init_fn=worker_init_fn,
|
worker_init_fn=worker_init_fn,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
prefetch_factor=2, # 减少预取以避免内存问题
|
prefetch_factor=2, # 减少预取以避免内存问题
|
||||||
|
|
|
||||||
|
|
@ -1,146 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Visualize frequency distribution with ASCII plots
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
def ascii_histogram(data, bins=20, width=60):
|
|
||||||
"""Create ASCII histogram"""
|
|
||||||
if not data:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
min_val = min(data)
|
|
||||||
max_val = max(data)
|
|
||||||
|
|
||||||
# Use log bins for wide range
|
|
||||||
if max_val / min_val > 1000:
|
|
||||||
log_min = math.log10(min_val) if min_val > 0 else 0
|
|
||||||
log_max = math.log10(max_val)
|
|
||||||
bin_edges = [10**(log_min + i*(log_max-log_min)/bins) for i in range(bins+1)]
|
|
||||||
hist = [0] * bins
|
|
||||||
for val in data:
|
|
||||||
if val > 0:
|
|
||||||
log_val = math.log10(val)
|
|
||||||
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * bins), bins-1)
|
|
||||||
hist[bin_idx] += 1
|
|
||||||
bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i+1]:.1e}" for i in range(bins)]
|
|
||||||
else:
|
|
||||||
bin_width = (max_val - min_val) / bins
|
|
||||||
bin_edges = [min_val + i*bin_width for i in range(bins+1)]
|
|
||||||
hist = [0] * bins
|
|
||||||
for val in data:
|
|
||||||
bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins-1)
|
|
||||||
hist[bin_idx] += 1
|
|
||||||
bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i+1]:.1f}" for i in range(bins)]
|
|
||||||
|
|
||||||
max_count = max(hist)
|
|
||||||
result = []
|
|
||||||
for i in range(bins):
|
|
||||||
if hist[i] == 0:
|
|
||||||
continue
|
|
||||||
bar = '#' * int(hist[i] / max_count * width)
|
|
||||||
result.append(f"{bin_labels[i]:20} | {bar} {hist[i]}")
|
|
||||||
|
|
||||||
return "\n".join(result)
|
|
||||||
|
|
||||||
def main():
|
|
||||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
|
||||||
with open(json_path, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
pairs = data.get('pairs', {})
|
|
||||||
counts = [pair.get('count', 0) for pair in pairs.values() if pair.get('count') is not None]
|
|
||||||
|
|
||||||
print("FREQUENCY DISTRIBUTION ANALYSIS")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
print("\n1. ASCII Histogram (log bins):")
|
|
||||||
print(ascii_histogram(counts, bins=20, width=60))
|
|
||||||
|
|
||||||
# Rank-frequency plot in ASCII
|
|
||||||
print("\n2. Rank-Frequency Relationship (Top 50):")
|
|
||||||
counts_sorted_desc = sorted(counts, reverse=True)
|
|
||||||
max_freq = counts_sorted_desc[0]
|
|
||||||
max_rank = 50
|
|
||||||
|
|
||||||
for rank in range(1, max_rank + 1):
|
|
||||||
freq = counts_sorted_desc[rank-1]
|
|
||||||
bar_length = int(math.log(freq) / math.log(max_freq) * 40)
|
|
||||||
bar = '#' * bar_length
|
|
||||||
print(f"Rank {rank:3}: {freq:12} {bar}")
|
|
||||||
|
|
||||||
# ID vs Frequency plot (sampled)
|
|
||||||
print("\n3. ID vs Frequency (sampled every 500 IDs):")
|
|
||||||
# Build ID to count mapping
|
|
||||||
id_to_count = {}
|
|
||||||
for key, pair in pairs.items():
|
|
||||||
char_id = pair.get('id')
|
|
||||||
count = pair.get('count')
|
|
||||||
if char_id is not None and count is not None:
|
|
||||||
id_to_count[char_id] = count
|
|
||||||
|
|
||||||
all_ids = sorted(id_to_count.keys())
|
|
||||||
max_id = all_ids[-1]
|
|
||||||
|
|
||||||
print("ID Frequency log10(freq)")
|
|
||||||
for id in range(0, max_id + 1, 500):
|
|
||||||
if id in id_to_count:
|
|
||||||
freq = id_to_count[id]
|
|
||||||
log_freq = math.log10(freq) if freq > 0 else 0
|
|
||||||
bar = '#' * int(log_freq / math.log10(max_freq) * 40)
|
|
||||||
print(f"{id:6} {freq:10} {log_freq:6.2f} {bar}")
|
|
||||||
|
|
||||||
# Zipf's law fit
|
|
||||||
print("\n4. Zipf's Law Analysis:")
|
|
||||||
print(" Rank * Frequency ≈ constant for Zipf's law")
|
|
||||||
print(" Top 10 ranks:")
|
|
||||||
for rank in range(1, 11):
|
|
||||||
freq = counts_sorted_desc[rank-1]
|
|
||||||
product = rank * freq
|
|
||||||
print(f" Rank {rank}: {freq:12} rank*freq = {product:.3e}")
|
|
||||||
|
|
||||||
# Check if product is roughly constant
|
|
||||||
products = [(rank+1) * counts_sorted_desc[rank] for rank in range(10)]
|
|
||||||
avg_product = sum(products) / len(products)
|
|
||||||
std_product = math.sqrt(sum((p - avg_product)**2 for p in products) / len(products))
|
|
||||||
print(f" Average product (ranks 2-11): {avg_product:.3e} ± {std_product:.3e}")
|
|
||||||
print(f" Coefficient of variation: {std_product/avg_product*100:.1f}%")
|
|
||||||
|
|
||||||
# Frequency spectrum
|
|
||||||
from collections import Counter
|
|
||||||
freq_counter = Counter(counts)
|
|
||||||
print("\n5. Frequency Spectrum (how many entries have each frequency):")
|
|
||||||
print(" Frequency Count Cumulative")
|
|
||||||
cum = 0
|
|
||||||
for freq in sorted(freq_counter.keys())[:20]:
|
|
||||||
count = freq_counter[freq]
|
|
||||||
cum += count
|
|
||||||
print(f" {freq:10} {count:6} {cum:6}")
|
|
||||||
|
|
||||||
# Summary statistics
|
|
||||||
print("\n6. Key Statistics:")
|
|
||||||
n = len(counts)
|
|
||||||
print(f" Total entries: {n}")
|
|
||||||
print(f" Min frequency: {min(counts)}")
|
|
||||||
print(f" Max frequency: {max(counts)}")
|
|
||||||
print(f" Ratio max/min: {max(counts)/min(counts):.2e}")
|
|
||||||
|
|
||||||
percentiles = [0.01, 0.1, 0.5, 0.9, 0.99]
|
|
||||||
for p in percentiles:
|
|
||||||
idx = int(p * n)
|
|
||||||
value = counts_sorted_desc[idx]
|
|
||||||
print(f" {p*100:5.1f}th percentile: {value:12} (rank ~{idx})")
|
|
||||||
|
|
||||||
# Save data for external plotting
|
|
||||||
with open("id_vs_freq.csv", "w") as f:
|
|
||||||
f.write("id,frequency\n")
|
|
||||||
for id in sorted(id_to_count.keys()):
|
|
||||||
f.write(f"{id},{id_to_count[id]}\n")
|
|
||||||
print("\nData saved to id_vs_freq.csv for external plotting")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Loading…
Reference in New Issue