diff --git a/eval.py b/eval.py index 7cc0b58..657dcc0 100644 --- a/eval.py +++ b/eval.py @@ -78,16 +78,11 @@ class TextEvaluator: def load_tokenizer(self): """加载tokenizer""" - try: - tokenizer_path = ( - Path(__file__).parent / "src" / "model" / "assets" / "tokenizer" - ) - self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) - print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}") - except Exception as e: - print(f"⚠️ 无法加载tokenizer: {e}") - print("使用默认的bert-base-chinese tokenizer") - self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") + tokenizer_path = ( + Path(__file__).parent / "src" / "model" / "assets" / "tokenizer" + ) + self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path)) + print(f"✅ Tokenizer加载完成,词汇表大小: {self.tokenizer.vocab_size}") def load_query_engine(self): """加载查询引擎用于字符-ID转换""" diff --git a/src/model/components.py b/src/model/components.py index d8988f9..513d989 100644 --- a/src/model/components.py +++ b/src/model/components.py @@ -180,14 +180,15 @@ class SlotMemory(nn.Module): Args: history_ids: [batch, total_steps] Flattened sequence of history tokens. - Empty positions should be filled with a special PAD or handled via mask. + Zero positions use start_emb as learnable query. Returns: S: [batch, total_steps, 512] Slot sequence representation [1] """ - # Embed history tokens - S = self.emb(history_ids) # [B, 24, 512] + S = self.emb(history_ids) # [B, num_slots, dim] + + zero_mask = (history_ids == 0).unsqueeze(-1).float() # [B, num_slots, 1] + S = S * (1 - zero_mask) + self.start_emb * zero_mask - # Add positional embeddings pos_ids = ( torch.arange(S.size(1), device=S.device).unsqueeze(0).expand_as(history_ids) ) diff --git a/src/model/dataset.py b/src/model/dataset.py index 48cfb69..88d3f4e 100644 --- a/src/model/dataset.py +++ b/src/model/dataset.py @@ -42,7 +42,8 @@ class PinyinInputDataset(IterableDataset): max_seq_length=128, text_field: str = "text", py_style_weight=(9, 2, 1), - shuffle_buffer_size: int = 5000, + shuffle_buffer_size: int = 100000, + retention_ratio: float = 0.5, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ): # 频率调整参数 (可根据需要调整) @@ -64,6 +65,16 @@ class PinyinInputDataset(IterableDataset): self.max_workers = max_workers self.py_style_weight = np.array(py_style_weight) / sum(py_style_weight) self.shuffle_buffer_size = shuffle_buffer_size + self.retention_ratio = retention_ratio + if not (0 < retention_ratio < 1): + raise ValueError( + f"retention_ratio必须在0和1之间,当前值: {retention_ratio}" + ) + self.retention_size = int(shuffle_buffer_size * retention_ratio) + if self.retention_size <= 0: + raise ValueError( + f"计算出的retention_size必须大于0,当前值: {self.retention_size} (shuffle_buffer_size={shuffle_buffer_size}, retention_ratio={retention_ratio})" + ) self.possible_lengths = list(length_weights.keys()) self.weights = list(length_weights.values()) @@ -360,17 +371,31 @@ class PinyinInputDataset(IterableDataset): # 添加到缓冲区 batch_samples.extend(samples) - # 处理shuffle buffer + # 处理shuffle buffer - 单缓冲区半保留方案 if len(batch_samples) >= self.shuffle_buffer_size: + # 全量打乱缓冲区 indices = np.random.permutation(len(batch_samples)) - for idx in indices: + + # 计算实际保留大小(不超过缓冲区大小) + actual_retention = min(self.retention_size, len(batch_samples)) + + # 计算输出数量 + output_count = len(batch_samples) - actual_retention + + # 输出前output_count个样本 + for i in range(output_count): if current_iter_index >= worker_quota: - # 清空batch_samples并返回 + # 配额用完,清空缓冲区并返回 batch_samples = [] - return # 使用return而不是break,因为我们在生成器函数中 - yield batch_samples[idx] + return + yield batch_samples[indices[i]] current_iter_index += 1 - batch_samples = [] + + # 保留后actual_retention个样本(不清空缓冲区) + retained_samples = [ + batch_samples[idx] for idx in indices[output_count:] + ] + batch_samples = retained_samples # 处理剩余的样本 if batch_samples: diff --git a/src/model/model.py b/src/model/model.py index 379eed0..70f630c 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -124,20 +124,18 @@ class InputMethodEngine(nn.Module): moe_out = self.moe(fused) # 5. 池化与分类:对槽位维度求平均(使用 mask 池化,完全兼容 torch.compile) - # 使用显式形状重塑,彻底杜绝广播歧义 batch_size = input_ids.size(0) slot_mask = (history_slot_ids != 0).float().view(batch_size, self.num_slots, 1) numerator = (moe_out * slot_mask).sum(dim=1) # [batch, dim] denominator = slot_mask.view(batch_size, -1).sum(dim=1) # [batch] - # 若无有效槽位,使用上下文 H 的掩码均值 - if torch.all(denominator == 0): - # H: [batch, seq_len, dim], attention_mask: [batch, seq_len] - ctx_mask = attention_mask.float().unsqueeze(-1) # [batch, seq_len, 1] - ctx_sum = (H * ctx_mask).sum(dim=1) # [batch, dim] - ctx_cnt = ctx_mask.sum(dim=1) + 1e-8 # [batch, 1] - pooled = ctx_sum / ctx_cnt - else: - pooled = numerator / (denominator.unsqueeze(-1) + 1e-8) # [batch, dim] + + all_slot_mean = moe_out.mean(dim=1) # [batch, dim] + all_zero = denominator == 0 # [batch] + pooled = torch.where( + all_zero.unsqueeze(-1), + all_slot_mean, + numerator / (denominator.unsqueeze(-1) + 1e-8), + ) logits = self.classifier(pooled) # [batch, vocab_size] return logits diff --git a/src/model/trainer.py b/src/model/trainer.py index e9d1f65..4eafa52 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -1070,7 +1070,7 @@ def train( max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), - shuffle_buffer_size=50000, + shuffle_buffer_size=100000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) @@ -1093,7 +1093,7 @@ def train( max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), - shuffle_buffer_size=500000, + shuffle_buffer_size=100000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) @@ -1388,7 +1388,7 @@ def expand_and_train( max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), - shuffle_buffer_size=5000, + shuffle_buffer_size=100000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) @@ -1409,7 +1409,7 @@ def expand_and_train( max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), - shuffle_buffer_size=50000, + shuffle_buffer_size=100000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) @@ -1711,7 +1711,7 @@ def expand_finetune( max_seq_length=final_max_seq_len, text_field="text", py_style_weight=(9, 2, 1), - shuffle_buffer_size=5000, + shuffle_buffer_size=100000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) @@ -1730,7 +1730,7 @@ def expand_finetune( max_seq_length=final_max_seq_len, text_field="text", py_style_weight=(9, 2, 1), - shuffle_buffer_size=50000, + shuffle_buffer_size=100000, length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, ) @@ -1802,4 +1802,3 @@ def expand_finetune( if __name__ == "__main__": app() - diff --git a/test.py b/test.py index 64efef4..9416ac0 100644 --- a/test.py +++ b/test.py @@ -47,8 +47,8 @@ def text_to_pinyin_ids(pinyin_str: str) -> List[int]: return [CHAR_TO_ID.get(c, 0) for c in pinyin_str] -part1 = "他是一名大学生,在上海读" -part2 = "dayi" +part1 = "明明是国庆节,可是因为月底要结账,财务部所有人都" +part2 = "bxu" pinyin_ids = text_to_pinyin_ids(part2) len_py = len(pinyin_ids) if len_py < 24: @@ -57,7 +57,7 @@ else: pinyin_ids = pinyin_ids[:24] pinyin_ids = torch.tensor(pinyin_ids, dtype=torch.long).unsqueeze(0) masked_labels = [0, 0, 0, 0, 0, 0, 0, 0] -part3 = "。" +part3 = "" part4 = "可行|特别|伤害" encoded = tokenizer( @@ -83,7 +83,7 @@ sample = { model = InputMethodEngine(pinyin_vocab_size=30, compile=False) -checkpoint = torch.load("/home/songsenand/下载/best_model.pt", map_location="cpu") +checkpoint = torch.load("/home/songsenand/下载/20260411(acc34)final_model.ptrom", map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) input_ids = sample["input_ids"]