diff --git a/src/model/dataset.py b/src/model/dataset.py index db17f61..48cfb69 100644 --- a/src/model/dataset.py +++ b/src/model/dataset.py @@ -7,6 +7,7 @@ from typing import Dict, List, Tuple import numpy as np import torch from datasets import load_dataset +from loguru import logger from modelscope import AutoTokenizer from pypinyin import lazy_pinyin from pypinyin.contrib.tone_convert import to_initials @@ -304,13 +305,19 @@ class PinyinInputDataset(IterableDataset): string_list.append(text[start_pos:end_pos]) # 用|连接所有字符串 part4 = "|".join(string_list) - labels = [ - self.query_engine.get_char_info_by_char_pinyin(c, p).id - for c, p in zip( - text[i : i + pinyin_len], - pinyin_list[i : i + pinyin_len], + try: + labels = [ + self.query_engine.get_char_info_by_char_pinyin(c, p).id + for c, p in zip( + text[i : i + pinyin_len], + pinyin_list[i : i + pinyin_len], + ) + ] + except AttributeError as e: + logger.error( + f"e: {e}, (text, pinyin): {text[i : i + pinyin_len]} - {pinyin_list[i : i + pinyin_len]}" ) - ] + continue if random.random() <= 0.1: labels.append(0)