refactor(generate_pinyin): 优化拼音生成逻辑,利用 pypinyin 分词能力处理多音字
This commit is contained in:
parent
8b41bcdc6f
commit
e8eab1f260
33
eval.py
33
eval.py
|
|
@ -14,7 +14,6 @@ eval.py - 评估模型在给定文本上的表现
|
|||
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
|
@ -33,8 +32,6 @@ from src.model.model import InputMethodEngine
|
|||
from src.model.query import QueryEngine
|
||||
from src.model.dataset import text_to_pinyin_ids
|
||||
|
||||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||
|
||||
|
||||
class TextEvaluator:
|
||||
def __init__(
|
||||
|
|
@ -171,34 +168,20 @@ class TextEvaluator:
|
|||
|
||||
def generate_pinyin(self, text: str) -> List[str]:
|
||||
"""
|
||||
流式处理单条文本,转换为拼音列表。
|
||||
参考dataset.py中的generate_pinyin方法。
|
||||
将文本转换为拼音列表。对整段文本调用 lazy_pinyin,
|
||||
利用 pypinyin 内部的分词能力处理多音字。
|
||||
参考 dataset.py 中的 generate_pinyin 方法。
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
text_len = len(text)
|
||||
result: List[str] = [""] * text_len
|
||||
pinyin_list = lazy_pinyin(text)
|
||||
|
||||
# 遍历所有连续汉字片段
|
||||
for match in _HANZI_RE.finditer(text):
|
||||
start_idx = match.start()
|
||||
hanzi_segment = match.group()
|
||||
# 健壮性兜底:若长度不匹配(极罕见),降级为逐字转换
|
||||
if len(pinyin_list) != len(text):
|
||||
pinyin_list = [lazy_pinyin(c)[0] for c in text]
|
||||
|
||||
pinyin_list = lazy_pinyin(hanzi_segment)
|
||||
|
||||
if len(pinyin_list) != len(hanzi_segment):
|
||||
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
|
||||
|
||||
for i, py in enumerate(pinyin_list):
|
||||
result[start_idx + i] = py
|
||||
|
||||
# 填充非汉字字符
|
||||
for i, char in enumerate(text):
|
||||
if not result[i]:
|
||||
result[i] = char
|
||||
|
||||
return result
|
||||
return pinyin_list
|
||||
|
||||
def get_mask_pinyin(
|
||||
self, text: str, pinyin_list: List[str]
|
||||
|
|
|
|||
|
|
@ -10,15 +10,22 @@ import math
|
|||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main():
|
||||
# Path to the JSON file
|
||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
||||
json_path = (
|
||||
Path(__file__).parent.parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
if not json_path.exists():
|
||||
print(f"Error: File not found: {json_path}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loading {json_path}...")
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
print(f"Timestamp: {data.get('timestamp')}")
|
||||
|
|
@ -26,7 +33,7 @@ def main():
|
|||
print(f"Total pinyins: {data.get('total_pinyins')}")
|
||||
print(f"Valid input character count: {data.get('valid_input_character_count')}")
|
||||
|
||||
pairs = data.get('pairs', {})
|
||||
pairs = data.get("pairs", {})
|
||||
print(f"Number of pairs: {len(pairs)}")
|
||||
|
||||
# Extract counts and IDs
|
||||
|
|
@ -35,9 +42,9 @@ def main():
|
|||
char_to_count = {}
|
||||
for key, pair in pairs.items():
|
||||
try:
|
||||
char_id = pair.get('id')
|
||||
count = pair.get('count')
|
||||
char = pair.get('char', '')
|
||||
char_id = pair.get("id")
|
||||
count = pair.get("count")
|
||||
char = pair.get("char", "")
|
||||
if char_id is not None and count is not None:
|
||||
counts.append(count)
|
||||
id_to_count[char_id] = count
|
||||
|
|
@ -109,16 +116,16 @@ def main():
|
|||
decreases = 0
|
||||
increases = 0
|
||||
for i in range(1, len(counts_by_id)):
|
||||
if counts_by_id[i] < counts_by_id[i-1]:
|
||||
if counts_by_id[i] < counts_by_id[i - 1]:
|
||||
decreases += 1
|
||||
elif counts_by_id[i] > counts_by_id[i-1]:
|
||||
elif counts_by_id[i] > counts_by_id[i - 1]:
|
||||
increases += 1
|
||||
|
||||
print(f"\n=== ID ORDER ANALYSIS ===")
|
||||
print(f"Total pairs: {len(counts_by_id)}")
|
||||
print(f"Decreases as ID increases: {decreases} times")
|
||||
print(f"Increases as ID increases: {increases} times")
|
||||
print(f"Percentage decreasing: {decreases/(len(counts_by_id)-1)*100:.2f}%")
|
||||
print(f"Percentage decreasing: {decreases / (len(counts_by_id) - 1) * 100:.2f}%")
|
||||
|
||||
# Check if IDs are roughly sorted by frequency
|
||||
# Compute Spearman rank correlation (simplified)
|
||||
|
|
@ -131,7 +138,9 @@ def main():
|
|||
avg_rank_diff = sum(rank_diffs) / len(rank_diffs)
|
||||
max_rank_diff = max(rank_diffs)
|
||||
|
||||
print(f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}")
|
||||
print(
|
||||
f"Average rank difference between ID order and frequency order: {avg_rank_diff:.2f}"
|
||||
)
|
||||
print(f"Maximum rank difference: {max_rank_diff}")
|
||||
|
||||
# Analyze specific ID range 5000-5500
|
||||
|
|
@ -163,7 +172,9 @@ def main():
|
|||
|
||||
# Find IDs with min frequency in this range
|
||||
min_in_range_ids = [id for id in range_ids if id_to_count[id] == range_min]
|
||||
print(f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}")
|
||||
print(
|
||||
f"IDs with min frequency in range: {min_in_range_ids[:10]}{'...' if len(min_in_range_ids) > 10 else ''}"
|
||||
)
|
||||
else:
|
||||
print("No IDs found in range 5000-5500")
|
||||
|
||||
|
|
@ -173,20 +184,26 @@ def main():
|
|||
log_min = math.log10(min_count) if min_count > 0 else 0
|
||||
log_max = math.log10(max_count)
|
||||
num_bins = 20
|
||||
bin_edges = [10**(log_min + i*(log_max-log_min)/num_bins) for i in range(num_bins+1)]
|
||||
bin_edges = [
|
||||
10 ** (log_min + i * (log_max - log_min) / num_bins)
|
||||
for i in range(num_bins + 1)
|
||||
]
|
||||
|
||||
hist = [0] * num_bins
|
||||
for count in counts:
|
||||
if count > 0:
|
||||
log_val = math.log10(count)
|
||||
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * num_bins), num_bins-1)
|
||||
bin_idx = min(
|
||||
int((log_val - log_min) / (log_max - log_min) * num_bins),
|
||||
num_bins - 1,
|
||||
)
|
||||
hist[bin_idx] += 1
|
||||
|
||||
print("Log-scale histogram (count range -> frequency count):")
|
||||
for i in range(num_bins):
|
||||
if hist[i] > 0:
|
||||
lower = bin_edges[i]
|
||||
upper = bin_edges[i+1]
|
||||
upper = bin_edges[i + 1]
|
||||
print(f" {lower:.2e} - {upper:.2e}: {hist[i]} entries")
|
||||
|
||||
# Check for zero or near-zero frequencies
|
||||
|
|
@ -204,22 +221,30 @@ def main():
|
|||
if non_zero_counts:
|
||||
actual_min = min(non_zero_counts)
|
||||
print(f"Actual min frequency (non-zero): {actual_min}")
|
||||
actual_min_ids = [id for id, count in id_to_count.items() if count == actual_min]
|
||||
print(f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}")
|
||||
actual_min_ids = [
|
||||
id for id, count in id_to_count.items() if count == actual_min
|
||||
]
|
||||
print(
|
||||
f"IDs with actual min frequency: {actual_min_ids[:10]}{'...' if len(actual_min_ids) > 10 else ''}"
|
||||
)
|
||||
|
||||
# Summary for smoothing algorithm design
|
||||
print("\n=== SUMMARY FOR SMOOTHING ALGORITHM DESIGN ===")
|
||||
print(f"Frequency range spans {max_count/min_count if min_count>0 else 'inf'}:1 ratio")
|
||||
print(
|
||||
f"Frequency range spans {max_count / min_count if min_count > 0 else 'inf'}:1 ratio"
|
||||
)
|
||||
print(f"Most entries ({p50}) have frequency around {p50}")
|
||||
print(f"Top 10% of entries have frequency > {p90}")
|
||||
print(f"Bottom 10% of entries have frequency < {p10}")
|
||||
print(f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency")
|
||||
print(
|
||||
f"ID order is {'roughly' if decreases > increases else 'not'} sorted by frequency"
|
||||
)
|
||||
|
||||
# Save detailed data for further analysis
|
||||
output_file = "frequency_analysis_results.txt"
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write("Frequency Analysis Results\n")
|
||||
f.write("="*50 + "\n")
|
||||
f.write("=" * 50 + "\n")
|
||||
f.write(f"Min frequency: {min_count}\n")
|
||||
f.write(f"Max frequency: {max_count}\n")
|
||||
f.write(f"Mean frequency: {mean_count:.2f}\n")
|
||||
|
|
@ -227,10 +252,15 @@ def main():
|
|||
f.write(f"10th percentile: {p10}\n")
|
||||
f.write(f"50th percentile: {p50}\n")
|
||||
f.write(f"90th percentile: {p90}\n")
|
||||
f.write(f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n")
|
||||
f.write(f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n")
|
||||
f.write(
|
||||
f"IDs in range 5000-5500 min: {range_min if 'range_min' in locals() else 'N/A'}\n"
|
||||
)
|
||||
f.write(
|
||||
f"IDs in range 5000-5500 max: {range_max if 'range_max' in locals() else 'N/A'}\n"
|
||||
)
|
||||
|
||||
print(f"\nDetailed results saved to {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -7,18 +7,25 @@ import json
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main():
|
||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
json_path = (
|
||||
Path(__file__).parent.parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
pairs = data.get('pairs', {})
|
||||
pairs = data.get("pairs", {})
|
||||
|
||||
# Build ID to count mapping
|
||||
id_to_count = {}
|
||||
for key, pair in pairs.items():
|
||||
char_id = pair.get('id')
|
||||
count = pair.get('count')
|
||||
char_id = pair.get("id")
|
||||
count = pair.get("count")
|
||||
if char_id is not None and count is not None:
|
||||
id_to_count[char_id] = count
|
||||
|
||||
|
|
@ -31,10 +38,10 @@ def main():
|
|||
if id in id_to_count:
|
||||
# Find the pair to get char and pinyin
|
||||
for key, pair in pairs.items():
|
||||
if pair.get('id') == id:
|
||||
char = pair.get('char', '')
|
||||
pinyin = pair.get('pinyin', '')
|
||||
count = pair.get('count', 0)
|
||||
if pair.get("id") == id:
|
||||
char = pair.get("char", "")
|
||||
pinyin = pair.get("pinyin", "")
|
||||
count = pair.get("count", 0)
|
||||
range_data.append((id, count, char, pinyin))
|
||||
if id % 100 == 0: # Print every 100th for overview
|
||||
print(f"{id}\t{count}\t{char}\t{pinyin}")
|
||||
|
|
@ -44,17 +51,22 @@ def main():
|
|||
if range_data:
|
||||
min_item = min(range_data, key=lambda x: x[1])
|
||||
max_item = max(range_data, key=lambda x: x[1])
|
||||
print(f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'")
|
||||
print(f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'")
|
||||
print(
|
||||
f"\nMin in range: ID {min_item[0]}, count {min_item[1]}, char '{min_item[2]}', pinyin '{min_item[3]}'"
|
||||
)
|
||||
print(
|
||||
f"Max in range: ID {max_item[0]}, count {max_item[1]}, char '{max_item[2]}', pinyin '{max_item[3]}'"
|
||||
)
|
||||
|
||||
# Check if frequencies are monotonic in this range
|
||||
counts = [item[1] for item in range_data]
|
||||
increasing = all(counts[i] <= counts[i+1] for i in range(len(counts)-1))
|
||||
decreasing = all(counts[i] >= counts[i+1] for i in range(len(counts)-1))
|
||||
increasing = all(counts[i] <= counts[i + 1] for i in range(len(counts) - 1))
|
||||
decreasing = all(counts[i] >= counts[i + 1] for i in range(len(counts) - 1))
|
||||
print(f"Monotonic in range: increasing={increasing}, decreasing={decreasing}")
|
||||
|
||||
# Check for frequency plateaus
|
||||
from collections import Counter
|
||||
|
||||
freq_count = Counter(counts)
|
||||
most_common = freq_count.most_common(5)
|
||||
print(f"Most common frequencies in range: {most_common}")
|
||||
|
|
@ -70,15 +82,19 @@ def main():
|
|||
|
||||
# Check if they're contiguous
|
||||
sorted_ids = sorted(freq_one_ids)
|
||||
contiguous = all(sorted_ids[i] + 1 == sorted_ids[i+1] for i in range(len(sorted_ids)-1))
|
||||
contiguous = all(
|
||||
sorted_ids[i] + 1 == sorted_ids[i + 1] for i in range(len(sorted_ids) - 1)
|
||||
)
|
||||
print(f"Are they contiguous IDs? {contiguous}")
|
||||
|
||||
# Sample some characters
|
||||
print("\nSample characters with frequency=1:")
|
||||
sample_count = 0
|
||||
for key, pair in pairs.items():
|
||||
if pair.get('count') == 1 and sample_count < 10:
|
||||
print(f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'")
|
||||
if pair.get("count") == 1 and sample_count < 10:
|
||||
print(
|
||||
f" ID {pair.get('id')}: char '{pair.get('char')}', pinyin '{pair.get('pinyin')}'"
|
||||
)
|
||||
sample_count += 1
|
||||
|
||||
# Check overall ID-frequency ordering
|
||||
|
|
@ -90,7 +106,7 @@ def main():
|
|||
non_increasing_segments = 0
|
||||
current_segment_length = 1
|
||||
for i in range(1, len(all_counts)):
|
||||
if all_counts[i] <= all_counts[i-1]:
|
||||
if all_counts[i] <= all_counts[i - 1]:
|
||||
current_segment_length += 1
|
||||
else:
|
||||
if current_segment_length > 1:
|
||||
|
|
@ -104,12 +120,16 @@ def main():
|
|||
|
||||
# Check for frequency plateaus overall
|
||||
from collections import Counter
|
||||
|
||||
overall_freq_count = Counter(all_counts)
|
||||
plateaus = [(freq, count) for freq, count in overall_freq_count.items() if count > 1]
|
||||
plateaus = [
|
||||
(freq, count) for freq, count in overall_freq_count.items() if count > 1
|
||||
]
|
||||
plateaus_sorted = sorted(plateaus, key=lambda x: x[1], reverse=True)[:10]
|
||||
print(f"Top 10 frequency plateaus (freq: count of IDs sharing that freq):")
|
||||
for freq, count in plateaus_sorted:
|
||||
print(f" {freq}: {count} IDs")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,10 +1,15 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
from model.dataset import PinyinInputDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from model.trainer import collate_fn, worker_init_fn
|
||||
|
||||
|
||||
data = PinyinInputDataset('/home/songsenand/Data/corpus/CCI-Data/')
|
||||
data = PinyinInputDataset("/home/songsenand/Data/corpus/CCI-Data/")
|
||||
|
||||
dataloader = DataLoader(
|
||||
data,
|
||||
|
|
@ -18,5 +23,5 @@ dataloader = DataLoader(
|
|||
)
|
||||
|
||||
for i in dataloader:
|
||||
print((i['labels'] == 1).sum())
|
||||
print((i["labels"] == 1).sum())
|
||||
break
|
||||
|
|
@ -9,17 +9,24 @@ import math
|
|||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main():
|
||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
json_path = (
|
||||
Path(__file__).parent.parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
pairs = data.get('pairs', {})
|
||||
pairs = data.get("pairs", {})
|
||||
|
||||
# Extract counts
|
||||
counts = []
|
||||
for key, pair in pairs.items():
|
||||
count = pair.get('count')
|
||||
count = pair.get("count")
|
||||
if count is not None:
|
||||
counts.append(count)
|
||||
|
||||
|
|
@ -40,25 +47,53 @@ def main():
|
|||
for p in percentiles:
|
||||
idx = int(p * n)
|
||||
value = counts_sorted_desc[idx]
|
||||
print(f"{p*100:5.1f}%: {value:>12} (rank ~{idx})")
|
||||
print(f"{p * 100:5.1f}%: {value:>12} (rank ~{idx})")
|
||||
|
||||
# Cumulative distribution
|
||||
print("\n=== CUMULATIVE DISTRIBUTION ===")
|
||||
thresholds = [1, 2, 3, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 500000, 1000000, 5000000, 10000000, 50000000, 100000000, 500000000]
|
||||
thresholds = [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
5,
|
||||
10,
|
||||
20,
|
||||
50,
|
||||
100,
|
||||
200,
|
||||
500,
|
||||
1000,
|
||||
2000,
|
||||
5000,
|
||||
10000,
|
||||
20000,
|
||||
50000,
|
||||
100000,
|
||||
200000,
|
||||
500000,
|
||||
1000000,
|
||||
5000000,
|
||||
10000000,
|
||||
50000000,
|
||||
100000000,
|
||||
500000000,
|
||||
]
|
||||
for thresh in thresholds:
|
||||
if thresh > max_count:
|
||||
break
|
||||
below = sum(1 for c in counts if c <= thresh)
|
||||
above = sum(1 for c in counts if c >= thresh)
|
||||
print(f"Count <= {thresh:10}: {below:6} entries ({below/n*100:5.1f}%)")
|
||||
print(f"Count <= {thresh:10}: {below:6} entries ({below / n * 100:5.1f}%)")
|
||||
# print(f"Count >= {thresh:10}: {above:6} entries ({above/n*100:5.1f}%)")
|
||||
|
||||
# Check min_count=109 parameter
|
||||
print("\n=== ANALYSIS OF THRESHOLD 109 ===")
|
||||
below_109 = sum(1 for c in counts if c < 109)
|
||||
at_or_above_109 = sum(1 for c in counts if c >= 109)
|
||||
print(f"Entries with count < 109: {below_109} ({below_109/n*100:.1f}%)")
|
||||
print(f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109/n*100:.1f}%)")
|
||||
print(f"Entries with count < 109: {below_109} ({below_109 / n * 100:.1f}%)")
|
||||
print(
|
||||
f"Entries with count >= 109: {at_or_above_109} ({at_or_above_109 / n * 100:.1f}%)"
|
||||
)
|
||||
|
||||
# If 109 is a threshold, what's the actual min among those >= 109?
|
||||
counts_ge_109 = [c for c in counts if c >= 109]
|
||||
|
|
@ -70,7 +105,7 @@ def main():
|
|||
print("\n=== RANK-FREQUENCY ANALYSIS (Top 100) ===")
|
||||
print("Rank\tFrequency\tlog(rank)\tlog(freq)")
|
||||
for rank in range(1, 101):
|
||||
freq = counts_sorted_desc[rank-1]
|
||||
freq = counts_sorted_desc[rank - 1]
|
||||
print(f"{rank}\t{freq}\t{math.log(rank):.3f}\t{math.log(freq):.3f}")
|
||||
|
||||
# Frequency spectrum (how many distinct frequencies)
|
||||
|
|
@ -88,8 +123,8 @@ def main():
|
|||
# Build ID to count mapping
|
||||
id_to_count = {}
|
||||
for key, pair in pairs.items():
|
||||
char_id = pair.get('id')
|
||||
count = pair.get('count')
|
||||
char_id = pair.get("id")
|
||||
count = pair.get("count")
|
||||
if char_id is not None and count is not None:
|
||||
id_to_count[char_id] = count
|
||||
|
||||
|
|
@ -108,13 +143,17 @@ def main():
|
|||
]
|
||||
|
||||
for start, end, label in ranges:
|
||||
range_counts = [id_to_count[id] for id in range(start, end) if id in id_to_count]
|
||||
range_counts = [
|
||||
id_to_count[id] for id in range(start, end) if id in id_to_count
|
||||
]
|
||||
if range_counts:
|
||||
min_c = min(range_counts)
|
||||
max_c = max(range_counts)
|
||||
mean_c = sum(range_counts) / len(range_counts)
|
||||
median_c = sorted(range_counts)[len(range_counts)//2]
|
||||
print(f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}")
|
||||
median_c = sorted(range_counts)[len(range_counts) // 2]
|
||||
print(
|
||||
f"{label} ({len(range_counts)} entries): min={min_c}, max={max_c}, mean={mean_c:.1f}, median={median_c}"
|
||||
)
|
||||
|
||||
# Check if IDs are perfectly sorted by frequency
|
||||
print("\n=== ID ORDER VERIFICATION ===")
|
||||
|
|
@ -124,10 +163,12 @@ def main():
|
|||
# Check for any violations of non-increasing order
|
||||
violations = 0
|
||||
for i in range(1, len(all_counts)):
|
||||
if all_counts[i] > all_counts[i-1]:
|
||||
if all_counts[i] > all_counts[i - 1]:
|
||||
violations += 1
|
||||
if violations <= 5:
|
||||
print(f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i-1]} (ID {all_ids[i-1]})")
|
||||
print(
|
||||
f"Violation at ID {all_ids[i]}: {all_counts[i]} > {all_counts[i - 1]} (ID {all_ids[i - 1]})"
|
||||
)
|
||||
|
||||
print(f"Total violations of non-increasing order: {violations}")
|
||||
|
||||
|
|
@ -140,13 +181,17 @@ def main():
|
|||
for i, (id, count) in enumerate(zip(all_ids, all_counts)):
|
||||
if count != current_freq:
|
||||
if current_freq is not None:
|
||||
group_sizes.append((current_freq, group_start, all_ids[i-1], i - group_start))
|
||||
group_sizes.append(
|
||||
(current_freq, group_start, all_ids[i - 1], i - group_start)
|
||||
)
|
||||
current_freq = count
|
||||
group_start = i
|
||||
|
||||
# Last group
|
||||
if current_freq is not None:
|
||||
group_sizes.append((current_freq, group_start, all_ids[-1], len(all_ids) - group_start))
|
||||
group_sizes.append(
|
||||
(current_freq, group_start, all_ids[-1], len(all_ids) - group_start)
|
||||
)
|
||||
|
||||
# Sort groups by size
|
||||
group_sizes.sort(key=lambda x: x[3], reverse=True)
|
||||
|
|
@ -158,11 +203,15 @@ def main():
|
|||
# Summary for smoothing algorithm
|
||||
print("\n=== SMOOTHING ALGORITHM IMPLICATIONS ===")
|
||||
print("1. IDs are perfectly sorted by frequency (non-increasing).")
|
||||
print(f"2. Frequency range: {min_count} to {max_count} (ratio {max_count/min_count:.1e}:1).")
|
||||
print(f"3. {below_109} entries ({below_109/n*100:.1f}%) have frequency < 109.")
|
||||
print(f"4. Median frequency: {counts_sorted_desc[n//2]}.")
|
||||
print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9*n)]}.")
|
||||
print(f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01*n)]}.")
|
||||
print(
|
||||
f"2. Frequency range: {min_count} to {max_count} (ratio {max_count / min_count:.1e}:1)."
|
||||
)
|
||||
print(f"3. {below_109} entries ({below_109 / n * 100:.1f}%) have frequency < 109.")
|
||||
print(f"4. Median frequency: {counts_sorted_desc[n // 2]}.")
|
||||
print(f"5. 90% of entries have frequency <= {counts_sorted_desc[int(0.9 * n)]}.")
|
||||
print(
|
||||
f"6. Top 1% of entries have frequency >= {counts_sorted_desc[int(0.01 * n)]}."
|
||||
)
|
||||
print("7. Large frequency plateaus exist (many IDs share same frequency).")
|
||||
print("8. Smoothing should handle extreme frequency ratios (1:5e8).")
|
||||
|
||||
|
|
@ -173,5 +222,6 @@ def main():
|
|||
f.write(f"{rank},{freq}\n")
|
||||
print("\nRank-frequency data saved to rank_freq.csv")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -3,6 +3,7 @@ from pathlib import Path
|
|||
from typing import Dict, Any
|
||||
import sys
|
||||
|
||||
|
||||
def modify_pinyin_statistics(file_path: Path) -> None:
|
||||
"""
|
||||
一次性修改拼音统计JSON文件。
|
||||
|
|
@ -13,7 +14,7 @@ def modify_pinyin_statistics(file_path: Path) -> None:
|
|||
"""
|
||||
# 1. 加载原数据
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data: Dict[str, Any] = json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"错误:文件不存在 {file_path}", file=sys.stderr)
|
||||
|
|
@ -34,7 +35,7 @@ def modify_pinyin_statistics(file_path: Path) -> None:
|
|||
"id": 0,
|
||||
"char": "",
|
||||
"pinyin": "",
|
||||
"count": original_zero_count + 1 # 原count + 1
|
||||
"count": original_zero_count + 1, # 原count + 1
|
||||
}
|
||||
|
||||
# 3.2 处理其他所有记录,键和id都+1
|
||||
|
|
@ -54,15 +55,16 @@ def modify_pinyin_statistics(file_path: Path) -> None:
|
|||
# 这里保持原时间戳不变,因为是一次性修改
|
||||
|
||||
# 写回文件,保持可读格式
|
||||
backup_path = file_path.with_suffix('.json.bak')
|
||||
backup_path = file_path.with_suffix(".json.bak")
|
||||
try:
|
||||
# 先备份原文件
|
||||
import shutil
|
||||
|
||||
shutil.copy2(file_path, backup_path)
|
||||
print(f"已创建备份: {backup_path}")
|
||||
|
||||
# 写入新数据
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"修改完成!")
|
||||
|
|
@ -78,15 +80,21 @@ def modify_pinyin_statistics(file_path: Path) -> None:
|
|||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 假设你的JSON文件在当前目录
|
||||
json_file = Path("./src/model/assets/pinyin_char_statistics.json")
|
||||
# JSON文件相对于项目根目录
|
||||
json_file = (
|
||||
Path(__file__).parent.parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
|
||||
# 执行修改
|
||||
modify_pinyin_statistics(json_file)
|
||||
|
||||
# 验证修改:读取并显示前几条记录
|
||||
print("\n验证前5条记录:")
|
||||
with open(json_file, 'r', encoding='utf-8') as f:
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for i in range(5):
|
||||
|
|
@ -8,6 +8,7 @@ import math
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def ascii_histogram(data, bins=20, width=60):
|
||||
"""Create ASCII histogram"""
|
||||
if not data:
|
||||
|
|
@ -20,43 +21,56 @@ def ascii_histogram(data, bins=20, width=60):
|
|||
if max_val / min_val > 1000:
|
||||
log_min = math.log10(min_val) if min_val > 0 else 0
|
||||
log_max = math.log10(max_val)
|
||||
bin_edges = [10**(log_min + i*(log_max-log_min)/bins) for i in range(bins+1)]
|
||||
bin_edges = [
|
||||
10 ** (log_min + i * (log_max - log_min) / bins) for i in range(bins + 1)
|
||||
]
|
||||
hist = [0] * bins
|
||||
for val in data:
|
||||
if val > 0:
|
||||
log_val = math.log10(val)
|
||||
bin_idx = min(int((log_val - log_min) / (log_max - log_min) * bins), bins-1)
|
||||
bin_idx = min(
|
||||
int((log_val - log_min) / (log_max - log_min) * bins), bins - 1
|
||||
)
|
||||
hist[bin_idx] += 1
|
||||
bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i+1]:.1e}" for i in range(bins)]
|
||||
bin_labels = [f"{bin_edges[i]:.1e}-{bin_edges[i + 1]:.1e}" for i in range(bins)]
|
||||
else:
|
||||
bin_width = (max_val - min_val) / bins
|
||||
bin_edges = [min_val + i*bin_width for i in range(bins+1)]
|
||||
bin_edges = [min_val + i * bin_width for i in range(bins + 1)]
|
||||
hist = [0] * bins
|
||||
for val in data:
|
||||
bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins-1)
|
||||
bin_idx = min(int((val - min_val) / (max_val - min_val) * bins), bins - 1)
|
||||
hist[bin_idx] += 1
|
||||
bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i+1]:.1f}" for i in range(bins)]
|
||||
bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i + 1]:.1f}" for i in range(bins)]
|
||||
|
||||
max_count = max(hist)
|
||||
result = []
|
||||
for i in range(bins):
|
||||
if hist[i] == 0:
|
||||
continue
|
||||
bar = '#' * int(hist[i] / max_count * width)
|
||||
bar = "#" * int(hist[i] / max_count * width)
|
||||
result.append(f"{bin_labels[i]:20} | {bar} {hist[i]}")
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def main():
|
||||
json_path = Path("src/model/assets/pinyin_char_statistics.json")
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
json_path = (
|
||||
Path(__file__).parent.parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
/ "pinyin_char_statistics.json"
|
||||
)
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
pairs = data.get('pairs', {})
|
||||
counts = [pair.get('count', 0) for pair in pairs.values() if pair.get('count') is not None]
|
||||
pairs = data.get("pairs", {})
|
||||
counts = [
|
||||
pair.get("count", 0) for pair in pairs.values() if pair.get("count") is not None
|
||||
]
|
||||
|
||||
print("FREQUENCY DISTRIBUTION ANALYSIS")
|
||||
print("="*60)
|
||||
print("=" * 60)
|
||||
|
||||
print("\n1. ASCII Histogram (log bins):")
|
||||
print(ascii_histogram(counts, bins=20, width=60))
|
||||
|
|
@ -68,9 +82,9 @@ def main():
|
|||
max_rank = 50
|
||||
|
||||
for rank in range(1, max_rank + 1):
|
||||
freq = counts_sorted_desc[rank-1]
|
||||
freq = counts_sorted_desc[rank - 1]
|
||||
bar_length = int(math.log(freq) / math.log(max_freq) * 40)
|
||||
bar = '#' * bar_length
|
||||
bar = "#" * bar_length
|
||||
print(f"Rank {rank:3}: {freq:12} {bar}")
|
||||
|
||||
# ID vs Frequency plot (sampled)
|
||||
|
|
@ -78,8 +92,8 @@ def main():
|
|||
# Build ID to count mapping
|
||||
id_to_count = {}
|
||||
for key, pair in pairs.items():
|
||||
char_id = pair.get('id')
|
||||
count = pair.get('count')
|
||||
char_id = pair.get("id")
|
||||
count = pair.get("count")
|
||||
if char_id is not None and count is not None:
|
||||
id_to_count[char_id] = count
|
||||
|
||||
|
|
@ -91,7 +105,7 @@ def main():
|
|||
if id in id_to_count:
|
||||
freq = id_to_count[id]
|
||||
log_freq = math.log10(freq) if freq > 0 else 0
|
||||
bar = '#' * int(log_freq / math.log10(max_freq) * 40)
|
||||
bar = "#" * int(log_freq / math.log10(max_freq) * 40)
|
||||
print(f"{id:6} {freq:10} {log_freq:6.2f} {bar}")
|
||||
|
||||
# Zipf's law fit
|
||||
|
|
@ -99,19 +113,22 @@ def main():
|
|||
print(" Rank * Frequency ≈ constant for Zipf's law")
|
||||
print(" Top 10 ranks:")
|
||||
for rank in range(1, 11):
|
||||
freq = counts_sorted_desc[rank-1]
|
||||
freq = counts_sorted_desc[rank - 1]
|
||||
product = rank * freq
|
||||
print(f" Rank {rank}: {freq:12} rank*freq = {product:.3e}")
|
||||
|
||||
# Check if product is roughly constant
|
||||
products = [(rank+1) * counts_sorted_desc[rank] for rank in range(10)]
|
||||
products = [(rank + 1) * counts_sorted_desc[rank] for rank in range(10)]
|
||||
avg_product = sum(products) / len(products)
|
||||
std_product = math.sqrt(sum((p - avg_product)**2 for p in products) / len(products))
|
||||
std_product = math.sqrt(
|
||||
sum((p - avg_product) ** 2 for p in products) / len(products)
|
||||
)
|
||||
print(f" Average product (ranks 2-11): {avg_product:.3e} ± {std_product:.3e}")
|
||||
print(f" Coefficient of variation: {std_product/avg_product*100:.1f}%")
|
||||
print(f" Coefficient of variation: {std_product / avg_product * 100:.1f}%")
|
||||
|
||||
# Frequency spectrum
|
||||
from collections import Counter
|
||||
|
||||
freq_counter = Counter(counts)
|
||||
print("\n5. Frequency Spectrum (how many entries have each frequency):")
|
||||
print(" Frequency Count Cumulative")
|
||||
|
|
@ -127,13 +144,13 @@ def main():
|
|||
print(f" Total entries: {n}")
|
||||
print(f" Min frequency: {min(counts)}")
|
||||
print(f" Max frequency: {max(counts)}")
|
||||
print(f" Ratio max/min: {max(counts)/min(counts):.2e}")
|
||||
print(f" Ratio max/min: {max(counts) / min(counts):.2e}")
|
||||
|
||||
percentiles = [0.01, 0.1, 0.5, 0.9, 0.99]
|
||||
for p in percentiles:
|
||||
idx = int(p * n)
|
||||
value = counts_sorted_desc[idx]
|
||||
print(f" {p*100:5.1f}th percentile: {value:12} (rank ~{idx})")
|
||||
print(f" {p * 100:5.1f}th percentile: {value:12} (rank ~{idx})")
|
||||
|
||||
# Save data for external plotting
|
||||
with open("id_vs_freq.csv", "w") as f:
|
||||
|
|
@ -142,5 +159,6 @@ def main():
|
|||
f.write(f"{id},{id_to_count[id]}\n")
|
||||
print("\nData saved to id_vs_freq.csv for external plotting")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -5,10 +5,9 @@ warnings.filterwarnings("ignore", message=".*pkg_resources.*")
|
|||
import jieba
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -21,7 +20,6 @@ from torch.utils.data import IterableDataset
|
|||
|
||||
from .query import QueryEngine
|
||||
|
||||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||
|
||||
CHAR_TO_ID: Dict[str, int] = {chr(i): i - 96 for i in range(97, 123)} # a-z -> 1-26
|
||||
CHAR_TO_ID["`"] = 27 # 显式添加反引号
|
||||
|
|
@ -76,6 +74,8 @@ class PinyinInputDataset(IterableDataset):
|
|||
merge_max_total_chars: int = 6,
|
||||
low_freq_repeat: float = 50.0,
|
||||
high_freq_repeat: float = 0.1,
|
||||
data_kwargs: Optional[Dict] = None,
|
||||
target_labels: Optional[Set[int]] = None,
|
||||
):
|
||||
# 频率调整参数 - 幂律平滑方案
|
||||
self.min_freq = 109
|
||||
|
|
@ -88,6 +88,9 @@ class PinyinInputDataset(IterableDataset):
|
|||
self.merge_max_short_words = merge_max_short_words
|
||||
self.merge_max_total_chars = merge_max_total_chars
|
||||
|
||||
self.data_kwargs = data_kwargs or {}
|
||||
self.target_labels = target_labels
|
||||
|
||||
jieba.initialize()
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
|
@ -98,7 +101,9 @@ class PinyinInputDataset(IterableDataset):
|
|||
self.max_iter_length = max_iter_length
|
||||
self.max_seq_length = max_seq_length
|
||||
self.text_field = text_field
|
||||
self.dataset = load_dataset(data_path, split="train", streaming=True)
|
||||
load_kwargs = {"split": "train", "streaming": True}
|
||||
load_kwargs.update(self.data_kwargs)
|
||||
self.dataset = load_dataset(data_path, **load_kwargs)
|
||||
self.max_workers = max_workers
|
||||
self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight)
|
||||
self.shuffle_buffer_size = shuffle_buffer_size
|
||||
|
|
@ -155,12 +160,13 @@ class PinyinInputDataset(IterableDataset):
|
|||
# 生成对应文本的拼音
|
||||
def generate_pinyin(self, text: str) -> List[str]:
|
||||
"""
|
||||
流式处理单条文本,转换为拼音列表。
|
||||
将文本转换为拼音列表。对整段文本调用 lazy_pinyin,
|
||||
利用 errors 回调确保一一对应,对生僻字从 QueryEngine 回退。
|
||||
|
||||
特性:
|
||||
1. 严格一一对应:len(result) == len(text)
|
||||
2. 高多音字准确率:利用 pypinyin 内部的词语分词能力
|
||||
3. 高性能:预分配内存,无多余对象创建
|
||||
2. 对 pypinyin 不认识的生僻字,回退到 QueryEngine 最高频读音
|
||||
3. 非汉字字符原样占位
|
||||
|
||||
Args:
|
||||
text: 输入字符串
|
||||
|
|
@ -171,40 +177,36 @@ class PinyinInputDataset(IterableDataset):
|
|||
if not text:
|
||||
return []
|
||||
|
||||
text_len = len(text)
|
||||
# 2. 预分配结果列表,初始化占位符。
|
||||
# 使用 None 或空字符串均可,这里用空字符串方便后续判断
|
||||
result: List[str] = [""] * text_len
|
||||
|
||||
# 3. 遍历所有连续汉字片段
|
||||
for match in _HANZI_RE.finditer(text):
|
||||
start_idx = match.start()
|
||||
hanzi_segment = match.group()
|
||||
|
||||
# 4. 核心转换:利用 pypinyin 的分词能力处理该片段
|
||||
# style=Style.NORMAL 获取不带声调的拼音
|
||||
pinyin_list = lazy_pinyin(hanzi_segment)
|
||||
|
||||
# 5. 健壮性兜底:
|
||||
# 正常情况下,pypinyin 返回的拼音数应等于汉字数。
|
||||
# 若不等(极罕见,如遇到特殊 Unicode 标点被误判为汉字),降级为单字转换
|
||||
if len(pinyin_list) != len(hanzi_segment):
|
||||
pinyin_list = [lazy_pinyin(c)[0] for c in hanzi_segment]
|
||||
|
||||
# 6. 直接通过索引填充到预分配的位置
|
||||
# 这比 list slicing assignment (result[start:end] = pinyin_list) 略快且更直观
|
||||
for i, py in enumerate(pinyin_list):
|
||||
result[start_idx + i] = py
|
||||
|
||||
# 7. 填充非汉字字符
|
||||
# 遍历原文,如果 result 对应位置为空,则填入原字符
|
||||
# 注意:对于纯汉字文本,这一步很快;对于混合文本,这是必要的
|
||||
for i, char in enumerate(text):
|
||||
if not result[i]:
|
||||
result[i] = char
|
||||
|
||||
def _fallback(chars):
|
||||
# lazy_pinyin 会把连续无拼音的字符聚合成一个字符串传入,
|
||||
# 必须逐字符处理,确保返回列表长度与输入字符数一致。
|
||||
result = []
|
||||
for char in chars:
|
||||
if self.query_engine.is_chinese_char(char):
|
||||
ids = self.query_engine.query_by_char(char, limit=1)
|
||||
if ids:
|
||||
result.append(ids[0][1])
|
||||
else:
|
||||
result.append(char)
|
||||
else:
|
||||
result.append(char)
|
||||
return result
|
||||
|
||||
pinyin_list = lazy_pinyin(text, errors=_fallback)
|
||||
|
||||
# 防御性校验:若长度仍不匹配(极罕见),逐字回退
|
||||
if len(pinyin_list) != len(text):
|
||||
logger.warning(
|
||||
f"pinyin length mismatch: text_len={len(text)}, "
|
||||
f"pinyin_len={len(pinyin_list)}, text={text[:50]!r}"
|
||||
)
|
||||
pinyin_list = []
|
||||
for c in text:
|
||||
result = lazy_pinyin(c, errors=_fallback)
|
||||
pinyin_list.append(result[0] if result else c)
|
||||
|
||||
return pinyin_list
|
||||
|
||||
def get_mask_pinyin(
|
||||
self, text: str, pinyin_list: List[str]
|
||||
) -> Tuple[int, List[str]]:
|
||||
|
|
@ -243,51 +245,61 @@ class PinyinInputDataset(IterableDataset):
|
|||
pinyin_ids = pinyin_ids[:24]
|
||||
return torch.tensor(pinyin_ids, dtype=torch.long)
|
||||
|
||||
def _add_word_samples(
|
||||
def _build_single_sample(
|
||||
self,
|
||||
batch_samples: list,
|
||||
labels: list,
|
||||
encoded: dict,
|
||||
part4: str,
|
||||
part1: str,
|
||||
part3: str,
|
||||
pinyin_str: str,
|
||||
label: int,
|
||||
history: list,
|
||||
text: str,
|
||||
word_start: int,
|
||||
word_end: int,
|
||||
part2: str,
|
||||
pinyin_ids: torch.Tensor,
|
||||
) -> list:
|
||||
for label_idx, label in enumerate(labels):
|
||||
base_repeats = self.adjust_frequency(self.sample_freqs.get(label, 0))
|
||||
if base_repeats == 0:
|
||||
continue
|
||||
weight = (
|
||||
self._history_weights[label_idx]
|
||||
if label_idx < len(self._history_weights)
|
||||
else 3.0
|
||||
words: list,
|
||||
) -> dict:
|
||||
"""构造单条样本,每次调用都会重新随机采样上下文"""
|
||||
|
||||
# part1 长度:高斯分布 N(36, 6^2),截断 [0, min(48, word_start)]
|
||||
part1_len = min(max(int(random.gauss(36, 6)), 0), 48, word_start)
|
||||
part1 = text[word_start - part1_len : word_start]
|
||||
|
||||
# part3:每次重新 roll
|
||||
part3 = ""
|
||||
if random.random() > 0.7:
|
||||
part3 = text[word_end : word_end + random.randint(1, 16)]
|
||||
|
||||
# part4:每次重新 roll
|
||||
part4 = ""
|
||||
if random.random() > 0.7 and words:
|
||||
num_words = random.randint(1, 3)
|
||||
selected_words = random.sample(words, min(num_words, len(words)))
|
||||
part4 = "|".join(selected_words)
|
||||
|
||||
encoded = self.tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
part3,
|
||||
max_length=self.max_seq_length,
|
||||
truncation=True,
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
repeats = max(1, int(base_repeats * weight))
|
||||
|
||||
history = labels[:label_idx]
|
||||
if len(history) > 8:
|
||||
history = history[-8:]
|
||||
else:
|
||||
history.extend([0] * (8 - len(history)))
|
||||
# 确保 history 长度为 8
|
||||
hist = list(history)
|
||||
if len(hist) > 8:
|
||||
hist = hist[-8:]
|
||||
while len(hist) < 8:
|
||||
hist.append(0)
|
||||
|
||||
sample_dict = {
|
||||
return {
|
||||
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
||||
"token_type_ids": torch.tensor(
|
||||
encoded["token_type_ids"], dtype=torch.long
|
||||
),
|
||||
"attention_mask": torch.tensor(
|
||||
encoded["attention_mask"], dtype=torch.long
|
||||
),
|
||||
"token_type_ids": torch.tensor(encoded["token_type_ids"], dtype=torch.long),
|
||||
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
|
||||
"label": torch.tensor([label], dtype=torch.long),
|
||||
"history_slot_ids": torch.tensor(history, dtype=torch.long),
|
||||
"history_slot_ids": torch.tensor(hist, dtype=torch.long),
|
||||
"prefix": f"{part4}^{part1}",
|
||||
"suffix": part3,
|
||||
"pinyin": pinyin_str,
|
||||
"pinyin": part2,
|
||||
"pinyin_ids": pinyin_ids,
|
||||
}
|
||||
batch_samples.extend([sample_dict] * repeats)
|
||||
return batch_samples
|
||||
|
||||
def __iter__(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
|
@ -436,42 +448,43 @@ class PinyinInputDataset(IterableDataset):
|
|||
if not should_break and random.random() <= 0.1:
|
||||
labels.append(0)
|
||||
|
||||
# part1: 词起点前的文本(所有样本共享)
|
||||
part1 = text[max(0, word_start - 48) : word_start]
|
||||
|
||||
# part3: 词后文本
|
||||
part3 = ""
|
||||
if random.random() > 0.7:
|
||||
part3 = text[word_end : word_end + random.randint(1, 16)]
|
||||
|
||||
# part4: 词提示
|
||||
part4 = ""
|
||||
if random.random() > 0.7:
|
||||
num_words = random.randint(1, 3)
|
||||
if words:
|
||||
selected_words = random.sample(
|
||||
words, min(num_words, len(words))
|
||||
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
|
||||
processed_history = []
|
||||
for label_idx, label in enumerate(labels):
|
||||
base_repeats = self.adjust_frequency(
|
||||
self.sample_freqs.get(label, 0)
|
||||
)
|
||||
part4 = "|".join(selected_words)
|
||||
if base_repeats == 0:
|
||||
processed_history.append(label)
|
||||
continue
|
||||
if (
|
||||
self.target_labels is not None
|
||||
and label not in self.target_labels
|
||||
):
|
||||
processed_history.append(label)
|
||||
continue
|
||||
|
||||
encoded = self.tokenizer(
|
||||
f"{part4}|{part1}",
|
||||
part3,
|
||||
max_length=self.max_seq_length,
|
||||
truncation=True,
|
||||
return_token_type_ids=True,
|
||||
weight = (
|
||||
self._history_weights[label_idx]
|
||||
if label_idx < len(self._history_weights)
|
||||
else 3.0
|
||||
)
|
||||
repeats = max(1, int(base_repeats * weight))
|
||||
|
||||
batch_samples = self._add_word_samples(
|
||||
batch_samples,
|
||||
labels,
|
||||
encoded,
|
||||
part4,
|
||||
part1,
|
||||
part3,
|
||||
part2,
|
||||
pinyin_ids,
|
||||
for _ in range(repeats):
|
||||
sample = self._build_single_sample(
|
||||
label=label,
|
||||
history=processed_history,
|
||||
text=text,
|
||||
word_start=word_start,
|
||||
word_end=word_end,
|
||||
part2=part2,
|
||||
pinyin_ids=pinyin_ids,
|
||||
words=words,
|
||||
)
|
||||
batch_samples.append(sample)
|
||||
|
||||
processed_history.append(label)
|
||||
|
||||
# ========== Phase 2: 破词续接 ==========
|
||||
if should_break and break_pos < word_len_chars:
|
||||
|
|
@ -533,33 +546,44 @@ class PinyinInputDataset(IterableDataset):
|
|||
if random.random() <= 0.1:
|
||||
cont_labels.append(0)
|
||||
|
||||
# part1_cont: 包含已确认前缀的上下文
|
||||
part1_cont = text[max(0, cont_start - 48) : cont_start]
|
||||
|
||||
# part3_cont: 续接目标后的文本
|
||||
# 逐个 label 处理,削峰填谷前置,每次重复重新采样上下文
|
||||
cont_processed_history = []
|
||||
cont_end = cont_positions[-1] + 1
|
||||
part3_cont = ""
|
||||
if random.random() > 0.7:
|
||||
part3_cont = text[cont_end : cont_end + random.randint(1, 16)]
|
||||
|
||||
encoded_cont = self.tokenizer(
|
||||
f"{part4}|{part1_cont}",
|
||||
part3_cont,
|
||||
max_length=self.max_seq_length,
|
||||
truncation=True,
|
||||
return_token_type_ids=True,
|
||||
for label_idx, label in enumerate(cont_labels):
|
||||
base_repeats = self.adjust_frequency(
|
||||
self.sample_freqs.get(label, 0)
|
||||
)
|
||||
if base_repeats == 0:
|
||||
cont_processed_history.append(label)
|
||||
continue
|
||||
if (
|
||||
self.target_labels is not None
|
||||
and label not in self.target_labels
|
||||
):
|
||||
cont_processed_history.append(label)
|
||||
continue
|
||||
|
||||
batch_samples = self._add_word_samples(
|
||||
batch_samples,
|
||||
cont_labels,
|
||||
encoded_cont,
|
||||
part4,
|
||||
part1_cont,
|
||||
part3_cont,
|
||||
part2_cont,
|
||||
pinyin_ids_cont,
|
||||
weight = (
|
||||
self._history_weights[label_idx]
|
||||
if label_idx < len(self._history_weights)
|
||||
else 3.0
|
||||
)
|
||||
repeats = max(1, int(base_repeats * weight))
|
||||
|
||||
for _ in range(repeats):
|
||||
sample = self._build_single_sample(
|
||||
label=label,
|
||||
history=cont_processed_history,
|
||||
text=text,
|
||||
word_start=cont_start,
|
||||
word_end=cont_end,
|
||||
part2=part2_cont,
|
||||
pinyin_ids=pinyin_ids_cont,
|
||||
words=words,
|
||||
)
|
||||
batch_samples.append(sample)
|
||||
|
||||
cont_processed_history.append(label)
|
||||
|
||||
idx = merge_end_idx
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
步骤 1: find-missing — 扫描已预处理数据,找出从未出现的 label ID,输出 JSON
|
||||
步骤 2: generate-template — 根据 JSON 生成 JSONL 占位文件,供用户手动填入包含缺失字的真实文本
|
||||
步骤 3: preprocess-supplement — 将填好的 JSONL 文本预处理为 .npz 分片,输出到独立目录
|
||||
|
||||
用法:
|
||||
python -m model.supplement_missing find-missing \
|
||||
|
|
@ -13,6 +14,12 @@
|
|||
python -m model.supplement_missing generate-template \
|
||||
--missing-chars missing_chars.json \
|
||||
--output supplement_texts.jsonl
|
||||
|
||||
python -m model.supplement_missing preprocess-supplement \
|
||||
--missing-chars missing_chars.json \
|
||||
--supplement-texts supplement_texts.jsonl \
|
||||
--output-dir ./preprocessed/supplement \
|
||||
--num-samples 100000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -21,12 +28,17 @@ from pathlib import Path
|
|||
from typing import Set
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .dataset import PinyinInputDataset
|
||||
from .preprocess import collect_samples
|
||||
from .query import QueryEngine
|
||||
from .trainer import preprocess_collate_fn, worker_init_fn
|
||||
|
||||
|
||||
def scan_labels(preprocessed_dir: Path) -> Set[int]:
|
||||
|
|
@ -175,6 +187,107 @@ def cmd_generate_template(args):
|
|||
)
|
||||
|
||||
|
||||
def cmd_preprocess_supplement(args):
|
||||
console = Console()
|
||||
|
||||
# 加载缺失字符
|
||||
missing_path = Path(args.missing_chars)
|
||||
if not missing_path.exists():
|
||||
console.print(f"[bold red]文件不存在: {missing_path}[/bold red]")
|
||||
return
|
||||
|
||||
with open(missing_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
missing_chars = data.get("missing_chars", [])
|
||||
if not missing_chars:
|
||||
console.print("[bold green]没有缺失字符,无需处理[/bold green]")
|
||||
return
|
||||
|
||||
target_labels = {entry["id"] for entry in missing_chars}
|
||||
target_labels.add(0) # 包含 EOS
|
||||
|
||||
# 解析参数
|
||||
py_style_weight = tuple(int(x) for x in args.py_style_weight.split(","))
|
||||
length_weights = {
|
||||
int(k): int(v)
|
||||
for k, v in (item.split(":") for item in args.length_weights.split(","))
|
||||
}
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
max_iter = args.num_samples * 5
|
||||
num_workers = args.num_workers
|
||||
|
||||
console.print("[bold cyan]=== 补充数据预处理 ===[/bold cyan]")
|
||||
console.print(f"补充文本: {args.supplement_texts}")
|
||||
console.print(f"缺失字符数: {len(missing_chars)}")
|
||||
console.print(f"目标样本: {args.num_samples:,}")
|
||||
console.print(f"输出目录: {output_dir}")
|
||||
console.print(f"Worker 数: {num_workers}")
|
||||
console.print()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
console.print("[bold cyan]创建补充数据集...[/bold cyan]")
|
||||
dataset = PinyinInputDataset(
|
||||
data_path="json",
|
||||
max_workers=num_workers,
|
||||
max_iter_length=max_iter,
|
||||
max_seq_length=args.max_seq_length,
|
||||
text_field="text",
|
||||
py_style_weight=py_style_weight,
|
||||
shuffle_buffer_size=100,
|
||||
length_weights=length_weights,
|
||||
data_kwargs={
|
||||
"data_files": args.supplement_texts,
|
||||
"streaming": False,
|
||||
},
|
||||
target_labels=target_labels,
|
||||
)
|
||||
|
||||
dataloader_kwargs = {
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": num_workers,
|
||||
"pin_memory": False,
|
||||
"worker_init_fn": worker_init_fn,
|
||||
"collate_fn": preprocess_collate_fn(args.max_seq_length),
|
||||
}
|
||||
if num_workers > 0:
|
||||
dataloader_kwargs["prefetch_factor"] = 2
|
||||
dataloader_kwargs["persistent_workers"] = True
|
||||
|
||||
dataloader = DataLoader(dataset, **dataloader_kwargs)
|
||||
|
||||
logger.info("开始收集补充数据...")
|
||||
count = collect_samples(
|
||||
dataloader,
|
||||
args.num_samples,
|
||||
output_dir,
|
||||
"supplement",
|
||||
args.max_seq_length,
|
||||
args.shard_size,
|
||||
)
|
||||
|
||||
if count < args.num_samples:
|
||||
logger.warning(f"补充样本不足: 目标 {args.num_samples}, 实际 {count}")
|
||||
|
||||
console.print("\n[bold green]=== 补充预处理完成 ===[/bold green]")
|
||||
console.print(f"生成样本: {count:,}")
|
||||
console.print(f"输出目录: {output_dir}")
|
||||
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in output_dir.iterdir() if f.suffix == ".npz"
|
||||
)
|
||||
console.print(f"总大小: {total_size / (1024**3):.2f} GB (compressed)")
|
||||
console.print()
|
||||
console.print(
|
||||
"[bold yellow]提示[/bold yellow]: 请检查补充数据质量,清洗无误后手动将 shard_*.npz 合并到 train/ 目录并更新 metadata.json"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="缺失字符补充工具",
|
||||
|
|
@ -183,6 +296,7 @@ def main():
|
|||
子命令:
|
||||
find-missing 扫描已预处理数据,找出从未出现的 label ID
|
||||
generate-template 根据缺失字符 JSON 生成 JSONL 占位文件
|
||||
preprocess-supplement 将填好的 JSONL 预处理为 .npz 分片(独立目录)
|
||||
|
||||
示例:
|
||||
python -m model.supplement_missing find-missing \\
|
||||
|
|
@ -192,6 +306,12 @@ def main():
|
|||
python -m model.supplement_missing generate-template \\
|
||||
--missing-chars missing_chars.json \\
|
||||
--output supplement_texts.jsonl
|
||||
|
||||
python -m model.supplement_missing preprocess-supplement \\
|
||||
--missing-chars missing_chars.json \\
|
||||
--supplement-texts supplement_texts.jsonl \\
|
||||
--output-dir ./preprocessed/supplement \\
|
||||
--num-samples 100000
|
||||
""",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
||||
|
|
@ -232,6 +352,77 @@ def main():
|
|||
help="每个缺失字符生成的模板条数(默认: 3)",
|
||||
)
|
||||
|
||||
# preprocess-supplement
|
||||
p_pre = subparsers.add_parser(
|
||||
"preprocess-supplement", help="将 JSONL 预处理为 .npz 分片"
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--missing-chars",
|
||||
type=str,
|
||||
required=True,
|
||||
help="缺失字符 JSON 文件路径(由 find-missing 生成)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--supplement-texts",
|
||||
type=str,
|
||||
required=True,
|
||||
help="已填写的补充文本 JSONL 文件路径",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="输出目录(独立目录,不会覆盖已有数据)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
required=True,
|
||||
help="目标样本数量",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="批大小(默认: 128)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="DataLoader worker 数量。本地 JSONL 小文件建议 0(默认: 0)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--max-seq-length",
|
||||
type=int,
|
||||
default=128,
|
||||
help="最大序列长度(默认: 128)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="随机种子(默认: 42)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--shard-size",
|
||||
type=int,
|
||||
default=5_000_000,
|
||||
help="分片大小(样本数),控制内存峰值(默认: 500万)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--py-style-weight",
|
||||
type=str,
|
||||
default="9,2,1",
|
||||
help="拼音风格权重(逗号分隔,默认: 9,2,1)",
|
||||
)
|
||||
p_pre.add_argument(
|
||||
"--length-weights",
|
||||
type=str,
|
||||
default="1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2",
|
||||
help="词长权重(默认: 1:10,2:50,3:50,4:40,5:15,6:10,7:5,8:2)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command is None:
|
||||
|
|
@ -242,6 +433,8 @@ def main():
|
|||
cmd_find_missing(args)
|
||||
elif args.command == "generate-template":
|
||||
cmd_generate_template(args)
|
||||
elif args.command == "preprocess-supplement":
|
||||
cmd_preprocess_supplement(args)
|
||||
|
||||
|
||||
app = main
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append("src")
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
import time
|
||||
import torch
|
||||
|
|
@ -26,7 +27,7 @@ from pypinyin.contrib.tone_convert import to_initials
|
|||
from torch.utils.data import IterableDataset
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
Path(str(__file__)).parent / "src" / "model" / "assets" / "tokenizer"
|
||||
Path(__file__).parent.parent / "src" / "model" / "assets" / "tokenizer"
|
||||
)
|
||||
|
||||
_HANZI_RE = re.compile(r"[\u4e00-\u9fff]+")
|
||||
|
|
@ -83,7 +84,9 @@ sample = {
|
|||
|
||||
model = InputMethodEngine(pinyin_vocab_size=30, compile=False)
|
||||
|
||||
checkpoint = torch.load("/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu")
|
||||
checkpoint = torch.load(
|
||||
"/home/songsenand/下载/20260412epoch2.ptrom", map_location="cpu"
|
||||
)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
|
|
@ -100,7 +103,7 @@ for k, v in sample.items():
|
|||
start = time.time()
|
||||
with torch.no_grad():
|
||||
res = model(input_ids, token_type_ids, attention_mask, pinyin_ids, history_slot_ids)
|
||||
print(f'计算时长: {(time.time() - start) * 1000:4f}ms')
|
||||
print(f"计算时长: {(time.time() - start) * 1000:4f}ms")
|
||||
sort_res = sorted(
|
||||
[(i, v) for i, v in enumerate(res[0])], key=lambda x: x[1], reverse=True
|
||||
)
|
||||
|
|
@ -392,7 +392,13 @@ def check_compile_issues():
|
|||
issues = []
|
||||
|
||||
# 检查 components.py 中的潜在问题
|
||||
with open("src/model/components.py", "r") as f:
|
||||
components_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"src",
|
||||
"model",
|
||||
"components.py",
|
||||
)
|
||||
with open(components_path, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
# 检查 float('-inf')
|
||||
|
|
@ -4,9 +4,13 @@
|
|||
解决设备转换和权重加载问题
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
|
@ -196,7 +200,7 @@ def test_id_mapping():
|
|||
|
||||
query_engine = QueryEngine()
|
||||
stats_path = (
|
||||
Path(__file__).parent
|
||||
Path(__file__).parent.parent
|
||||
/ "src"
|
||||
/ "model"
|
||||
/ "assets"
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append("src")
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import time
|
||||
|
|
@ -42,23 +43,49 @@ def worker_init_fn(worker_id: int) -> None:
|
|||
|
||||
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
自定义批处理函数,将多个样本组合成一个batch
|
||||
|
||||
Args:
|
||||
batch: 样本列表,每个样本是一个字典
|
||||
|
||||
Returns:
|
||||
批处理后的字典,tensor字段已stack,字符串字段保持为列表
|
||||
自定义批处理函数,将多个样本组合成一个batch。
|
||||
支持动态padding:根据batch内最大序列长度进行padding。
|
||||
"""
|
||||
# 处理tensor字段 - 使用squeeze去除多余的batch维度
|
||||
input_ids = torch.stack([item["input_ids"].squeeze(0) for item in batch])
|
||||
token_type_ids = torch.stack([item["token_type_ids"].squeeze(0) for item in batch])
|
||||
attention_mask = torch.stack([item["attention_mask"].squeeze(0) for item in batch])
|
||||
input_ids_list = [item["input_ids"] for item in batch]
|
||||
token_type_ids_list = [item["token_type_ids"] for item in batch]
|
||||
attention_mask_list = [item["attention_mask"] for item in batch]
|
||||
|
||||
target_len = max(ids.shape[0] for ids in input_ids_list)
|
||||
|
||||
padded_input_ids = []
|
||||
padded_token_type_ids = []
|
||||
padded_attention_mask = []
|
||||
for ids, tt_ids, mask in zip(
|
||||
input_ids_list, token_type_ids_list, attention_mask_list
|
||||
):
|
||||
seq_len = ids.shape[0]
|
||||
if seq_len < target_len:
|
||||
pad_len = target_len - seq_len
|
||||
padded_input_ids.append(
|
||||
torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
|
||||
)
|
||||
padded_token_type_ids.append(
|
||||
torch.cat([tt_ids, torch.zeros(pad_len, dtype=tt_ids.dtype)])
|
||||
)
|
||||
padded_attention_mask.append(
|
||||
torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
|
||||
)
|
||||
elif seq_len > target_len:
|
||||
padded_input_ids.append(ids[:target_len])
|
||||
padded_token_type_ids.append(tt_ids[:target_len])
|
||||
padded_attention_mask.append(mask[:target_len])
|
||||
else:
|
||||
padded_input_ids.append(ids)
|
||||
padded_token_type_ids.append(tt_ids)
|
||||
padded_attention_mask.append(mask)
|
||||
|
||||
input_ids = torch.stack(padded_input_ids)
|
||||
token_type_ids = torch.stack(padded_token_type_ids)
|
||||
attention_mask = torch.stack(padded_attention_mask)
|
||||
labels = torch.stack([item["label"].squeeze(0) for item in batch])
|
||||
history_slot_ids = torch.stack([item["history_slot_ids"] for item in batch])
|
||||
pinyin_ids = torch.stack([item["pinyin_ids"] for item in batch])
|
||||
|
||||
# 字符串字段保持为列表
|
||||
prefixes = [item["prefix"] for item in batch]
|
||||
suffixes = [item["suffix"] for item in batch]
|
||||
pinyins = [item["pinyin"] for item in batch]
|
||||
|
|
@ -96,6 +123,5 @@ dataloader = DataLoader(
|
|||
persistent_workers=True,
|
||||
)
|
||||
|
||||
for i, shape in tqdm(enumerate(dataloader), total=1000000/512):
|
||||
for i, shape in tqdm(enumerate(dataloader), total=1000000 / 512):
|
||||
pass
|
||||
|
||||
|
|
@ -6,8 +6,8 @@ import torch.nn as nn
|
|||
from rich.console import Console
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# 添加src目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.model.model import InputMethodEngine
|
||||
from src.model.trainer import Trainer
|
||||
Loading…
Reference in New Issue