调整拼音分组及模型参数以优化性能
This commit is contained in:
parent
5857c90be7
commit
350cab20c5
|
|
@ -13,30 +13,30 @@ from torch.utils.data import DataLoader, IterableDataset
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
PG = {
|
PG = {
|
||||||
"y": 0,
|
"r": 0,
|
||||||
"k": 0,
|
"l": 0,
|
||||||
"e": 0,
|
"p": 1,
|
||||||
"l": 1,
|
"d": 1,
|
||||||
"w": 1,
|
"h": 2,
|
||||||
"f": 1,
|
"f": 2,
|
||||||
"q": 2,
|
"g": 3,
|
||||||
"a": 2,
|
"m": 3,
|
||||||
"s": 2,
|
"z": 4,
|
||||||
"x": 3,
|
"o": 4,
|
||||||
"b": 3,
|
"t": 5,
|
||||||
"r": 3,
|
"q": 5,
|
||||||
"o": 4,
|
"b": 6,
|
||||||
"m": 4,
|
"w": 6,
|
||||||
"z": 4,
|
"j": 7,
|
||||||
"g": 5,
|
"e": 7,
|
||||||
"n": 5,
|
"k": 8,
|
||||||
"c": 5,
|
"c": 8,
|
||||||
"t": 6,
|
"s": 9,
|
||||||
"p": 6,
|
"a": 9,
|
||||||
"d": 6,
|
"n": 10,
|
||||||
"j": 7,
|
"x": 10,
|
||||||
"h": 7,
|
"y": 11,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class PinyinInputDataset(IterableDataset):
|
class PinyinInputDataset(IterableDataset):
|
||||||
|
|
@ -70,7 +70,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
repeat_end_freq: int = 10000, # 开始重复的阈值
|
repeat_end_freq: int = 10000, # 开始重复的阈值
|
||||||
max_drop_prob: float = 0.8, # 最大丢弃概率
|
max_drop_prob: float = 0.8, # 最大丢弃概率
|
||||||
max_repeat_expect: float = 50.0, # 最大重复期望
|
max_repeat_expect: float = 50.0, # 最大重复期望
|
||||||
sample_context_section = [0.90, 0.95, 1]
|
sample_context_section=[0.90, 0.95, 1],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化数据集
|
初始化数据集
|
||||||
|
|
@ -127,8 +127,6 @@ class PinyinInputDataset(IterableDataset):
|
||||||
# 上下文采样方式概率区间
|
# 上下文采样方式概率区间
|
||||||
self.sample_context_section = sample_context_section
|
self.sample_context_section = sample_context_section
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_next_chinese_chars(
|
def get_next_chinese_chars(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
|
|
@ -441,7 +439,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
prob = random.random()
|
prob = random.random()
|
||||||
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 8
|
pg = self.pg_groups[processed_pinyin[0]] if processed_pinyin else 12
|
||||||
if prob < 0.1:
|
if prob < 0.1:
|
||||||
py = ""
|
py = ""
|
||||||
else:
|
else:
|
||||||
|
|
@ -455,9 +453,7 @@ class PinyinInputDataset(IterableDataset):
|
||||||
"char_id": torch.tensor([char_info["id"]]),
|
"char_id": torch.tensor([char_info["id"]]),
|
||||||
"char": char,
|
"char": char,
|
||||||
"freq": char_info["freq"],
|
"freq": char_info["freq"],
|
||||||
"pg": torch.tensor(
|
"pg": torch.tensor([pg]),
|
||||||
[pg]
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 根据调整因子重复样本
|
# 根据调整因子重复样本
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,7 @@ def simulated_annealing_grouping(letters, num_groups=8, iterations=10000):
|
||||||
|
|
||||||
|
|
||||||
# 运行模拟退火算法
|
# 运行模拟退火算法
|
||||||
best_groups, sums_c1, sums_c2, best_energy = simulated_annealing_grouping(list(c1.keys()))
|
best_groups, sums_c1, sums_c2, best_energy = simulated_annealing_grouping(list(c1.keys()), 12)
|
||||||
|
|
||||||
print("模拟退火算法分组结果:")
|
print("模拟退火算法分组结果:")
|
||||||
for i, (group, sum1, sum2) in enumerate(zip(best_groups, sums_c1, sums_c2)):
|
for i, (group, sum1, sum2) in enumerate(zip(best_groups, sums_c1, sums_c2)):
|
||||||
|
|
@ -163,5 +163,4 @@ print(f"c1各组总和变异系数: {np.std(sums_c1) / np.mean(sums_c1):.4f}")
|
||||||
print(f"c2各组总和变异系数: {np.std(sums_c2) / np.mean(sums_c2):.4f}")
|
print(f"c2各组总和变异系数: {np.std(sums_c2) / np.mean(sums_c2):.4f}")
|
||||||
|
|
||||||
|
|
||||||
with open("pinyin_group.json", "w") as f:
|
print({letter: i for i, sub_item in enumerate(best_groups) for letter in sub_item})
|
||||||
json.dump({letter: i for i, sub_item in enumerate(best_groups) for letter in sub_item}, f, indent=4)
|
|
||||||
|
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -96,7 +96,7 @@ class MoEModel(nn.Module):
|
||||||
output_multiplier=2,
|
output_multiplier=2,
|
||||||
d_model=768,
|
d_model=768,
|
||||||
num_resblocks=4,
|
num_resblocks=4,
|
||||||
num_domain_experts=8,
|
num_domain_experts=12,
|
||||||
num_shared_experts=1,
|
num_shared_experts=1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -238,10 +238,26 @@ class MoEModel(nn.Module):
|
||||||
expert_out = self.experts[7](pooled) + self.expert_bias(
|
expert_out = self.experts[7](pooled) + self.expert_bias(
|
||||||
torch.tensor(7, device=pooled.device)
|
torch.tensor(7, device=pooled.device)
|
||||||
)
|
)
|
||||||
else: # group_id == 8
|
elif group_id == 8:
|
||||||
expert_out = self.experts[8](pooled) + self.expert_bias(
|
expert_out = self.experts[8](pooled) + self.expert_bias(
|
||||||
torch.tensor(8, device=pooled.device)
|
torch.tensor(8, device=pooled.device)
|
||||||
)
|
)
|
||||||
|
elif group_id == 9:
|
||||||
|
expert_out = self.experts[9](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(9, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 10:
|
||||||
|
expert_out = self.experts[10](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(10, device=pooled.device)
|
||||||
|
)
|
||||||
|
elif group_id == 11:
|
||||||
|
expert_out = self.experts[11](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(12, device=pooled.device)
|
||||||
|
)
|
||||||
|
else: # group_id == 12
|
||||||
|
expert_out = self.experts[12](pooled) + self.expert_bias(
|
||||||
|
torch.tensor(12, device=pooled.device)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
batch_size = pooled.size(0)
|
batch_size = pooled.size(0)
|
||||||
# 并行计算所有专家输出
|
# 并行计算所有专家输出
|
||||||
|
|
@ -498,7 +514,7 @@ class MoEModel(nn.Module):
|
||||||
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
|
f"step: {global_step}, loss: {avg_loss:.4f}, acc: {acc:.4f}, eval_loss: {eval_loss:.4f}"
|
||||||
)
|
)
|
||||||
batch_loss_sum = 0.0
|
batch_loss_sum = 0.0
|
||||||
if processed_batches + 1 >= stop_batch:
|
if processed_batches - 1 >= stop_batch:
|
||||||
break
|
break
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue