refactor(generate_pinyin): 优化拼音生成逻辑,利用 pypinyin 分词能力处理多音字
This commit is contained in:
parent
8b41bcdc6f
commit
e8eab1f260
33
eval.py
33
eval.py
|
|
@ -14,7 +14,6 @@ eval.py - 评估模型在给定文本上的表现
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple, Optional
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
|
@ -33,8 +32,6 @@ from src.model.model import InputMethodEngine
|
||||||
from src.model.query import QueryEngine
|
from src.model.query import QueryEngine
|
||||||
from src.model.dataset import text_to_pinyin_ids
|
from src.model.dataset import text_to_pinyin_ids
|
||||||
|
|
||||||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
|
||||||
|
|
||||||
|
|
||||||
class TextEvaluator:
|
class TextEvaluator:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -171,34 +168,20 @@ class TextEvaluator:
|
||||||
|
|
||||||
def generate_pinyin(self, text: str) -> List[str]:
|
def generate_pinyin(self, text: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
流式处理单条文本,转换为拼音列表。
|
将文本转换为拼音列表。对整段文本调用 lazy_pinyin,
|
||||||
参考dataset.py中的generate_pinyin方法。
|
利用 pypinyin 内部的分词能力处理多音字。
|
||||||
|
参考 dataset.py 中的 generate_pinyin 方法。
|
||||||
"""
|
"""
|
||||||
if not text:
|
if not text:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
text_len = len(text)
|
pinyin_list = lazy_pinyin(text)
|
||||||
result: List[str] = [""] * text_len
|
|
||||||
|
|
||||||
# 遍历所有连续汉字片段
|
# 健壮性兜底:若长度不匹配(极罕见),降级为逐字转换
|
||||||
for match in _HANZI_RE.finditer(text):
|
if len(pinyin_list) != len(text):
|
||||||
start_idx = match.start()
|
pinyin_list = [lazy_pinyin(c)[0] for c in text]
|
||||||
hanzi_segment = match.group()
|
|
||||||
|
|
||||||
pinyin_list = lazy_pinyin(hanzi_segment)
|
return pinyin_list
|
||||||
|
|
||||||
if len(pinyin_list) != len(hanzi_segment):
|
|
||||||
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
|
|
||||||
|
|
||||||
for i, py in enumerate(pinyin_list):
|
|
||||||
result[start_idx + i] = py
|
|
||||||
|
|
||||||
# 填充非汉字字符
|
|
||||||
for i, char in enumerate(text):
|
|
||||||
if not result[i]:
|
|
||||||
result[i] = char
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_mask_pinyin(
|
def get_mask_pinyin(
|
||||||
self, text: str, pinyin_list: List[str]
|
self, text: str, pinyin_list: List[str]
|
||||||
|
|
|
||||||
|
|
@ -10,34 +10,41 @@ import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Path to the JSON file
|
# Path to the JSON file
|
||||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
json_path = (
|
||||||
|
Path(__file__).parent.parent
|
||||||
|
/ "src"
|
||||||
|
/ "model"
|
||||||
|
/ "assets"
|
||||||
|
/ "pinyin_char_statistics.json"
|
||||||
|
)
|
||||||
if not json_path.exists():
|
if not json_path.exists():
|
||||||
print(f"Error: File not found: {json_path}")
|
print(f"Error: File not found: {json_path}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
print(f"Loading {json_path}...")
|
print(f"Loading {json_path}...")
|
||||||
with open(json_path, 'r', encoding='utf-8') as f:
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
print(f"Timestamp: {data.get('timestamp')}")
|
print(f"Timestamp: {data.get('timestamp')}")
|
||||||
print(f"Total characters: {data.get('total_characters')}")
|
print(f"Total characters: {data.get('total_characters')}")
|
||||||
print(f"Total pinyins: {data.get('total_pinyins')}")
|
print(f"Total pinyins: {data.get('total_pinyins')}")
|
||||||
print(f"Valid input character count: {data.get('valid_input_character_count')}")
|
print(f"Valid input character count: {data.get('valid_input_character_count')}")
|
||||||
|
|
||||||
pairs = data.get('pairs', {})
|
pairs = data.get("pairs", {})
|
||||||
print(f"Number of pairs: {len(pairs)}")
|
print(f"Number of pairs: {len(pairs)}")
|
||||||
|
|
||||||
# Extract counts and IDs
|
# Extract counts and IDs
|
||||||
counts = []
|
counts = []
|
||||||
id_to_count = {}
|
id_to_count = {}
|
||||||
char_to_count = {}
|
char_to_count = {}
|
||||||
for key, pair in pairs.items():
|
for key, pair in pairs.items():
|
||||||
try:
|
try:
|
||||||
char_id = pair.get('id')
|
char_id = pair.get("id")
|
||||||
count = pair.get('count')
|
count = pair.get("count")
|
||||||
char = pair.get('char', '')
|
char = pair.get("char", "")
|
||||||
if char_id is not None and count is not None:
|
if char_id is not None and count is not None:
|
||||||
counts.append(count)
|
counts.append(count)
|
||||||
id_to_count[char_id] = count
|
id_to_count[char_id] = count
|
||||||
|
|
@ -46,21 +53,21 @@ def main():
|
||||||
except (ValueError, TypeError) as e:
|
except (ValueError, TypeError) as e:
|
||||||
print(f"Warning: Could not parse pair {key}: {e}")
|
print(f"Warning: Could not parse pair {key}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not counts:
|
if not counts:
|
||||||
print("No valid count data found.")
|
print("No valid count data found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Basic statistics
|
# Basic statistics
|
||||||
min_count = min(counts)
|
min_count = min(counts)
|
||||||
max_count = max(counts)
|
max_count = max(counts)
|
||||||
total_count = sum(counts)
|
total_count = sum(counts)
|
||||||
mean_count = total_count / len(counts)
|
mean_count = total_count / len(counts)
|
||||||
|
|
||||||
# Sort counts for percentiles
|
# Sort counts for percentiles
|
||||||
sorted_counts = sorted(counts)
|
sorted_counts = sorted(counts)
|
||||||
n = len(sorted_counts)
|
n = len(sorted_counts)
|
||||||
|
|
||||||
# Percentiles
|
# Percentiles
|
||||||
p10 = sorted_counts[int(0.1 * n)]
|
p10 = sorted_counts[int(0.1 * n)]
|
||||||
p25 = sorted_counts[int(0.25 * n)]
|
p25 = sorted_counts[int(0.25 * n)]
|
||||||
|
|
@ -68,11 +75,11 @@ def main():
|
||||||
p75 = sorted_counts[int(0.75 * n)]
|
p75 = sorted_counts[int(0.75 * n)]
|
||||||
p90 = sorted_counts[int(0.9 * n)]
|
p90 = sorted_counts[int(0.9 * n)]
|
||||||
p99 = sorted_counts[int(0.99 * n)]
|
p99 = sorted_counts[int(0.99 * n)]
|
||||||
|
|
||||||
# Variance and std dev
|
# Variance and std dev
|
||||||
variance = sum((x - mean_count) ** 2 for x in counts) / n
|
variance = sum((x - mean_count) ** 2 for x in counts) / n
|
||||||
std_dev = math.sqrt(variance)
|
std_dev = math.sqrt(variance)
|
||||||
|
|
||||||
print("\n=== BASIC STATISTICS ===")
|
print("\n=== BASIC STATISTICS ===")
|
||||||
print(f"Min frequency: {min_count}")
|
print(f"Min frequency: {min_count}")
|
||||||
print(f"Max frequency: {max_count}")
|
print(f"Max frequency: {max_count}")
|
||||||
|
|
@ -80,7 +87,7 @@ def main():
|
||||||
print(f"Standard deviation: {std_dev:.2f}")
|
print(f"Standard deviation: {std_dev:.2f}")
|
||||||
print(f"Total frequency sum: {total_count}")
|
print(f"Total frequency sum: {total_count}")
|
||||||
print(f"Number of entries: {n}")
|
print(f"Number of entries: {n}")
|
||||||
|
|
||||||
print("\n=== PERCENTILES ===")
|
print("\n=== PERCENTILES ===")
|
||||||
print(f"10th percentile: {p10}")
|
print(f"10th percentile: {p10}")
|
||||||
print(f"25th percentile: {p25}")
|
print(f"25th percentile: {p25}")
|
||||||
|
|
@ -88,52 +95,54 @@ def main():
|
||||||
print(f"75th percentile: {p75}")
|
print(f"75th percentile: {p75}")
|
||||||
print(f"90th percentile: {p90}")
|
print(f"90th percentile: {p90}")
|
||||||
print(f"99th percentile: {p99}")
|
print(f"99th percentile: {p99}")
|
||||||
|
|
||||||
# Find IDs with min and max counts
|
# Find IDs with min and max counts
|
||||||
min_ids = [id for id, count in id_to_count.items() if count == min_count]
|
min_ids = [id for id, count in id_to_count.items() if count == min_count]
|
||||||
max_ids = [id for id, count in id_to_count.items() if count == max_count]
|
max_ids = [id for id, count in id_to_count.items() if count == max_count]
|
||||||
|
|
||||||
print(f"\nIDs with min frequency ({min_count}): {min_ids}")
|
print(f"\nIDs with min frequency ({min_count}): {min_ids}")
|
||||||
print(f"IDs with max frequency ({max_count}): {max_ids}")
|
print(f"IDs with max frequency ({max_count}): {max_ids}")
|
||||||
|
|
||||||
# Check if IDs are assigned in frequency order
|
# Check if IDs are assigned in frequency order
|
||||||
# Compute correlation between ID and count
|
# Compute correlation between ID and count
|
||||||
ids = list(id_to_count.keys())
|
ids = list(id_to_count.keys())
|
||||||
id_counts = [id_to_count[id] for id in ids]
|
id_counts = [id_to_count[id] for id in ids]
|
||||||
|
|
||||||
# Sort by ID and check if counts are decreasing
|
# Sort by ID and check if counts are decreasing
|
||||||
sorted_by_id = sorted(ids)
|
sorted_by_id = sorted(ids)
|
||||||
counts_by_id = [id_to_count[id] for id in sorted_by_id]
|
counts_by_id = [id_to_count[id] for id in sorted_by_id]
|
||||||
|
|
||||||
# Calculate monotonicity: count of times count decreases as ID increases
|
# Calculate monotonicity: count of times count decreases as ID increases
|
||||||
decreases = 0
|
decreases = 0
|
||||||
increases = 0
|
increases = 0
|
||||||
for i in range(1, len(counts_by_id)):
|
for i in range(1, len(counts_by_id)):
|
||||||
if counts_by_id[i] < counts_by_id[i-1]:
|
if counts_by_id[i] < counts_by_id[i - 1]:
|
||||||
decreases += 1
|
decreases += 1
|
||||||
elif counts_by_id[i] > counts_by_id[i-1]:
|
elif counts_by_id[i] > counts_by_id[i - 1]:
|
||||||
increases += 1
|
increases += 1
|
||||||
|
|
||||||
print(f"\n=== ID ORDER ANALYSIS ===")
|
print(f"\n=== ID ORDER ANALYSIS ===")
|
||||||
print(f"Total pairs: {len(counts_by_id)}")
|
print(f"Total pairs: {len(counts_by_id)}")
|
||||||
print(f"Decreases as ID increases: {decreases} times")
|
print(f"Decreases as ID increases: {decreases} times")
|
||||||
print(f"Increases as ID increases: {increases} times")
|
print(f"Increases as ID increases: {increases} times")
|
||||||
print(f"Percentage decreasing: {decreases/(len(counts_by_id)-1)*100:.2f}%")
|
print(f"Percentage decreasing: {decreases / (len(counts_by_id) - 1) * 100:.2f}%")
|
||||||
|
|
||||||
# Check if IDs are roughly sorted by frequency
|
# Check if IDs are roughly sorted by frequency
|
||||||
# Compute Spearman rank correlation (simplified)
|
# Compute Spearman rank correlation (simplified)
|
||||||
sorted_by_count = sorted(ids, key=lambda x: id_to_count[x], reverse=True)
|
sorted_by_count = sorted(ids, key=lambda x: id_to_count[x], reverse=True)
|
||||||
rank_by_id = {id: i for i, id in enumerate(sorted_by_id)}
|
rank_by_id = {id: i for i, id in enumerate(sorted_by_id)}
|
||||||
rank_by_count = {id: i for i, id in enumerate(sorted_by_count)}
|
rank_by_count = {id: i for i, id in enumerate(sorted_by_count)}
|
||||||
|
|
||||||
# Average rank difference
|
# Average rank difference
|
||||||
rank_diffs = [abs(rank_by_id[id] - rank_by_count[id]) for id in ids]
|
rank_diffs = [abs(rank_by_id[id] - rank_by_count[id]) for id in ids]
|
||||||
avg_rank_diff = sum(rank_diffs) / len(rank_diffs)
|
avg_rank_diff = sum(rank_diffs) / len(rank_diffs)
|
||||||
max_rank_diff = max(rank_diffs)
|
max_rank_diff = max(rank_diffs)
|
||||||
|
|
||||||
print(f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}")
|
print(
|
||||||
|
f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}"
|
||||||
|
)
|
||||||
print(f"Maximum rank difference: {max_rank_diff}")
|
print(f"Maximum rank difference: {max_rank_diff}")
|
||||||
|
|
||||||
# Analyze specific ID range 5000-5500
|
# Analyze specific ID range 5000-5500
|
||||||
print("\n=== ANALYSIS OF ID RANGE 5000-5500 ===")
|
print("\n=== ANALYSIS OF ID RANGE 5000-5500 ===")
|
||||||
range_counts = []
|
range_counts = []
|
||||||
|
|
@ -142,7 +151,7 @@ def main():
|
||||||
if id in id_to_count:
|
if id in id_to_count:
|
||||||
range_counts.append(id_to_count[id])
|
range_counts.append(id_to_count[id])
|
||||||
range_ids.append(id)
|
range_ids.append(id)
|
||||||
|
|
||||||
if range_counts:
|
if range_counts:
|
||||||
range_min = min(range_counts)
|
range_min = min(range_counts)
|
||||||
range_max = max(range_counts)
|
range_max = max(range_counts)
|
||||||
|
|
@ -152,7 +161,7 @@ def main():
|
||||||
range_p10 = range_sorted[int(0.1 * range_n)] if range_n > 0 else 0
|
range_p10 = range_sorted[int(0.1 * range_n)] if range_n > 0 else 0
|
||||||
range_p50 = range_sorted[int(0.5 * range_n)] if range_n > 0 else 0
|
range_p50 = range_sorted[int(0.5 * range_n)] if range_n > 0 else 0
|
||||||
range_p90 = range_sorted[int(0.9 * range_n)] if range_n > 0 else 0
|
range_p90 = range_sorted[int(0.9 * range_n)] if range_n > 0 else 0
|
||||||
|
|
||||||
print(f"IDs in range 5000-5500: {len(range_counts)}")
|
print(f"IDs in range 5000-5500: {len(range_counts)}")
|
||||||
print(f"Min frequency in range: {range_min}")
|
print(f"Min frequency in range: {range_min}")
|
||||||
print(f"Max frequency in range: {range_max}")
|
print(f"Max frequency in range: {range_max}")
|
||||||
|
|
@ -160,66 +169,82 @@ def main():
|
||||||
print(f"10th percentile in range: {range_p10}")
|
print(f"10th percentile in range: {range_p10}")
|
||||||
print(f"50th percentile in range: {range_p50}")
|
print(f"50th percentile in range: {range_p50}")
|
||||||
print(f"90th percentile in range: {range_p90}")
|
print(f"90th percentile in range: {range_p90}")
|
||||||
|
|
||||||
# Find IDs with min frequency in this range
|
# Find IDs with min frequency in this range
|
||||||
min_in_range_ids = [id for id in range_ids if id_to_count[id] == range_min]
|
min_in_range_ids = [id for id in range_ids if id_to_count[id] == range_min]
|
||||||
print(f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}")
|
print(
|
||||||
|
f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("No IDs found in range 5000-5500")
|
print("No IDs found in range 5000-5500")
|
||||||
|
|
||||||
# Histogram of frequencies (log bins)
|
# Histogram of frequencies (log bins)
|
||||||
print("\n=== FREQUENCY DISTRIBUTION (LOG BINS) ===")
|
print("\n=== FREQUENCY DISTRIBUTION (LOG BINS) ===")
|
||||||
if max_count > 0:
|
if max_count > 0:
|
||||||
log_min = math.log10(min_count) if min_count > 0 else 0
|
log_min = math.log10(min_count) if min_count > 0 else 0
|
||||||
log_max = math.log10(max_count)
|
log_max = math.log10(max_count)
|
||||||
num_bins = 20
|
num_bins = 20
|
||||||
bin_edges = [10**(log_min + i*(log_max-log_min)/num_bins) for i in range(num_bins+1)]
|
bin_edges = [
|
||||||
|
10 ** (log_min + i * (log_max - log_min) / num_bins)
|
||||||
|
for i in range(num_bins + 1)
|
||||||
|
]
|
||||||
|
|
||||||
hist = [0] * num_bins
|
hist = [0] * num_bins
|
||||||
for count in counts:
|
for count in counts:
|
||||||
if count > 0:
|
if count > 0:
|
||||||
log_val = math.log10(count)
|
log_val = math.log10(count)
|
||||||
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * num_bins), num_bins-1)
|
bin_idx = min(
|
||||||
|
int((log_val - log_min) / (log_max - log_min) * num_bins),
|
||||||
|
num_bins - 1,
|
||||||
|
)
|
||||||
hist[bin_idx] += 1
|
hist[bin_idx] += 1
|
||||||
|
|
||||||
print("Log-scale histogram (count range -> frequency count):")
|
print("Log-scale histogram (count range -> frequency count):")
|
||||||
for i in range(num_bins):
|
for i in range(num_bins):
|
||||||
if hist[i] > 0:
|
if hist[i] > 0:
|
||||||
lower = bin_edges[i]
|
lower = bin_edges[i]
|
||||||
upper = bin_edges[i+1]
|
upper = bin_edges[i + 1]
|
||||||
print(f" {lower:.2e} - {upper:.2e}: {hist[i]} entries")
|
print(f" {lower:.2e} - {upper:.2e}: {hist[i]} entries")
|
||||||
|
|
||||||
# Check for zero or near-zero frequencies
|
# Check for zero or near-zero frequencies
|
||||||
zero_count = sum(1 for c in counts if c == 0)
|
zero_count = sum(1 for c in counts if c == 0)
|
||||||
low_count = sum(1 for c in counts if 0 < c <= 10)
|
low_count = sum(1 for c in counts if 0 < c <= 10)
|
||||||
very_low_count = sum(1 for c in counts if 0 < c <= 100)
|
very_low_count = sum(1 for c in counts if 0 < c <= 100)
|
||||||
|
|
||||||
print(f"\n=== LOW FREQUENCY ANALYSIS ===")
|
print(f"\n=== LOW FREQUENCY ANALYSIS ===")
|
||||||
print(f"Entries with zero frequency: {zero_count}")
|
print(f"Entries with zero frequency: {zero_count}")
|
||||||
print(f"Entries with frequency <= 10: {low_count}")
|
print(f"Entries with frequency <= 10: {low_count}")
|
||||||
print(f"Entries with frequency <= 100: {very_low_count}")
|
print(f"Entries with frequency <= 100: {very_low_count}")
|
||||||
|
|
||||||
# Find the actual min frequency (excluding zeros if any)
|
# Find the actual min frequency (excluding zeros if any)
|
||||||
non_zero_counts = [c for c in counts if c > 0]
|
non_zero_counts = [c for c in counts if c > 0]
|
||||||
if non_zero_counts:
|
if non_zero_counts:
|
||||||
actual_min = min(non_zero_counts)
|
actual_min = min(non_zero_counts)
|
||||||
print(f"Actual min frequency (non-zero): {actual_min}")
|
print(f"Actual min frequency (non-zero): {actual_min}")
|
||||||
actual_min_ids = [id for id, count in id_to_count.items() if count == actual_min]
|
actual_min_ids = [
|
||||||
print(f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}")
|
id for id, count in id_to_count.items() if count == actual_min
|
||||||
|
]
|
||||||
|
print(
|
||||||
|
f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
# Summary for smoothing algorithm design
|
# Summary for smoothing algorithm design
|
||||||
print("\n=== SUMMARY FOR SMOOTHING ALGORITHM DESIGN ===")
|
print("\n=== SUMMARY FOR SMOOTHING ALGORITHM DESIGN ===")
|
||||||
print(f"Frequency range spans {max_count/min_count if min_count>0 else 'inf'}:1 ratio")
|
print(
|
||||||
|
f"Frequency range spans {max_count / min_count if min_count > 0 else 'inf'}:1 ratio"
|
||||||
|
)
|
||||||
print(f"Most entries ({p50}) have frequency around {p50}")
|
print(f"Most entries ({p50}) have frequency around {p50}")
|
||||||
print(f"Top 10% of entries have frequency > {p90}")
|
print(f"Top 10% of entries have frequency > {p90}")
|
||||||
print(f"Bottom 10% of entries have frequency < {p10}")
|
print(f"Bottom 10% of entries have frequency < {p10}")
|
||||||
print(f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency")
|
print(
|
||||||
|
f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency"
|
||||||
|
)
|
||||||
|
|
||||||
# Save detailed data for further analysis
|
# Save detailed data for further analysis
|
||||||
output_file = "frequency_analysis_results.txt"
|
output_file = "frequency_analysis_results.txt"
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
f.write("Frequency Analysis Results\n")
|
f.write("Frequency Analysis Results\n")
|
||||||
f.write("="*50 + "\n")
|
f.write("=" * 50 + "\n")
|
||||||
f.write(f"Min frequency: {min_count}\n")
|
f.write(f"Min frequency: {min_count}\n")
|
||||||
f.write(f"Max frequency: {max_count}\n")
|
f.write(f"Max frequency: {max_count}\n")
|
||||||
f.write(f"Mean frequency: {mean_count:.2f}\n")
|
f.write(f"Mean frequency: {mean_count:.2f}\n")
|
||||||
|
|
@ -227,10 +252,15 @@ def main():
|
||||||
f.write(f"10th percentile: {p10}\n")
|
f.write(f"10th percentile: {p10}\n")
|
||||||
f.write(f"50th percentile: {p50}\n")
|
f.write(f"50th percentile: {p50}\n")
|
||||||
f.write(f"90th percentile: {p90}\n")
|
f.write(f"90th percentile: {p90}\n")
|
||||||
f.write(f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n")
|
f.write(
|
||||||
f.write(f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n")
|
f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n"
|
||||||
|
)
|
||||||
|
f.write(
|
||||||
|
f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"\nDetailed results saved to {output_file}")
|
print(f"\nDetailed results saved to {output_file}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
@ -7,58 +7,70 @@ import json
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
json_path = (
|
||||||
with open(json_path, 'r', encoding='utf-8') as f:
|
Path(__file__).parent.parent
|
||||||
|
/ "src"
|
||||||
|
/ "model"
|
||||||
|
/ "assets"
|
||||||
|
/ "pinyin_char_statistics.json"
|
||||||
|
)
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
pairs = data.get('pairs', {})
|
pairs = data.get("pairs", {})
|
||||||
|
|
||||||
# Build ID to count mapping
|
# Build ID to count mapping
|
||||||
id_to_count = {}
|
id_to_count = {}
|
||||||
for key, pair in pairs.items():
|
for key, pair in pairs.items():
|
||||||
char_id = pair.get('id')
|
char_id = pair.get("id")
|
||||||
count = pair.get('count')
|
count = pair.get("count")
|
||||||
if char_id is not None and count is not None:
|
if char_id is not None and count is not None:
|
||||||
id_to_count[char_id] = count
|
id_to_count[char_id] = count
|
||||||
|
|
||||||
# Analyze range 5000-5500 in detail
|
# Analyze range 5000-5500 in detail
|
||||||
print("ID range 5000-5500 detailed analysis:")
|
print("ID range 5000-5500 detailed analysis:")
|
||||||
print("ID\tCount\tChar\tPinyin")
|
print("ID\tCount\tChar\tPinyin")
|
||||||
|
|
||||||
range_data = []
|
range_data = []
|
||||||
for id in range(5000, 5501):
|
for id in range(5000, 5501):
|
||||||
if id in id_to_count:
|
if id in id_to_count:
|
||||||
# Find the pair to get char and pinyin
|
# Find the pair to get char and pinyin
|
||||||
for key, pair in pairs.items():
|
for key, pair in pairs.items():
|
||||||
if pair.get('id') == id:
|
if pair.get("id") == id:
|
||||||
char = pair.get('char', '')
|
char = pair.get("char", "")
|
||||||
pinyin = pair.get('pinyin', '')
|
pinyin = pair.get("pinyin", "")
|
||||||
count = pair.get('count', 0)
|
count = pair.get("count", 0)
|
||||||
range_data.append((id, count, char, pinyin))
|
range_data.append((id, count, char, pinyin))
|
||||||
if id % 100 == 0: # Print every 100th for overview
|
if id % 100 == 0: # Print every 100th for overview
|
||||||
print(f"{id}\t{count}\t{char}\t{pinyin}")
|
print(f"{id}\t{count}\t{char}\t{pinyin}")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Print min and max in range
|
# Print min and max in range
|
||||||
if range_data:
|
if range_data:
|
||||||
min_item = min(range_data, key=lambda x: x[1])
|
min_item = min(range_data, key=lambda x: x[1])
|
||||||
max_item = max(range_data, key=lambda x: x[1])
|
max_item = max(range_data, key=lambda x: x[1])
|
||||||
print(f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'")
|
print(
|
||||||
print(f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'")
|
f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'"
|
||||||
|
)
|
||||||
|
|
||||||
# Check if frequencies are monotonic in this range
|
# Check if frequencies are monotonic in this range
|
||||||
counts = [item[1] for item in range_data]
|
counts = [item[1] for item in range_data]
|
||||||
increasing = all(counts[i] <= counts[i+1] for i in range(len(counts)-1))
|
increasing = all(counts[i] <= counts[i + 1] for i in range(len(counts) - 1))
|
||||||
decreasing = all(counts[i] >= counts[i+1] for i in range(len(counts)-1))
|
decreasing = all(counts[i] >= counts[i + 1] for i in range(len(counts) - 1))
|
||||||
print(f"Monotonic in range: increasing={increasing}, decreasing={decreasing}")
|
print(f"Monotonic in range: increasing={increasing}, decreasing={decreasing}")
|
||||||
|
|
||||||
# Check for frequency plateaus
|
# Check for frequency plateaus
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
freq_count = Counter(counts)
|
freq_count = Counter(counts)
|
||||||
most_common = freq_count.most_common(5)
|
most_common = freq_count.most_common(5)
|
||||||
print(f"Most common frequencies in range: {most_common}")
|
print(f"Most common frequencies in range: {most_common}")
|
||||||
|
|
||||||
# Analyze the tail (IDs with frequency 1)
|
# Analyze the tail (IDs with frequency 1)
|
||||||
print("\n\nAnalysis of frequency=1 entries:")
|
print("\n\nAnalysis of frequency=1 entries:")
|
||||||
freq_one_ids = [id for id, count in id_to_count.items() if count == 1]
|
freq_one_ids = [id for id, count in id_to_count.items() if count == 1]
|
||||||
|
|
@ -67,30 +79,34 @@ def main():
|
||||||
print(f"ID range of frequency=1: {min(freq_one_ids)} to {max(freq_one_ids)}")
|
print(f"ID range of frequency=1: {min(freq_one_ids)} to {max(freq_one_ids)}")
|
||||||
print(f"First 10 IDs: {freq_one_ids[:10]}")
|
print(f"First 10 IDs: {freq_one_ids[:10]}")
|
||||||
print(f"Last 10 IDs: {freq_one_ids[-10:]}")
|
print(f"Last 10 IDs: {freq_one_ids[-10:]}")
|
||||||
|
|
||||||
# Check if they're contiguous
|
# Check if they're contiguous
|
||||||
sorted_ids = sorted(freq_one_ids)
|
sorted_ids = sorted(freq_one_ids)
|
||||||
contiguous = all(sorted_ids[i] + 1 == sorted_ids[i+1] for i in range(len(sorted_ids)-1))
|
contiguous = all(
|
||||||
|
sorted_ids[i] + 1 == sorted_ids[i + 1] for i in range(len(sorted_ids) - 1)
|
||||||
|
)
|
||||||
print(f"Are they contiguous IDs? {contiguous}")
|
print(f"Are they contiguous IDs? {contiguous}")
|
||||||
|
|
||||||
# Sample some characters
|
# Sample some characters
|
||||||
print("\nSample characters with frequency=1:")
|
print("\nSample characters with frequency=1:")
|
||||||
sample_count = 0
|
sample_count = 0
|
||||||
for key, pair in pairs.items():
|
for key, pair in pairs.items():
|
||||||
if pair.get('count') == 1 and sample_count < 10:
|
if pair.get("count") == 1 and sample_count < 10:
|
||||||
print(f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'")
|
print(
|
||||||
|
f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'"
|
||||||
|
)
|
||||||
sample_count += 1
|
sample_count += 1
|
||||||
|
|
||||||
# Check overall ID-frequency ordering
|
# Check overall ID-frequency ordering
|
||||||
print("\n\nOverall ID-frequency ordering analysis:")
|
print("\n\nOverall ID-frequency ordering analysis:")
|
||||||
all_ids = sorted(id_to_count.keys())
|
all_ids = sorted(id_to_count.keys())
|
||||||
all_counts = [id_to_count[id] for id in all_ids]
|
all_counts = [id_to_count[id] for id in all_ids]
|
||||||
|
|
||||||
# Count monotonic segments
|
# Count monotonic segments
|
||||||
non_increasing_segments = 0
|
non_increasing_segments = 0
|
||||||
current_segment_length = 1
|
current_segment_length = 1
|
||||||
for i in range(1, len(all_counts)):
|
for i in range(1, len(all_counts)):
|
||||||
if all_counts[i] <= all_counts[i-1]:
|
if all_counts[i] <= all_counts[i - 1]:
|
||||||
current_segment_length += 1
|
current_segment_length += 1
|
||||||
else:
|
else:
|
||||||
if current_segment_length > 1:
|
if current_segment_length > 1:
|
||||||
|
|
@ -98,18 +114,22 @@ def main():
|
||||||
current_segment_length = 1
|
current_segment_length = 1
|
||||||
if current_segment_length > 1:
|
if current_segment_length > 1:
|
||||||
non_increasing_segments += 1
|
non_increasing_segments += 1
|
||||||
|
|
||||||
print(f"Total IDs: {len(all_ids)}")
|
print(f"Total IDs: {len(all_ids)}")
|
||||||
print(f"Non-increasing segments: {non_increasing_segments}")
|
print(f"Non-increasing segments: {non_increasing_segments}")
|
||||||
|
|
||||||
# Check for frequency plateaus overall
|
# Check for frequency plateaus overall
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
overall_freq_count = Counter(all_counts)
|
overall_freq_count = Counter(all_counts)
|
||||||
plateaus = [(freq, count) for freq, count in overall_freq_count.items() if count > 1]
|
plateaus = [
|
||||||
|
(freq, count) for freq, count in overall_freq_count.items() if count > 1
|
||||||
|
]
|
||||||
plateaus_sorted = sorted(plateaus, key=lambda x: x[1], reverse=True)[:10]
|
plateaus_sorted = sorted(plateaus, key=lambda x: x[1], reverse=True)[:10]
|
||||||
print(f"Top 10 frequency plateaus (freq: count of IDs sharing that freq):")
|
print(f"Top 10 frequency plateaus (freq: count of IDs sharing that freq):")
|
||||||
for freq, count in plateaus_sorted:
|
for freq, count in plateaus_sorted:
|
||||||
print(f" {freq}: {count} IDs")
|
print(f" {freq}: {count} IDs")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
@ -1,10 +1,15 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||||
|
|
||||||
from model.dataset import PinyinInputDataset
|
from model.dataset import PinyinInputDataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from model.trainer import collate_fn, worker_init_fn
|
from model.trainer import collate_fn, worker_init_fn
|
||||||
|
|
||||||
|
|
||||||
data = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
|
data = PinyinInputDataset("/home/songsenand/Data/corpus/CCI-Data/")
|
||||||
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
data,
|
data,
|
||||||
|
|
@ -18,5 +23,5 @@ dataloader = DataLoader(
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in dataloader:
|
for i in dataloader:
|
||||||
print((i['labels'] == 1).sum())
|
print((i["labels"] == 1).sum())
|
||||||
break
|
break
|
||||||
|
|
@ -9,90 +9,125 @@ import math
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
json_path = (
|
||||||
with open(json_path, 'r', encoding='utf-8') as f:
|
Path(__file__).parent.parent
|
||||||
|
/ "src"
|
||||||
|
/ "model"
|
||||||
|
/ "assets"
|
||||||
|
/ "pinyin_char_statistics.json"
|
||||||
|
)
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
pairs = data.get('pairs', {})
|
pairs = data.get("pairs", {})
|
||||||
|
|
||||||
# Extract counts
|
# Extract counts
|
||||||
counts = []
|
counts = []
|
||||||
for key, pair in pairs.items():
|
for key, pair in pairs.items():
|
||||||
count = pair.get('count')
|
count = pair.get("count")
|
||||||
if count is not None:
|
if count is not None:
|
||||||
counts.append(count)
|
counts.append(count)
|
||||||
|
|
||||||
n = len(counts)
|
n = len(counts)
|
||||||
print(f"Total entries: {n}")
|
print(f"Total entries: {n}")
|
||||||
|
|
||||||
# Sort descending for rank-frequency analysis
|
# Sort descending for rank-frequency analysis
|
||||||
counts_sorted_desc = sorted(counts, reverse=True)
|
counts_sorted_desc = sorted(counts, reverse=True)
|
||||||
|
|
||||||
# Basic statistics
|
# Basic statistics
|
||||||
min_count = min(counts)
|
min_count = min(counts)
|
||||||
max_count = max(counts)
|
max_count = max(counts)
|
||||||
mean_count = sum(counts) / n
|
mean_count = sum(counts) / n
|
||||||
|
|
||||||
# Percentiles
|
# Percentiles
|
||||||
percentiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
|
percentiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
|
||||||
print("\n=== PERCENTILE DISTRIBUTION ===")
|
print("\n=== PERCENTILE DISTRIBUTION ===")
|
||||||
for p in percentiles:
|
for p in percentiles:
|
||||||
idx = int(p * n)
|
idx = int(p * n)
|
||||||
value = counts_sorted_desc[idx]
|
value = counts_sorted_desc[idx]
|
||||||
print(f"{p*100:5.1f}%: {value:>12} (rank ~{idx})")
|
print(f"{p * 100:5.1f}%: {value:>12} (rank ~{idx})")
|
||||||
|
|
||||||
# Cumulative distribution
|
# Cumulative distribution
|
||||||
print("\n=== CUMULATIVE DISTRIBUTION ===")
|
print("\n=== CUMULATIVE DISTRIBUTION ===")
|
||||||
thresholds = [1, 2, 3, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 500000, 1000000, 5000000, 10000000, 50000000, 100000000, 500000000]
|
thresholds = [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
10,
|
||||||
|
20,
|
||||||
|
50,
|
||||||
|
100,
|
||||||
|
200,
|
||||||
|
500,
|
||||||
|
1000,
|
||||||
|
2000,
|
||||||
|
5000,
|
||||||
|
10000,
|
||||||
|
20000,
|
||||||
|
50000,
|
||||||
|
100000,
|
||||||
|
200000,
|
||||||
|
500000,
|
||||||
|
1000000,
|
||||||
|
5000000,
|
||||||
|
10000000,
|
||||||
|
50000000,
|
||||||
|
100000000,
|
||||||
|
500000000,
|
||||||
|
]
|
||||||
for thresh in thresholds:
|
for thresh in thresholds:
|
||||||
if thresh > max_count:
|
if thresh > max_count:
|
||||||
break
|
break
|
||||||
below = sum(1 for c in counts if c <= thresh)
|
below = sum(1 for c in counts if c <= thresh)
|
||||||
above = sum(1 for c in counts if c >= thresh)
|
above = sum(1 for c in counts if c >= thresh)
|
||||||
print(f"Count <= {thresh:10}: {below:6} entries ({below/n*100:5.1f}%)")
|
print(f"Count <= {thresh:10}: {below:6} entries ({below / n * 100:5.1f}%)")
|
||||||
# print(f"Count >= {thresh:10}: {above:6} entries ({above/n*100:5.1f}%)")
|
# print(f"Count >= {thresh:10}: {above:6} entries ({above/n*100:5.1f}%)")
|
||||||
|
|
||||||
# Check min_count=109 parameter
|
# Check min_count=109 parameter
|
||||||
print("\n=== ANALYSIS OF THRESHOLD 109 ===")
|
print("\n=== ANALYSIS OF THRESHOLD 109 ===")
|
||||||
below_109 = sum(1 for c in counts if c < 109)
|
below_109 = sum(1 for c in counts if c < 109)
|
||||||
at_or_above_109 = sum(1 for c in counts if c >= 109)
|
at_or_above_109 = sum(1 for c in counts if c >= 109)
|
||||||
print(f"Entries with count < 109: {below_109} ({below_109/n*100:.1f}%)")
|
print(f"Entries with count < 109: {below_109} ({below_109 / n * 100:.1f}%)")
|
||||||
print(f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109/n*100:.1f}%)")
|
print(
|
||||||
|
f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109 / n * 100:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
# If 109 is a threshold, what's the actual min among those >= 109?
|
# If 109 is a threshold, what's the actual min among those >= 109?
|
||||||
counts_ge_109 = [c for c in counts if c >= 109]
|
counts_ge_109 = [c for c in counts if c >= 109]
|
||||||
if counts_ge_109:
|
if counts_ge_109:
|
||||||
actual_min_ge_109 = min(counts_ge_109)
|
actual_min_ge_109 = min(counts_ge_109)
|
||||||
print(f"Actual min frequency among those >= 109: {actual_min_ge_109}")
|
print(f"Actual min frequency among those >= 109: {actual_min_ge_109}")
|
||||||
|
|
||||||
# Rank-frequency analysis (Zipf's law)
|
# Rank-frequency analysis (Zipf's law)
|
||||||
print("\n=== RANK-FREQUENCY ANALYSIS (Top 100) ===")
|
print("\n=== RANK-FREQUENCY ANALYSIS (Top 100) ===")
|
||||||
print("Rank\tFrequency\tlog(rank)\tlog(freq)")
|
print("Rank\tFrequency\tlog(rank)\tlog(freq)")
|
||||||
for rank in range(1, 101):
|
for rank in range(1, 101):
|
||||||
freq = counts_sorted_desc[rank-1]
|
freq = counts_sorted_desc[rank - 1]
|
||||||
print(f"{rank}\t{freq}\t{math.log(rank):.3f}\t{math.log(freq):.3f}")
|
print(f"{rank}\t{freq}\t{math.log(rank):.3f}\t{math.log(freq):.3f}")
|
||||||
|
|
||||||
# Frequency spectrum (how many distinct frequencies)
|
# Frequency spectrum (how many distinct frequencies)
|
||||||
freq_counter = Counter(counts)
|
freq_counter = Counter(counts)
|
||||||
print(f"\n=== FREQUENCY SPECTRUM ===")
|
print(f"\n=== FREQUENCY SPECTRUM ===")
|
||||||
print(f"Distinct frequency values: {len(freq_counter)}")
|
print(f"Distinct frequency values: {len(freq_counter)}")
|
||||||
|
|
||||||
# Most common frequencies
|
# Most common frequencies
|
||||||
print("\nTop 20 most common frequencies (plateau sizes):")
|
print("\nTop 20 most common frequencies (plateau sizes):")
|
||||||
for freq, freq_count in freq_counter.most_common(20):
|
for freq, freq_count in freq_counter.most_common(20):
|
||||||
print(f" Frequency {freq}: {freq_count} entries")
|
print(f" Frequency {freq}: {freq_count} entries")
|
||||||
|
|
||||||
# Analyze ID ranges
|
# Analyze ID ranges
|
||||||
print("\n=== ID RANGE ANALYSIS ===")
|
print("\n=== ID RANGE ANALYSIS ===")
|
||||||
# Build ID to count mapping
|
# Build ID to count mapping
|
||||||
id_to_count = {}
|
id_to_count = {}
|
||||||
for key, pair in pairs.items():
|
for key, pair in pairs.items():
|
||||||
char_id = pair.get('id')
|
char_id = pair.get("id")
|
||||||
count = pair.get('count')
|
count = pair.get("count")
|
||||||
if char_id is not None and count is not None:
|
if char_id is not None and count is not None:
|
||||||
id_to_count[char_id] = count
|
id_to_count[char_id] = count
|
||||||
|
|
||||||
ranges = [
|
ranges = [
|
||||||
(0, 100, "Top 100 IDs"),
|
(0, 100, "Top 100 IDs"),
|
||||||
(100, 500, "IDs 100-500"),
|
(100, 500, "IDs 100-500"),
|
||||||
|
|
@ -106,66 +141,80 @@ def main():
|
||||||
(19000, 19500, "IDs 19000-19500 (before freq=1)"),
|
(19000, 19500, "IDs 19000-19500 (before freq=1)"),
|
||||||
(19499, 20647, "IDs with freq=1"),
|
(19499, 20647, "IDs with freq=1"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for start, end, label in ranges:
|
for start, end, label in ranges:
|
||||||
range_counts = [id_to_count[id] for id in range(start, end) if id in id_to_count]
|
range_counts = [
|
||||||
|
id_to_count[id] for id in range(start, end) if id in id_to_count
|
||||||
|
]
|
||||||
if range_counts:
|
if range_counts:
|
||||||
min_c = min(range_counts)
|
min_c = min(range_counts)
|
||||||
max_c = max(range_counts)
|
max_c = max(range_counts)
|
||||||
mean_c = sum(range_counts) / len(range_counts)
|
mean_c = sum(range_counts) / len(range_counts)
|
||||||
median_c = sorted(range_counts)[len(range_counts)//2]
|
median_c = sorted(range_counts)[len(range_counts) // 2]
|
||||||
print(f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}")
|
print(
|
||||||
|
f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}"
|
||||||
|
)
|
||||||
|
|
||||||
# Check if IDs are perfectly sorted by frequency
|
# Check if IDs are perfectly sorted by frequency
|
||||||
print("\n=== ID ORDER VERIFICATION ===")
|
print("\n=== ID ORDER VERIFICATION ===")
|
||||||
all_ids = sorted(id_to_count.keys())
|
all_ids = sorted(id_to_count.keys())
|
||||||
all_counts = [id_to_count[id] for id in all_ids]
|
all_counts = [id_to_count[id] for id in all_ids]
|
||||||
|
|
||||||
# Check for any violations of non-increasing order
|
# Check for any violations of non-increasing order
|
||||||
violations = 0
|
violations = 0
|
||||||
for i in range(1, len(all_counts)):
|
for i in range(1, len(all_counts)):
|
||||||
if all_counts[i] > all_counts[i-1]:
|
if all_counts[i] > all_counts[i - 1]:
|
||||||
violations += 1
|
violations += 1
|
||||||
if violations <= 5:
|
if violations <= 5:
|
||||||
print(f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i-1]} (ID {all_ids[i-1]})")
|
print(
|
||||||
|
f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i - 1]} (ID {all_ids[i - 1]})"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Total violations of non-increasing order: {violations}")
|
print(f"Total violations of non-increasing order: {violations}")
|
||||||
|
|
||||||
# Check if equal frequencies are grouped together
|
# Check if equal frequencies are grouped together
|
||||||
print("\n=== FREQUENCY GROUPING ANALYSIS ===")
|
print("\n=== FREQUENCY GROUPING ANALYSIS ===")
|
||||||
current_freq = None
|
current_freq = None
|
||||||
group_start = None
|
group_start = None
|
||||||
group_sizes = []
|
group_sizes = []
|
||||||
|
|
||||||
for i, (id, count) in enumerate(zip(all_ids, all_counts)):
|
for i, (id, count) in enumerate(zip(all_ids, all_counts)):
|
||||||
if count != current_freq:
|
if count != current_freq:
|
||||||
if current_freq is not None:
|
if current_freq is not None:
|
||||||
group_sizes.append((current_freq, group_start, all_ids[i-1], i - group_start))
|
group_sizes.append(
|
||||||
|
(current_freq, group_start, all_ids[i - 1], i - group_start)
|
||||||
|
)
|
||||||
current_freq = count
|
current_freq = count
|
||||||
group_start = i
|
group_start = i
|
||||||
|
|
||||||
# Last group
|
# Last group
|
||||||
if current_freq is not None:
|
if current_freq is not None:
|
||||||
group_sizes.append((current_freq, group_start, all_ids[-1], len(all_ids) - group_start))
|
group_sizes.append(
|
||||||
|
(current_freq, group_start, all_ids[-1], len(all_ids) - group_start)
|
||||||
|
)
|
||||||
|
|
||||||
# Sort groups by size
|
# Sort groups by size
|
||||||
group_sizes.sort(key=lambda x: x[3], reverse=True)
|
group_sizes.sort(key=lambda x: x[3], reverse=True)
|
||||||
print("Top 10 largest frequency groups (plateaus):")
|
print("Top 10 largest frequency groups (plateaus):")
|
||||||
for freq, start_id_idx, end_id, size in group_sizes[:10]:
|
for freq, start_id_idx, end_id, size in group_sizes[:10]:
|
||||||
start_id = all_ids[start_id_idx]
|
start_id = all_ids[start_id_idx]
|
||||||
print(f" Frequency {freq}: IDs {start_id}-{end_id} ({size} entries)")
|
print(f" Frequency {freq}: IDs {start_id}-{end_id} ({size} entries)")
|
||||||
|
|
||||||
# Summary for smoothing algorithm
|
# Summary for smoothing algorithm
|
||||||
print("\n=== SMOOTHING ALGORITHM IMPLICATIONS ===")
|
print("\n=== SMOOTHING ALGORITHM IMPLICATIONS ===")
|
||||||
print("1. IDs are perfectly sorted by frequency (non-increasing).")
|
print("1. IDs are perfectly sorted by frequency (non-increasing).")
|
||||||
print(f"2. Frequency range: {min_count} to {max_count} (ratio {max_count/min_count:.1e}:1).")
|
print(
|
||||||
print(f"3. {below_109} entries ({below_109/n*100:.1f}%) have frequency < 109.")
|
f"2. Frequency range: {min_count} to {max_count} (ratio {max_count / min_count:.1e}:1)."
|
||||||
print(f"4. Median frequency: {counts_sorted_desc[n//2]}.")
|
)
|
||||||
print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9*n)]}.")
|
print(f"3. {below_109} entries ({below_109 / n * 100:.1f}%) have frequency < 109.")
|
||||||
print(f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01*n)]}.")
|
print(f"4. Median frequency: {counts_sorted_desc[n // 2]}.")
|
||||||
|
print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9 * n)]}.")
|
||||||
|
print(
|
||||||
|
f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01 * n)]}."
|
||||||
|
)
|
||||||
print("7. Large frequency plateaus exist (many IDs share same frequency).")
|
print("7. Large frequency plateaus exist (many IDs share same frequency).")
|
||||||
print("8. Smoothing should handle extreme frequency ratios (1:5e8).")
|
print("8. Smoothing should handle extreme frequency ratios (1:5e8).")
|
||||||
|
|
||||||
# Save data for plotting
|
# Save data for plotting
|
||||||
with open("rank_freq.csv", "w") as f:
|
with open("rank_freq.csv", "w") as f:
|
||||||
f.write("rank,frequency\n")
|
f.write("rank,frequency\n")
|
||||||
|
|
@ -173,5 +222,6 @@ def main():
|
||||||
f.write(f"{rank},{freq}\n")
|
f.write(f"{rank},{freq}\n")
|
||||||
print("\nRank-frequency data saved to rank_freq.csv")
|
print("\nRank-frequency data saved to rank_freq.csv")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
@ -3,6 +3,7 @@ from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def modify_pinyin_statistics(file_path: Path) -> None:
|
def modify_pinyin_statistics(file_path: Path) -> None:
|
||||||
"""
|
"""
|
||||||
一次性修改拼音统计JSON文件。
|
一次性修改拼音统计JSON文件。
|
||||||
|
|
@ -13,7 +14,7 @@ def modify_pinyin_statistics(file_path: Path) -> None:
|
||||||
"""
|
"""
|
||||||
# 1. 加载原数据
|
# 1. 加载原数据
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
data: Dict[str, Any] = json.load(f)
|
data: Dict[str, Any] = json.load(f)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"错误:文件不存在 {file_path}", file=sys.stderr)
|
print(f"错误:文件不存在 {file_path}", file=sys.stderr)
|
||||||
|
|
@ -34,7 +35,7 @@ def modify_pinyin_statistics(file_path: Path) -> None:
|
||||||
"id": 0,
|
"id": 0,
|
||||||
"char": "",
|
"char": "",
|
||||||
"pinyin": "",
|
"pinyin": "",
|
||||||
"count": original_zero_count + 1 # 原count + 1
|
"count": original_zero_count + 1, # 原count + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
# 3.2 处理其他所有记录,键和id都+1
|
# 3.2 处理其他所有记录,键和id都+1
|
||||||
|
|
@ -54,15 +55,16 @@ def modify_pinyin_statistics(file_path: Path) -> None:
|
||||||
# 这里保持原时间戳不变,因为是一次性修改
|
# 这里保持原时间戳不变,因为是一次性修改
|
||||||
|
|
||||||
# 写回文件,保持可读格式
|
# 写回文件,保持可读格式
|
||||||
backup_path = file_path.with_suffix('.json.bak')
|
backup_path = file_path.with_suffix(".json.bak")
|
||||||
try:
|
try:
|
||||||
# 先备份原文件
|
# 先备份原文件
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
shutil.copy2(file_path, backup_path)
|
shutil.copy2(file_path, backup_path)
|
||||||
print(f"已创建备份: {backup_path}")
|
print(f"已创建备份: {backup_path}")
|
||||||
|
|
||||||
# 写入新数据
|
# 写入新数据
|
||||||
with open(file_path, 'w', encoding='utf-8') as f:
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
print(f"修改完成!")
|
print(f"修改完成!")
|
||||||
|
|
@ -78,15 +80,21 @@ def modify_pinyin_statistics(file_path: Path) -> None:
|
||||||
|
|
||||||
# 使用示例
|
# 使用示例
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 假设你的JSON文件在当前目录
|
# JSON文件相对于项目根目录
|
||||||
json_file = Path("./src/model/assets/pinyin_char_statistics.json")
|
json_file = (
|
||||||
|
Path(__file__).parent.parent
|
||||||
|
/ "src"
|
||||||
|
/ "model"
|
||||||
|
/ "assets"
|
||||||
|
/ "pinyin_char_statistics.json"
|
||||||
|
)
|
||||||
|
|
||||||
# 执行修改
|
# 执行修改
|
||||||
modify_pinyin_statistics(json_file)
|
modify_pinyin_statistics(json_file)
|
||||||
|
|
||||||
# 验证修改:读取并显示前几条记录
|
# 验证修改:读取并显示前几条记录
|
||||||
print("\n验证前5条记录:")
|
print("\n验证前5条记录:")
|
||||||
with open(json_file, 'r', encoding='utf-8') as f:
|
with open(json_file, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
|
|
@ -8,110 +8,127 @@ import math
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def ascii_histogram(data, bins=20, width=60):
|
def ascii_histogram(data, bins=20, width=60):
|
||||||
"""Create ASCII histogram"""
|
"""Create ASCII histogram"""
|
||||||
if not data:
|
if not data:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
min_val = min(data)
|
min_val = min(data)
|
||||||
max_val = max(data)
|
max_val = max(data)
|
||||||
|
|
||||||
# Use log bins for wide range
|
# Use log bins for wide range
|
||||||
if max_val / min_val > 1000:
|
if max_val / min_val > 1000:
|
||||||
log_min = math.log10(min_val) if min_val > 0 else 0
|
log_min = math.log10(min_val) if min_val > 0 else 0
|
||||||
log_max = math.log10(max_val)
|
log_max = math.log10(max_val)
|
||||||
bin_edges = [10**(log_min + i*(log_max-log_min)/bins) for i in range(bins+1)]
|
bin_edges = [
|
||||||
|
10 ** (log_min + i * (log_max - log_min) / bins) for i in range(bins + 1)
|
||||||
|
]
|
||||||
hist = [0] * bins
|
hist = [0] * bins
|
||||||
for val in data:
|
for val in data:
|
||||||
if val > 0:
|
if val > 0:
|
||||||
log_val = math.log10(val)
|
log_val = math.log10(val)
|
||||||
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * bins), bins-1)
|
bin_idx = min(
|
||||||
|
int((log_val - log_min) / (log_max - log_min) * bins), bins - 1
|
||||||
|
)
|
||||||
hist[bin_idx] += 1
|
hist[bin_idx] += 1
|
||||||
bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i+1]:.1e}" for i in range(bins)]
|
bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i + 1]:.1e}" for i in range(bins)]
|
||||||
else:
|
else:
|
||||||
bin_width = (max_val - min_val) / bins
|
bin_width = (max_val - min_val) / bins
|
||||||
bin_edges = [min_val + i*bin_width for i in range(bins+1)]
|
bin_edges = [min_val + i * bin_width for i in range(bins + 1)]
|
||||||
hist = [0] * bins
|
hist = [0] * bins
|
||||||
for val in data:
|
for val in data:
|
||||||
bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins-1)
|
bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins - 1)
|
||||||
hist[bin_idx] += 1
|
hist[bin_idx] += 1
|
||||||
bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i+1]:.1f}" for i in range(bins)]
|
bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i + 1]:.1f}" for i in range(bins)]
|
||||||
|
|
||||||
max_count = max(hist)
|
max_count = max(hist)
|
||||||
result = []
|
result = []
|
||||||
for i in range(bins):
|
for i in range(bins):
|
||||||
if hist[i] == 0:
|
if hist[i] == 0:
|
||||||
continue
|
continue
|
||||||
bar = '#' * int(hist[i] / max_count * width)
|
bar = "#" * int(hist[i] / max_count * width)
|
||||||
result.append(f"{bin_labels[i]:20} | {bar} {hist[i]}")
|
result.append(f"{bin_labels[i]:20} | {bar} {hist[i]}")
|
||||||
|
|
||||||
return "\n".join(result)
|
return "\n".join(result)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
json_path = (
|
||||||
with open(json_path, 'r', encoding='utf-8') as f:
|
Path(__file__).parent.parent
|
||||||
|
/ "src"
|
||||||
|
/ "model"
|
||||||
|
/ "assets"
|
||||||
|
/ "pinyin_char_statistics.json"
|
||||||
|
)
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
pairs = data.get('pairs', {})
|
pairs = data.get("pairs", {})
|
||||||
counts = [pair.get('count', 0) for pair in pairs.values() if pair.get('count') is not None]
|
counts = [
|
||||||
|
pair.get("count", 0) for pair in pairs.values() if pair.get("count") is not None
|
||||||
|
]
|
||||||
|
|
||||||
print("FREQUENCY DISTRIBUTION ANALYSIS")
|
print("FREQUENCY DISTRIBUTION ANALYSIS")
|
||||||
print("="*60)
|
print("=" * 60)
|
||||||
|
|
||||||
print("\n1. ASCII Histogram (log bins):")
|
print("\n1. ASCII Histogram (log bins):")
|
||||||
print(ascii_histogram(counts, bins=20, width=60))
|
print(ascii_histogram(counts, bins=20, width=60))
|
||||||
|
|
||||||
# Rank-frequency plot in ASCII
|
# Rank-frequency plot in ASCII
|
||||||
print("\n2. Rank-Frequency Relationship (Top 50):")
|
print("\n2. Rank-Frequency Relationship (Top 50):")
|
||||||
counts_sorted_desc = sorted(counts, reverse=True)
|
counts_sorted_desc = sorted(counts, reverse=True)
|
||||||
max_freq = counts_sorted_desc[0]
|
max_freq = counts_sorted_desc[0]
|
||||||
max_rank = 50
|
max_rank = 50
|
||||||
|
|
||||||
for rank in range(1, max_rank + 1):
|
for rank in range(1, max_rank + 1):
|
||||||
freq = counts_sorted_desc[rank-1]
|
freq = counts_sorted_desc[rank - 1]
|
||||||
bar_length = int(math.log(freq) / math.log(max_freq) * 40)
|
bar_length = int(math.log(freq) / math.log(max_freq) * 40)
|
||||||
bar = '#' * bar_length
|
bar = "#" * bar_length
|
||||||
print(f"Rank {rank:3}: {freq:12} {bar}")
|
print(f"Rank {rank:3}: {freq:12} {bar}")
|
||||||
|
|
||||||
# ID vs Frequency plot (sampled)
|
# ID vs Frequency plot (sampled)
|
||||||
print("\n3. ID vs Frequency (sampled every 500 IDs):")
|
print("\n3. ID vs Frequency (sampled every 500 IDs):")
|
||||||
# Build ID to count mapping
|
# Build ID to count mapping
|
||||||
id_to_count = {}
|
id_to_count = {}
|
||||||
for key, pair in pairs.items():
|
for key, pair in pairs.items():
|
||||||
char_id = pair.get('id')
|
char_id = pair.get("id")
|
||||||
count = pair.get('count')
|
count = pair.get("count")
|
||||||
if char_id is not None and count is not None:
|
if char_id is not None and count is not None:
|
||||||
id_to_count[char_id] = count
|
id_to_count[char_id] = count
|
||||||
|
|
||||||
all_ids = sorted(id_to_count.keys())
|
all_ids = sorted(id_to_count.keys())
|
||||||
max_id = all_ids[-1]
|
max_id = all_ids[-1]
|
||||||
|
|
||||||
print("ID Frequency log10(freq)")
|
print("ID Frequency log10(freq)")
|
||||||
for id in range(0, max_id + 1, 500):
|
for id in range(0, max_id + 1, 500):
|
||||||
if id in id_to_count:
|
if id in id_to_count:
|
||||||
freq = id_to_count[id]
|
freq = id_to_count[id]
|
||||||
log_freq = math.log10(freq) if freq > 0 else 0
|
log_freq = math.log10(freq) if freq > 0 else 0
|
||||||
bar = '#' * int(log_freq / math.log10(max_freq) * 40)
|
bar = "#" * int(log_freq / math.log10(max_freq) * 40)
|
||||||
print(f"{id:6} {freq:10} {log_freq:6.2f} {bar}")
|
print(f"{id:6} {freq:10} {log_freq:6.2f} {bar}")
|
||||||
|
|
||||||
# Zipf's law fit
|
# Zipf's law fit
|
||||||
print("\n4. Zipf's Law Analysis:")
|
print("\n4. Zipf's Law Analysis:")
|
||||||
print(" Rank * Frequency ≈ constant for Zipf's law")
|
print(" Rank * Frequency ≈ constant for Zipf's law")
|
||||||
print(" Top 10 ranks:")
|
print(" Top 10 ranks:")
|
||||||
for rank in range(1, 11):
|
for rank in range(1, 11):
|
||||||
freq = counts_sorted_desc[rank-1]
|
freq = counts_sorted_desc[rank - 1]
|
||||||
product = rank * freq
|
product = rank * freq
|
||||||
print(f" Rank {rank}: {freq:12} rank*freq = {product:.3e}")
|
print(f" Rank {rank}: {freq:12} rank*freq = {product:.3e}")
|
||||||
|
|
||||||
# Check if product is roughly constant
|
# Check if product is roughly constant
|
||||||
products = [(rank+1) * counts_sorted_desc[rank] for rank in range(10)]
|
products = [(rank + 1) * counts_sorted_desc[rank] for rank in range(10)]
|
||||||
avg_product = sum(products) / len(products)
|
avg_product = sum(products) / len(products)
|
||||||
std_product = math.sqrt(sum((p - avg_product)**2 for p in products) / len(products))
|
std_product = math.sqrt(
|
||||||
|
sum((p - avg_product) ** 2 for p in products) / len(products)
|
||||||
|
)
|
||||||
print(f" Average product (ranks 2-11): {avg_product:.3e} ± {std_product:.3e}")
|
print(f" Average product (ranks 2-11): {avg_product:.3e} ± {std_product:.3e}")
|
||||||
print(f" Coefficient of variation: {std_product/avg_product*100:.1f}%")
|
print(f" Coefficient of variation: {std_product / avg_product * 100:.1f}%")
|
||||||
|
|
||||||
# Frequency spectrum
|
# Frequency spectrum
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
freq_counter = Counter(counts)
|
freq_counter = Counter(counts)
|
||||||
print("\n5. Frequency Spectrum (how many entries have each frequency):")
|
print("\n5. Frequency Spectrum (how many entries have each frequency):")
|
||||||
print(" Frequency Count Cumulative")
|
print(" Frequency Count Cumulative")
|
||||||
|
|
@ -120,21 +137,21 @@ def main():
|
||||||
count = freq_counter[freq]
|
count = freq_counter[freq]
|
||||||
cum += count
|
cum += count
|
||||||
print(f" {freq:10} {count:6} {cum:6}")
|
print(f" {freq:10} {count:6} {cum:6}")
|
||||||
|
|
||||||
# Summary statistics
|
# Summary statistics
|
||||||
print("\n6. Key Statistics:")
|
print("\n6. Key Statistics:")
|
||||||
n = len(counts)
|
n = len(counts)
|
||||||
print(f" Total entries: {n}")
|
print(f" Total entries: {n}")
|
||||||
print(f" Min frequency: {min(counts)}")
|
print(f" Min frequency: {min(counts)}")
|
||||||
print(f" Max frequency: {max(counts)}")
|
print(f" Max frequency: {max(counts)}")
|
||||||
print(f" Ratio max/min: {max(counts)/min(counts):.2e}")
|
print(f" Ratio max/min: {max(counts) / min(counts):.2e}")
|
||||||
|
|
||||||
percentiles = [0.01, 0.1, 0.5, 0.9, 0.99]
|
percentiles = [0.01, 0.1, 0.5, 0.9, 0.99]
|
||||||
for p in percentiles:
|
for p in percentiles:
|
||||||
idx = int(p * n)
|
idx = int(p * n)
|
||||||
value = counts_sorted_desc[idx]
|
value = counts_sorted_desc[idx]
|
||||||
print(f" {p*100:5.1f}th percentile: {value:12} (rank ~{idx})")
|
print(f" {p * 100:5.1f}th percentile: {value:12} (rank ~{idx})")
|
||||||
|
|
||||||
# Save data for external plotting
|
# Save data for external plotting
|
||||||
with open("id_vs_freq.csv", "w") as f:
|
with open("id_vs_freq.csv", "w") as f:
|
||||||
f.write("id,frequency\n")
|
f.write("id,frequency\n")
|
||||||
|
|
@ -142,5 +159,6 @@ def main():
|
||||||
f.write(f"{id},{id_to_count[id]}\n")
|
f.write(f"{id},{id_to_count[id]}\n")
|
||||||
print("\nData saved to id_vs_freq.csv for external plotting")
|
print("\nData saved to id_vs_freq.csv for external plotting")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
@ -5,10 +5,9 @@ warnings.filterwarnings("ignore", message=".*pkg_resources.*")
|
||||||
import jieba
|
import jieba
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import re
|
|
||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -21,7 +20,6 @@ from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
from .query import QueryEngine
|
from .query import QueryEngine
|
||||||
|
|
||||||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
|
||||||
|
|
||||||
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
|
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
|
||||||
CHAR_TO_ID["`"] = 27 # 显式添加反引号
|
CHAR_TO_ID["`"] = 27 # 显式添加反引号
|
||||||
|
|
@ -76,6 +74,8 @@ class PinyinInputDataset(IterableDataset):
|
||||||
merge_max_total_chars: int = 6,
|
merge_max_total_chars: int = 6,
|
||||||
low_freq_repeat: float = 50.0,
|
low_freq_repeat: float = 50.0,
|
||||||
high_freq_repeat: float = 0.1,
|
high_freq_repeat: float = 0.1,
|
||||||
|
data_kwargs: Optional[Dict] = None,
|
||||||
|
target_labels: Optional[Set[int]] = None,
|
||||||
):
|
):
|
||||||
# 频率调整参数 - 幂律平滑方案
|
# 频率调整参数 - 幂律平滑方案
|
||||||
self.min_freq = 109
|
self.min_freq = 109
|
||||||
|
|
@ -88,6 +88,9 @@ class PinyinInputDataset(IterableDataset):
|
||||||
self.merge_max_short_words = merge_max_short_words
|
self.merge_max_short_words = merge_max_short_words
|
||||||
self.merge_max_total_chars = merge_max_total_chars
|
self.merge_max_total_chars = merge_max_total_chars
|
||||||
|
|
||||||
|
self.data_kwargs = data_kwargs or {}
|
||||||
|
self.target_labels = target_labels
|
||||||
|
|
||||||
jieba.initialize()
|
jieba.initialize()
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
|
@ -98,7 +101,9 @@ class PinyinInputDataset(IterableDataset):
|
||||||
self.max_iter_length = max_iter_length
|
self.max_iter_length = max_iter_length
|
||||||
self.max_seq_length = max_seq_length
|
self.max_seq_length = max_seq_length
|
||||||
self.text_field = text_field
|
self.text_field = text_field
|
||||||
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
load_kwargs = {"split": "train", "streaming": True}
|
||||||
|
load_kwargs.update(self.data_kwargs)
|
||||||
|
self.dataset = load_dataset(data_path, **load_kwargs)
|
||||||
self.max_workers = max_workers
|
self.max_workers = max_workers
|
||||||
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
||||||
self.shuffle_buffer_size = shuffle_buffer_size
|
self.shuffle_buffer_size = shuffle_buffer_size
|
||||||
|
|
@ -155,12 +160,13 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 生成对应文本的拼音
|
# 生成对应文本的拼音
|
||||||
def generate_pinyin(self, text: str) -> List[str]:
|
def generate_pinyin(self, text: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
流式处理单条文本,转换为拼音列表。
|
将文本转换为拼音列表。对整段文本调用 lazy_pinyin,
|
||||||
|
利用 errors 回调确保一一对应,对生僻字从 QueryEngine 回退。
|
||||||
|
|
||||||
特性:
|
特性:
|
||||||
1. 严格一一对应:len(result) == len(text)
|
1. 严格一一对应:len(result) == len(text)
|
||||||
2. 高多音字准确率:利用 pypinyin 内部的词语分词能力
|
2. 对 pypinyin 不认识的生僻字,回退到 QueryEngine 最高频读音
|
||||||
3. 高性能:预分配内存,无多余对象创建
|
3. 非汉字字符原样占位
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 输入字符串
|
text: 输入字符串
|
||||||
|
|
@ -171,39 +177,35 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if not text:
|
if not text:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
text_len = len(text)
|
def _fallback(chars):
|
||||||
# 2. 预分配结果列表,初始化占位符。
|
# lazy_pinyin 会把连续无拼音的字符聚合成一个字符串传入,
|
||||||
# 使用 None 或空字符串均可,这里用空字符串方便后续判断
|
# 必须逐字符处理,确保返回列表长度与输入字符数一致。
|
||||||
result: List[str] = [""] * text_len
|
result = []
|
||||||
|
for char in chars:
|
||||||
|
if self.query_engine.is_chinese_char(char):
|
||||||
|
ids = self.query_engine.query_by_char(char, limit=1)
|
||||||
|
if ids:
|
||||||
|
result.append(ids[0][1])
|
||||||
|
else:
|
||||||
|
result.append(char)
|
||||||
|
else:
|
||||||
|
result.append(char)
|
||||||
|
return result
|
||||||
|
|
||||||
# 3. 遍历所有连续汉字片段
|
pinyin_list = lazy_pinyin(text, errors=_fallback)
|
||||||
for match in _HANZI_RE.finditer(text):
|
|
||||||
start_idx = match.start()
|
|
||||||
hanzi_segment = match.group()
|
|
||||||
|
|
||||||
# 4. 核心转换:利用 pypinyin 的分词能力处理该片段
|
# 防御性校验:若长度仍不匹配(极罕见),逐字回退
|
||||||
# style=Style.NORMAL 获取不带声调的拼音
|
if len(pinyin_list) != len(text):
|
||||||
pinyin_list = lazy_pinyin(hanzi_segment)
|
logger.warning(
|
||||||
|
f"pinyin length mismatch: text_len={len(text)}, "
|
||||||
|
f"pinyin_len={len(pinyin_list)}, text={text[:50]!r}"
|
||||||
|
)
|
||||||
|
pinyin_list = []
|
||||||
|
for c in text:
|
||||||
|
result = lazy_pinyin(c, errors=_fallback)
|
||||||
|
pinyin_list.append(result[0] if result else c)
|
||||||
|
|
||||||
# 5. 健壮性兜底:
|
return pinyin_list
|
||||||
# 正常情况下,pypinyin 返回的拼音数应等于汉字数。
|
|
||||||
# 若不等(极罕见,如遇到特殊 Unicode 标点被误判为汉字),降级为单字转换
|
|
||||||
if len(pinyin_list) != len(hanzi_segment):
|
|
||||||
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
|
|
||||||
|
|
||||||
# 6. 直接通过索引填充到预分配的位置
|
|
||||||
# 这比 list slicing assignment (result[start:end] = pinyin_list) 略快且更直观
|
|
||||||
for i, py in enumerate(pinyin_list):
|
|
||||||
result[start_idx + i] = py
|
|
||||||
|
|
||||||
# 7. 填充非汉字字符
|
|
||||||
# 遍历原文,如果 result 对应位置为空,则填入原字符
|
|
||||||
# 注意:对于纯汉字文本,这一步很快;对于混合文本,这是必要的
|
|
||||||
for i, char in enumerate(text):
|
|
||||||
if not result[i]:
|
|
||||||
result[i] = char
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_mask_pinyin(
|
def get_mask_pinyin(
|
||||||
self, text: str, pinyin_list: List[str]
|
self, text: str, pinyin_list: List[str]
|
||||||
|
|
@ -243,51 +245,61 @@ class PinyinInputDataset(IterableDataset):
|
||||||
pinyin_ids = pinyin_ids[:24]
|
pinyin_ids = pinyin_ids[:24]
|
||||||
return torch.tensor(pinyin_ids, dtype=torch.long)
|
return torch.tensor(pinyin_ids, dtype=torch.long)
|
||||||
|
|
||||||
def _add_word_samples(
|
def _build_single_sample(
|
||||||
self,
|
self,
|
||||||
batch_samples: list,
|
label: int,
|
||||||
labels: list,
|
history: list,
|
||||||
encoded: dict,
|
text: str,
|
||||||
part4: str,
|
word_start: int,
|
||||||
part1: str,
|
word_end: int,
|
||||||
part3: str,
|
part2: str,
|
||||||
pinyin_str: str,
|
|
||||||
pinyin_ids: torch.Tensor,
|
pinyin_ids: torch.Tensor,
|
||||||
) -> list:
|
words: list,
|
||||||
for label_idx, label in enumerate(labels):
|
) -> dict:
|
||||||
base_repeats = self.adjust_frequency(self.sample_freqs.get(label, 0))
|
"""构造单条样本,每次调用都会重新随机采样上下文"""
|
||||||
if base_repeats == 0:
|
|
||||||
continue
|
|
||||||
weight = (
|
|
||||||
self._history_weights[label_idx]
|
|
||||||
if label_idx < len(self._history_weights)
|
|
||||||
else 3.0
|
|
||||||
)
|
|
||||||
repeats = max(1, int(base_repeats * weight))
|
|
||||||
|
|
||||||
history = labels[:label_idx]
|
# part1 长度:高斯分布 N(36, 6^2),截断 [0, min(48, word_start)]
|
||||||
if len(history) > 8:
|
part1_len = min(max(int(random.gauss(36, 6)), 0), 48, word_start)
|
||||||
history = history[-8:]
|
part1 = text[word_start - part1_len : word_start]
|
||||||
else:
|
|
||||||
history.extend([0] * (8 - len(history)))
|
|
||||||
|
|
||||||
sample_dict = {
|
# part3:每次重新 roll
|
||||||
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
part3 = ""
|
||||||
"token_type_ids": torch.tensor(
|
if random.random() > 0.7:
|
||||||
encoded["token_type_ids"], dtype=torch.long
|
part3 = text[word_end : word_end + random.randint(1, 16)]
|
||||||
),
|
|
||||||
"attention_mask": torch.tensor(
|
# part4:每次重新 roll
|
||||||
encoded["attention_mask"], dtype=torch.long
|
part4 = ""
|
||||||
),
|
if random.random() > 0.7 and words:
|
||||||
"label": torch.tensor([label], dtype=torch.long),
|
num_words = random.randint(1, 3)
|
||||||
"history_slot_ids": torch.tensor(history, dtype=torch.long),
|
selected_words = random.sample(words, min(num_words, len(words)))
|
||||||
"prefix": f"{part4}^{part1}",
|
part4 = "|".join(selected_words)
|
||||||
"suffix": part3,
|
|
||||||
"pinyin": pinyin_str,
|
encoded = self.tokenizer(
|
||||||
"pinyin_ids": pinyin_ids,
|
f"{part4}|{part1}",
|
||||||
}
|
part3,
|
||||||
batch_samples.extend([sample_dict] * repeats)
|
max_length=self.max_seq_length,
|
||||||
return batch_samples
|
truncation=True,
|
||||||
|
return_token_type_ids=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 确保 history 长度为 8
|
||||||
|
hist = list(history)
|
||||||
|
if len(hist) > 8:
|
||||||
|
hist = hist[-8:]
|
||||||
|
while len(hist) < 8:
|
||||||
|
hist.append(0)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
||||||
|
"token_type_ids": torch.tensor(encoded["token_type_ids"], dtype=torch.long),
|
||||||
|
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
|
||||||
|
"label": torch.tensor([label], dtype=torch.long),
|
||||||
|
"history_slot_ids": torch.tensor(hist, dtype=torch.long),
|
||||||
|
"prefix": f"{part4}^{part1}",
|
||||||
|
"suffix": part3,
|
||||||
|
"pinyin": part2,
|
||||||
|
"pinyin_ids": pinyin_ids,
|
||||||
|
}
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
|
@ -436,42 +448,43 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if not should_break and random.random() <= 0.1:
|
if not should_break and random.random() <= 0.1:
|
||||||
labels.append(0)
|
labels.append(0)
|
||||||
|
|
||||||
# part1: 词起点前的文本(所有样本共享)
|
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
|
||||||
part1 = text[max(0, word_start - 48) : word_start]
|
processed_history = []
|
||||||
|
for label_idx, label in enumerate(labels):
|
||||||
|
base_repeats = self.adjust_frequency(
|
||||||
|
self.sample_freqs.get(label, 0)
|
||||||
|
)
|
||||||
|
if base_repeats == 0:
|
||||||
|
processed_history.append(label)
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
self.target_labels is not None
|
||||||
|
and label not in self.target_labels
|
||||||
|
):
|
||||||
|
processed_history.append(label)
|
||||||
|
continue
|
||||||
|
|
||||||
# part3: 词后文本
|
weight = (
|
||||||
part3 = ""
|
self._history_weights[label_idx]
|
||||||
if random.random() > 0.7:
|
if label_idx < len(self._history_weights)
|
||||||
part3 = text[word_end : word_end + random.randint(1, 16)]
|
else 3.0
|
||||||
|
)
|
||||||
|
repeats = max(1, int(base_repeats * weight))
|
||||||
|
|
||||||
# part4: 词提示
|
for _ in range(repeats):
|
||||||
part4 = ""
|
sample = self._build_single_sample(
|
||||||
if random.random() > 0.7:
|
label=label,
|
||||||
num_words = random.randint(1, 3)
|
history=processed_history,
|
||||||
if words:
|
text=text,
|
||||||
selected_words = random.sample(
|
word_start=word_start,
|
||||||
words, min(num_words, len(words))
|
word_end=word_end,
|
||||||
|
part2=part2,
|
||||||
|
pinyin_ids=pinyin_ids,
|
||||||
|
words=words,
|
||||||
)
|
)
|
||||||
part4 = "|".join(selected_words)
|
batch_samples.append(sample)
|
||||||
|
|
||||||
encoded = self.tokenizer(
|
processed_history.append(label)
|
||||||
f"{part4}|{part1}",
|
|
||||||
part3,
|
|
||||||
max_length=self.max_seq_length,
|
|
||||||
truncation=True,
|
|
||||||
return_token_type_ids=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_samples = self._add_word_samples(
|
|
||||||
batch_samples,
|
|
||||||
labels,
|
|
||||||
encoded,
|
|
||||||
part4,
|
|
||||||
part1,
|
|
||||||
part3,
|
|
||||||
part2,
|
|
||||||
pinyin_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ========== Phase 2: 破词续接 ==========
|
# ========== Phase 2: 破词续接 ==========
|
||||||
if should_break and break_pos < word_len_chars:
|
if should_break and break_pos < word_len_chars:
|
||||||
|
|
@ -533,33 +546,44 @@ class PinyinInputDataset(IterableDataset):
|
||||||
if random.random() <= 0.1:
|
if random.random() <= 0.1:
|
||||||
cont_labels.append(0)
|
cont_labels.append(0)
|
||||||
|
|
||||||
# part1_cont: 包含已确认前缀的上下文
|
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
|
||||||
part1_cont = text[max(0, cont_start - 48) : cont_start]
|
cont_processed_history = []
|
||||||
|
|
||||||
# part3_cont: 续接目标后的文本
|
|
||||||
cont_end = cont_positions[-1] + 1
|
cont_end = cont_positions[-1] + 1
|
||||||
part3_cont = ""
|
for label_idx, label in enumerate(cont_labels):
|
||||||
if random.random() > 0.7:
|
base_repeats = self.adjust_frequency(
|
||||||
part3_cont = text[cont_end : cont_end + random.randint(1, 16)]
|
self.sample_freqs.get(label, 0)
|
||||||
|
)
|
||||||
|
if base_repeats == 0:
|
||||||
|
cont_processed_history.append(label)
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
self.target_labels is not None
|
||||||
|
and label not in self.target_labels
|
||||||
|
):
|
||||||
|
cont_processed_history.append(label)
|
||||||
|
continue
|
||||||
|
|
||||||
encoded_cont = self.tokenizer(
|
weight = (
|
||||||
f"{part4}|{part1_cont}",
|
self._history_weights[label_idx]
|
||||||
part3_cont,
|
if label_idx < len(self._history_weights)
|
||||||
max_length=self.max_seq_length,
|
else 3.0
|
||||||
truncation=True,
|
)
|
||||||
return_token_type_ids=True,
|
repeats = max(1, int(base_repeats * weight))
|
||||||
)
|
|
||||||
|
|
||||||
batch_samples = self._add_word_samples(
|
for _ in range(repeats):
|
||||||
batch_samples,
|
sample = self._build_single_sample(
|
||||||
cont_labels,
|
label=label,
|
||||||
encoded_cont,
|
history=cont_processed_history,
|
||||||
part4,
|
text=text,
|
||||||
part1_cont,
|
word_start=cont_start,
|
||||||
part3_cont,
|
word_end=cont_end,
|
||||||
part2_cont,
|
part2=part2_cont,
|
||||||
pinyin_ids_cont,
|
pinyin_ids=pinyin_ids_cont,
|
||||||
)
|
words=words,
|
||||||
|
)
|
||||||
|
batch_samples.append(sample)
|
||||||
|
|
||||||
|
cont_processed_history.append(label)
|
||||||
|
|
||||||
idx = merge_end_idx
|
idx = merge_end_idx
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
|
|
||||||
步骤 1: find-missing — 扫描已预处理数据,找出从未出现的 label ID,输出 JSON
|
步骤 1: find-missing — 扫描已预处理数据,找出从未出现的 label ID,输出 JSON
|
||||||
步骤 2: generate-template — 根据 JSON 生成 JSONL 占位文件,供用户手动填入包含缺失字的真实文本
|
步骤 2: generate-template — 根据 JSON 生成 JSONL 占位文件,供用户手动填入包含缺失字的真实文本
|
||||||
|
步骤 3: preprocess-supplement — 将填好的 JSONL 文本预处理为 .npz 分片,输出到独立目录
|
||||||
|
|
||||||
用法:
|
用法:
|
||||||
python -m model.supplement_missing find-missing \
|
python -m model.supplement_missing find-missing \
|
||||||
|
|
@ -13,6 +14,12 @@
|
||||||
python -m model.supplement_missing generate-template \
|
python -m model.supplement_missing generate-template \
|
||||||
--missing-chars missing_chars.json \
|
--missing-chars missing_chars.json \
|
||||||
--output supplement_texts.jsonl
|
--output supplement_texts.jsonl
|
||||||
|
|
||||||
|
python -m model.supplement_missing preprocess-supplement \
|
||||||
|
--missing-chars missing_chars.json \
|
||||||
|
--supplement-texts supplement_texts.jsonl \
|
||||||
|
--output-dir ./preprocessed/supplement \
|
||||||
|
--num-samples 100000
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
@ -21,12 +28,17 @@ from pathlib import Path
|
||||||
from typing import Set
|
from typing import Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .dataset import PinyinInputDataset
|
||||||
|
from .preprocess import collect_samples
|
||||||
from .query import QueryEngine
|
from .query import QueryEngine
|
||||||
|
from .trainer import preprocess_collate_fn, worker_init_fn
|
||||||
|
|
||||||
|
|
||||||
def scan_labels(preprocessed_dir: Path) -> Set[int]:
|
def scan_labels(preprocessed_dir: Path) -> Set[int]:
|
||||||
|
|
@ -175,14 +187,116 @@ def cmd_generate_template(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_preprocess_supplement(args):
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
# 加载缺失字符
|
||||||
|
missing_path = Path(args.missing_chars)
|
||||||
|
if not missing_path.exists():
|
||||||
|
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(missing_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
missing_chars = data.get("missing_chars", [])
|
||||||
|
if not missing_chars:
|
||||||
|
console.print("[bold green]没有缺失字符,无需处理[/bold green]")
|
||||||
|
return
|
||||||
|
|
||||||
|
target_labels = {entry["id"] for entry in missing_chars}
|
||||||
|
target_labels.add(0) # 包含 EOS
|
||||||
|
|
||||||
|
# 解析参数
|
||||||
|
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
|
||||||
|
length_weights = {
|
||||||
|
int(k): int(v)
|
||||||
|
for k, v in (item.split(":") for item in args.length_weights.split(","))
|
||||||
|
}
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
max_iter = args.num_samples * 5
|
||||||
|
num_workers = args.num_workers
|
||||||
|
|
||||||
|
console.print("[bold cyan]=== 补充数据预处理 ===[/bold cyan]")
|
||||||
|
console.print(f"补充文本: {args.supplement_texts}")
|
||||||
|
console.print(f"缺失字符数: {len(missing_chars)}")
|
||||||
|
console.print(f"目标样本: {args.num_samples:,}")
|
||||||
|
console.print(f"输出目录: {output_dir}")
|
||||||
|
console.print(f"Worker 数: {num_workers}")
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
|
console.print("[bold cyan]创建补充数据集...[/bold cyan]")
|
||||||
|
dataset = PinyinInputDataset(
|
||||||
|
data_path="json",
|
||||||
|
max_workers=num_workers,
|
||||||
|
max_iter_length=max_iter,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
text_field="text",
|
||||||
|
py_style_weight=py_style_weight,
|
||||||
|
shuffle_buffer_size=100,
|
||||||
|
length_weights=length_weights,
|
||||||
|
data_kwargs={
|
||||||
|
"data_files": args.supplement_texts,
|
||||||
|
"streaming": False,
|
||||||
|
},
|
||||||
|
target_labels=target_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader_kwargs = {
|
||||||
|
"batch_size": args.batch_size,
|
||||||
|
"num_workers": num_workers,
|
||||||
|
"pin_memory": False,
|
||||||
|
"worker_init_fn": worker_init_fn,
|
||||||
|
"collate_fn": preprocess_collate_fn(args.max_seq_length),
|
||||||
|
}
|
||||||
|
if num_workers > 0:
|
||||||
|
dataloader_kwargs["prefetch_factor"] = 2
|
||||||
|
dataloader_kwargs["persistent_workers"] = True
|
||||||
|
|
||||||
|
dataloader = DataLoader(dataset, **dataloader_kwargs)
|
||||||
|
|
||||||
|
logger.info("开始收集补充数据...")
|
||||||
|
count = collect_samples(
|
||||||
|
dataloader,
|
||||||
|
args.num_samples,
|
||||||
|
output_dir,
|
||||||
|
"supplement",
|
||||||
|
args.max_seq_length,
|
||||||
|
args.shard_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if count < args.num_samples:
|
||||||
|
logger.warning(f"补充样本不足: 目标 {args.num_samples}, 实际 {count}")
|
||||||
|
|
||||||
|
console.print("\n[bold green]=== 补充预处理完成 ===[/bold green]")
|
||||||
|
console.print(f"生成样本: {count:,}")
|
||||||
|
console.print(f"输出目录: {output_dir}")
|
||||||
|
|
||||||
|
total_size = sum(
|
||||||
|
f.stat().st_size for f in output_dir.iterdir() if f.suffix == ".npz"
|
||||||
|
)
|
||||||
|
console.print(f"总大小: {total_size / (1024**3):.2f} GB (compressed)")
|
||||||
|
console.print()
|
||||||
|
console.print(
|
||||||
|
"[bold yellow]提示[/bold yellow]: 请检查补充数据质量,清洗无误后手动将 shard_*.npz 合并到 train/ 目录并更新 metadata.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="缺失字符补充工具",
|
description="缺失字符补充工具",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
epilog="""
|
epilog="""
|
||||||
子命令:
|
子命令:
|
||||||
find-missing 扫描已预处理数据,找出从未出现的 label ID
|
find-missing 扫描已预处理数据,找出从未出现的 label ID
|
||||||
generate-template 根据缺失字符 JSON 生成 JSONL 占位文件
|
generate-template 根据缺失字符 JSON 生成 JSONL 占位文件
|
||||||
|
preprocess-supplement 将填好的 JSONL 预处理为 .npz 分片(独立目录)
|
||||||
|
|
||||||
示例:
|
示例:
|
||||||
python -m model.supplement_missing find-missing \\
|
python -m model.supplement_missing find-missing \\
|
||||||
|
|
@ -192,6 +306,12 @@ def main():
|
||||||
python -m model.supplement_missing generate-template \\
|
python -m model.supplement_missing generate-template \\
|
||||||
--missing-chars missing_chars.json \\
|
--missing-chars missing_chars.json \\
|
||||||
--output supplement_texts.jsonl
|
--output supplement_texts.jsonl
|
||||||
|
|
||||||
|
python -m model.supplement_missing preprocess-supplement \\
|
||||||
|
--missing-chars missing_chars.json \\
|
||||||
|
--supplement-texts supplement_texts.jsonl \\
|
||||||
|
--output-dir ./preprocessed/supplement \\
|
||||||
|
--num-samples 100000
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
||||||
|
|
@ -232,6 +352,77 @@ def main():
|
||||||
help="每个缺失字符生成的模板条数(默认: 3)",
|
help="每个缺失字符生成的模板条数(默认: 3)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# preprocess-supplement
|
||||||
|
p_pre = subparsers.add_parser(
|
||||||
|
"preprocess-supplement", help="将 JSONL 预处理为 .npz 分片"
|
||||||
|
)
|
||||||
|
p_pre.add_argument(
|
||||||
|
"--missing-chars",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
|
||||||
|
)
|
||||||
|
p_pre.add_argument(
|
||||||
|
"--supplement-texts",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="已填写的补充文本 JSONL 文件路径",
|
||||||
|
)
|
||||||
|
p_pre.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="输出目录(独立目录,不会覆盖已有数据)",
|
||||||
|
)
|
||||||
|
p_pre.add_argument(
|
||||||
|
"--num-samples",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="目标样本数量",
|
||||||
|
)
|
||||||
|
p_pre.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help="批大小(默认: 128)",
|
||||||
|
)
|
||||||
|
p_pre.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="DataLoader worker 数量。本地 JSONL 小文件建议 0(默认: 0)",
|
||||||
|
)
|
||||||
|
p_pre.add_argument(
|
||||||
|
"--max-seq-length",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help="最大序列长度(默认: 128)",
|
||||||
|
)
|
||||||
|
p_pre.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="随机种子(默认: 42)",
|
||||||
|
)
|
||||||
|
p_pre.add_argument(
|
||||||
|
"--shard-size",
|
||||||
|
type=int,
|
||||||
|
default=5_000_000,
|
||||||
|
help="分片大小(样本数),控制内存峰值(默认: 500万)",
|
||||||
|
)
|
||||||
|
p_pre.add_argument(
|
||||||
|
"--py-style-weight",
|
||||||
|
type=str,
|
||||||
|
default="9,2,1",
|
||||||
|
help="拼音风格权重(逗号分隔,默认: 9,2,1)",
|
||||||
|
)
|
||||||
|
p_pre.add_argument(
|
||||||
|
"--length-weights",
|
||||||
|
type=str,
|
||||||
|
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
|
||||||
|
help="词长权重(默认: 1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.command is None:
|
if args.command is None:
|
||||||
|
|
@ -242,6 +433,8 @@ def main():
|
||||||
cmd_find_missing(args)
|
cmd_find_missing(args)
|
||||||
elif args.command == "generate-template":
|
elif args.command == "generate-template":
|
||||||
cmd_generate_template(args)
|
cmd_generate_template(args)
|
||||||
|
elif args.command == "preprocess-supplement":
|
||||||
|
cmd_preprocess_supplement(args)
|
||||||
|
|
||||||
|
|
||||||
app = main
|
app = main
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append("src")
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -26,7 +27,7 @@ from pypinyin.contrib.tone_convert import to_initials
|
||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer"
|
Path(__file__).parent.parent / "src" / "model" / "assets" / "tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||||
|
|
@ -83,7 +84,9 @@ sample = {
|
||||||
|
|
||||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||||
|
|
||||||
checkpoint = torch.load("/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu")
|
checkpoint = torch.load(
|
||||||
|
"/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu"
|
||||||
|
)
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
|
@ -100,7 +103,7 @@ for k, v in sample.items():
|
||||||
start = time.time()
|
start = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||||
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
|
print(f"计算时长: {(time.time() - start) * 1000:4f}ms")
|
||||||
sort_res = sorted(
|
sort_res = sorted(
|
||||||
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
|
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
|
||||||
)
|
)
|
||||||
|
|
@ -392,7 +392,13 @@ def check_compile_issues():
|
||||||
issues = []
|
issues = []
|
||||||
|
|
||||||
# 检查 components.py 中的潜在问题
|
# 检查 components.py 中的潜在问题
|
||||||
with open("src/model/components.py", "r") as f:
|
components_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||||
|
"src",
|
||||||
|
"model",
|
||||||
|
"components.py",
|
||||||
|
)
|
||||||
|
with open(components_path, "r") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# 检查 float('-inf')
|
# 检查 float('-inf')
|
||||||
|
|
@ -4,9 +4,13 @@
|
||||||
解决设备转换和权重加载问题
|
解决设备转换和权重加载问题
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 添加项目根目录到路径
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -196,7 +200,7 @@ def test_id_mapping():
|
||||||
|
|
||||||
query_engine = QueryEngine()
|
query_engine = QueryEngine()
|
||||||
stats_path = (
|
stats_path = (
|
||||||
Path(__file__).parent
|
Path(__file__).parent.parent
|
||||||
/ "src"
|
/ "src"
|
||||||
/ "model"
|
/ "model"
|
||||||
/ "assets"
|
/ "assets"
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append("src")
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
@ -42,23 +43,49 @@ def worker_init_fn(worker_id: int) -> None:
|
||||||
|
|
||||||
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
自定义批处理函数,将多个样本组合成一个batch
|
自定义批处理函数,将多个样本组合成一个batch。
|
||||||
|
支持动态padding:根据batch内最大序列长度进行padding。
|
||||||
Args:
|
|
||||||
batch: 样本列表,每个样本是一个字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
批处理后的字典,tensor字段已stack,字符串字段保持为列表
|
|
||||||
"""
|
"""
|
||||||
# 处理tensor字段 - 使用squeeze去除多余的batch维度
|
input_ids_list = [item["input_ids"] for item in batch]
|
||||||
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
|
token_type_ids_list = [item["token_type_ids"] for item in batch]
|
||||||
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
|
attention_mask_list = [item["attention_mask"] for item in batch]
|
||||||
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
|
|
||||||
|
target_len = max(ids.shape[0] for ids in input_ids_list)
|
||||||
|
|
||||||
|
padded_input_ids = []
|
||||||
|
padded_token_type_ids = []
|
||||||
|
padded_attention_mask = []
|
||||||
|
for ids, tt_ids, mask in zip(
|
||||||
|
input_ids_list, token_type_ids_list, attention_mask_list
|
||||||
|
):
|
||||||
|
seq_len = ids.shape[0]
|
||||||
|
if seq_len < target_len:
|
||||||
|
pad_len = target_len - seq_len
|
||||||
|
padded_input_ids.append(
|
||||||
|
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
|
||||||
|
)
|
||||||
|
padded_token_type_ids.append(
|
||||||
|
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
|
||||||
|
)
|
||||||
|
padded_attention_mask.append(
|
||||||
|
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
|
||||||
|
)
|
||||||
|
elif seq_len > target_len:
|
||||||
|
padded_input_ids.append(ids[:target_len])
|
||||||
|
padded_token_type_ids.append(tt_ids[:target_len])
|
||||||
|
padded_attention_mask.append(mask[:target_len])
|
||||||
|
else:
|
||||||
|
padded_input_ids.append(ids)
|
||||||
|
padded_token_type_ids.append(tt_ids)
|
||||||
|
padded_attention_mask.append(mask)
|
||||||
|
|
||||||
|
input_ids = torch.stack(padded_input_ids)
|
||||||
|
token_type_ids = torch.stack(padded_token_type_ids)
|
||||||
|
attention_mask = torch.stack(padded_attention_mask)
|
||||||
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
||||||
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
||||||
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
||||||
|
|
||||||
# 字符串字段保持为列表
|
|
||||||
prefixes = [item["prefix"] for item in batch]
|
prefixes = [item["prefix"] for item in batch]
|
||||||
suffixes = [item["suffix"] for item in batch]
|
suffixes = [item["suffix"] for item in batch]
|
||||||
pinyins = [item["pinyin"] for item in batch]
|
pinyins = [item["pinyin"] for item in batch]
|
||||||
|
|
@ -96,6 +123,5 @@ dataloader = DataLoader(
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, shape in tqdm(enumerate(dataloader), total=1000000/512):
|
for i, shape in tqdm(enumerate(dataloader), total=1000000 / 512):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -6,8 +6,8 @@ import torch.nn as nn
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
# 添加src目录到路径
|
# 添加项目根目录到路径
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
from src.model.model import InputMethodEngine
|
from src.model.model import InputMethodEngine
|
||||||
from src.model.trainer import Trainer
|
from src.model.trainer import Trainer
|
||||||
Loading…
Reference in New Issue