diff --git a/README.md b/README.md index 9765343..ff70a0a 100644 --- a/README.md +++ b/README.md @@ -31,10 +31,10 @@ * **输出**:槽位序列 $S$,形状为 `[batch, Num_Slots, 512]`。 #### C. 交叉注意力融合 (Cross-Attention Fusion) -这是模型的核心创新点,用于动态关联“历史记忆”与“当前语境”。 +这是模型的核心创新点,用于动态关联"历史记忆"与"当前语境"。 * **Query (Q)**:当前步的槽位序列 $S$(经过位置编码后)。 * **Key/Value (K/V)**:Transformer 编码器输出的上下文表示 $H$ [1]。 -* **机制**:让历史槽位主动关注当前文本语境,捕捉如“在‘班级第一名’语境下,‘王次香’比‘王慈祥’更相关”的逻辑。 +* **机制**:让历史槽位主动关注当前文本语境,捕捉如"在'班级第一名'语境下,'王次香'比'王慈祥'更相关"的逻辑。 * **输出**:融合后的特征序列,形状为 `[batch, Num_Slots, 512]`。 #### D. 门控与专家混合 (Gating + MoE) @@ -83,7 +83,7 @@ * 结合历史槽位记忆,通过交叉注意力和 MoE 模块融合特征。 * 分类头输出当前步所有候选字的概率分布。 2. **Teacher Forcing**: - * 在训练过程中,**强制使用真实的上一槽位输出**作为下一步的输入条件。这意味着模型在训练时始终基于“正确的历史”进行预测,从而快速收敛。 + * 在训练过程中,**强制使用真实的上一槽位输出**作为下一步的输入条件。这意味着模型在训练时始终基于"正确的历史"进行预测,从而快速收敛。 3. **反向传播**: * 根据 CrossEntropyLoss [1] 计算梯度,并通过 AdamW [1] 更新模型权重。 @@ -94,73 +94,475 @@ * **候选维护**:每个候选路径独立维护其历史槽位序列及累计概率 [1]。 * **终止条件**:当所有槽位填满(如 8×3=24 步)或所有候选分支的最高概率词均为终止符时退出 [1]。 +## 5. Jupyter Lab 训练示例 -## 5. 代码实现示意 (PyTorch) +以下是在 Jupyter Lab 环境中使用 `trainer.Trainer` 类训练输入法模型的完整示例: ```python +# %% [markdown] +# # 输入法模型训练示例 +# 本笔记本展示如何使用 trainer.Trainer 类训练输入法模型 + +# %% [code] +# 1. 导入必要的库 +import sys +import os +from pathlib import Path +from datetime import datetime + import torch -import torch.nn as nn +from torch.utils.data import DataLoader -class Expert(nn.Module): - def __init__(self, dim=512): - super().__init__() - self.net = nn.Sequential( - nn.Linear(dim, dim * 4), - nn.GELU(), - nn.Linear(dim * 4, dim) - ) - def forward(self, x): - return self.net(x) +# 添加项目路径(适应不同的Jupyter Lab运行位置) +project_root = Path.cwd() +# 检查当前目录是否包含src目录,如果不包含则使用父目录 +if not (project_root / "src").exists(): + project_root = project_root.parent +sys.path.insert(0, str(project_root)) # 优先搜索项目目录 -class InputMethodModel(nn.Module): - def __init__(self, vocab_size, pinyin_vocab_size, slot_vocab_size, dim=512, n_layers=4, n_heads=4, num_experts=20): - super().__init__() - # 1. Context Encoder - self.text_emb = nn.Embedding(vocab_size, dim) - self.pinyin_emb = nn.Embedding(pinyin_vocab_size, dim) - self.pos_emb = nn.Embedding(128, dim) - encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads) - self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) - - # 2. Slot Memory - self.slot_emb = nn.Embedding(slot_vocab_size, dim) - self.slot_pos_emb = nn.Embedding(5, dim) # 假设保留5个历史槽位 - - # 3. Cross-Attention - self.cross_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, batch_first=True) - - # 4. MoE Layer - self.num_experts = num_experts - self.experts = nn.ModuleList([Expert(dim) for _ in range(num_experts)]) - self.gate = nn.Linear(dim, num_experts) - - # 5. Classification Head - self.classifier = nn.Linear(dim, vocab_size) +# 导入项目模块 +from src.model.model import InputMethodEngine +from src.model.dataset import PinyinInputDataset +from src.model.trainer import Trainer, worker_init_fn, collate_fn - def forward(self, text_ids, pinyin_ids, history_slot_ids): - # Encode Context - x = self.text_emb(text_ids) + self.pinyin_emb(pinyin_ids) - x += self.pos_emb(torch.arange(x.size(1)).to(x.device)) - H = self.transformer(x) # [B, L, 512] - - # Encode Slots - S = self.slot_emb(history_slot_ids) - S += self.slot_pos_emb(torch.arange(S.size(1)).to(S.device)) - - # Cross-Attention: Q=Slots, K/V=Context - fused, _ = self.cross_attn(S, H, H) # [B, Slots, 512] - - # MoE Processing - # 简化版 MoE: 对所有专家输出进行加权平均 - gate_scores = torch.softmax(self.gate(fused), dim=-1) # [B, Slots, Num_Experts] - expert_outputs = torch.stack([expert(fused) for expert in self.experts], dim=-2) # [B, Slots, Num_Experts, Dim] - moe_out = torch.sum(gate_scores.unsqueeze(-1) * expert_outputs, dim=-2) # [B, Slots, Dim] - - # Pooling & Predict - pooled = moe_out.mean(dim=1) # [B, 512] - logits = self.classifier(pooled) - return logits +# %% [code] +# 2. 配置训练参数 +config = { + # 数据参数 + "train_data_path": "/path/to/your/train/dataset", # 替换为训练数据集路径 + "eval_data_path": "/path/to/your/eval/dataset", # 替换为评估数据集路径 + "output_dir": "./training_output", + + # 模型参数 + "vocab_size": 10019, + "pinyin_vocab_size": 30, + "dim": 512, + "num_slots": 8, + "n_layers": 4, + "n_heads": 4, + "num_experts": 20, + "max_seq_len": 128, + + # 训练参数 + "batch_size": 64, # 根据GPU内存调整 + "num_epochs": 10, + "learning_rate": 3e-4, + "min_learning_rate": 1e-9, + "weight_decay": 0.1, + "warmup_ratio": 0.1, + "label_smoothing": 0.15, + "grad_accum_steps": 2, # 梯度累积,模拟更大batch size + "clip_grad_norm": 1.0, + "eval_frequency": 500, # 每500步评估一次 + "save_frequency": 2000, # 每2000步保存检查点 + + # 高级选项 + "mixed_precision": True, + "use_tensorboard": True, + "seed": 42, + "max_iter_length": 1024 * 1024 * 128, # 最大迭代长度 +} + +# %% [code] +# 3. 设置随机种子和设备 +torch.manual_seed(config["seed"]) +if torch.cuda.is_available(): + torch.cuda.manual_seed_all(config["seed"]) + device = torch.device("cuda") + print(f"✅ 使用 GPU: {torch.cuda.get_device_name(0)}") +else: + device = torch.device("cpu") + print("⚠️ 使用 CPU 进行训练(建议使用 GPU 以获得更好性能)") + +# %% [code] +# 4. 创建数据集和数据加载器 +print("📊 创建数据集和数据加载器...") + +# 训练数据集 +train_dataset = PinyinInputDataset( + data_path=config["train_data_path"], + max_workers=-1, # 自动选择worker数量 + max_iter_length=config["max_iter_length"], + max_seq_length=config["max_seq_len"], + text_field="text", + py_style_weight=(9, 2, 1), + shuffle_buffer_size=5000, + length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, +) + +# 训练数据加载器 +train_dataloader = DataLoader( + train_dataset, + batch_size=config["batch_size"], + num_workers=min(max(1, (os.cpu_count() or 1) - 1), 8), # 合理数量的worker + pin_memory=torch.cuda.is_available(), + worker_init_fn=worker_init_fn, + collate_fn=collate_fn, + prefetch_factor=32, + persistent_workers=True, +) + +# 评估数据集 +eval_dataset = PinyinInputDataset( + data_path=config["eval_data_path"], + max_workers=-1, + max_iter_length=1024, # 评估集较小 + max_seq_length=config["max_seq_len"], + text_field="text", + py_style_weight=(9, 2, 1), + shuffle_buffer_size=1000, + length_weights={1: 10, 2: 50, 3: 50, 4: 40, 5: 15, 6: 10, 7: 5, 8: 2}, +) + +eval_dataloader = DataLoader( + eval_dataset, + batch_size=config["batch_size"], + num_workers=1, + pin_memory=torch.cuda.is_available(), + worker_init_fn=worker_init_fn, + collate_fn=collate_fn, + prefetch_factor=32, + persistent_workers=True, +) + +print(f"✅ 数据加载器创建完成") +print(f" 训练批次大小: {config['batch_size']}") +print(f" 预估训练步数: {config['max_iter_length'] // config['batch_size']}") + +# %% [code] +# 5. 创建模型 +print("🧠 创建输入法模型...") + +model = InputMethodEngine( + vocab_size=config["vocab_size"], + pinyin_vocab_size=config["pinyin_vocab_size"], + dim=config["dim"], + num_slots=config["num_slots"], + n_layers=config["n_layers"], + n_heads=config["n_heads"], + num_experts=config["num_experts"], + max_seq_len=config["max_seq_len"], +) + +# 将模型移动到设备 +model.to(device) + +# 计算参数量 +total_params = sum(p.numel() for p in model.parameters()) +trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + +print(f"✅ 模型创建完成") +print(f" 总参数量: {total_params:,}") +print(f" 可训练参数量: {trainable_params:,}") +print(f" 模型架构: {config['n_layers']}层Transformer, {config['dim']}维度, {config['num_experts']}个MoE专家") + +# %% [code] +# 6. 创建训练器 +print("⚙️ 创建训练器...") + +# 计算总训练步数 +total_steps = int(config["max_iter_length"] / config["batch_size"]) + +trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + total_steps=total_steps, + output_dir=config["output_dir"], + num_epochs=config["num_epochs"], + learning_rate=config["learning_rate"], + min_learning_rate=config["min_learning_rate"], + weight_decay=config["weight_decay"], + warmup_ratio=config["warmup_ratio"], + label_smoothing=config["label_smoothing"], + grad_accum_steps=config["grad_accum_steps"], + clip_grad_norm=config["clip_grad_norm"], + eval_frequency=config["eval_frequency"], + save_frequency=config["save_frequency"], + mixed_precision=config["mixed_precision"], + use_tensorboard=config["use_tensorboard"], +) + +print(f"✅ 训练器创建完成") +print(f" 总训练步数: {total_steps:,}") +print(f" 学习率: {config['learning_rate']:.2e} -> {config['min_learning_rate']:.2e}") +print(f" 输出目录: {config['output_dir']}") + +# %% [code] +# 7. 开始训练 +print("🚀 开始训练...") +print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + +try: + # 开始训练(可以从检查点恢复训练) + trainer.train(resume_from=None) # 设置检查点路径以恢复训练 + + print("✅ 训练完成!") + print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + print(f"模型和日志保存在: {config['output_dir']}") + +except KeyboardInterrupt: + print("⏹️ 训练被用户中断") + print("💾 保存当前检查点...") + trainer.save_checkpoint("interrupted") + print(f"检查点已保存到: {config['output_dir']}/checkpoint_interrupted.pt") + +except Exception as e: + print(f"❌ 训练过程中出现错误: {e}") + import traceback + traceback.print_exc() + +# %% [code] +# 8. 监控训练进度(如果使用TensorBoard) +if config["use_tensorboard"]: + print("📈 TensorBoard日志已记录在:") + print(f" {config['output_dir']}/tensorboard") + print("\n启动TensorBoard查看训练进度:") + print(" tensorboard --logdir ./training_output/tensorboard") + print("然后在浏览器中打开: http://localhost:6006") + +# %% [code] +# 9. 加载训练好的模型进行推理(示例) +def load_trained_model(checkpoint_path): + """加载训练好的模型进行检查点""" + print(f"📥 加载检查点: {checkpoint_path}") + + # 创建与训练时相同配置的模型 + loaded_model = InputMethodEngine( + vocab_size=config["vocab_size"], + pinyin_vocab_size=config["pinyin_vocab_size"], + dim=config["dim"], + num_slots=config["num_slots"], + n_layers=config["n_layers"], + n_heads=config["n_heads"], + num_experts=config["num_experts"], + max_seq_len=config["max_seq_len"], + ) + + # 加载检查点 + checkpoint = torch.load(checkpoint_path, map_location=device) + loaded_model.load_state_dict(checkpoint["model_state_dict"]) + loaded_model.to(device) + loaded_model.eval() + + print(f"✅ 模型加载完成,训练步数: {checkpoint.get('global_step', 'N/A')}") + print(f" 训练损失: {checkpoint.get('train_loss', 'N/A'):.4f}") + + return loaded_model + +# 使用示例(取消注释以使用) +# trained_model = load_trained_model("./training_output/checkpoint_final.pt") ``` -## 6. 总结 -本方案通过**单流 Transformer 编码**结合**结构化槽位交叉注意力**,并引入**20个专家的 MoE 模块** [1],在保证模型轻量(4层 Transformer)的同时,有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构,该设计在工程实现上更优雅,在推理效率上更高效,是轻量级输入法模型的局部最优解。 +### 关键说明 + +1. **环境要求**: + - Python 3.12+ + - PyTorch 2.10+ + - 建议使用GPU进行训练 + - 安装项目依赖:`pip install -e .` + +2. **数据集格式**: + - 使用Hugging Face `datasets`格式 + - 必须包含`text`字段 + - 支持流式读取(streaming=True) + +3. **训练监控**: + - 控制台输出训练进度和指标 + - TensorBoard记录损失、准确率、学习率等 + - 定期保存模型检查点 + +4. **可调整参数**: + - `batch_size`: 根据GPU内存调整 + - `learning_rate`: 建议在1e-4到5e-4之间 + - `grad_accum_steps`: 模拟更大batch size + - `num_epochs`: 根据数据集大小调整 + +5. **故障排除**: + - GPU内存不足:减小`batch_size`或增加`grad_accum_steps` + - 训练不稳定:降低`learning_rate`或增加`warmup_ratio` + - 过拟合:增加`label_smoothing`或使用更大数据集 + +## 6. 使用指南 + +本项目的训练功能通过命令行工具 `train-model` 提供,支持训练、评估和导出模型。 + +### 6.1 安装与准备 + +#### 使用 uv(推荐) +本项目使用 [`uv`](https://github.com/astral-sh/uv) 作为Python包管理器,它比传统的 pip 更快且更可靠。 + +1. **安装 uv**(如果尚未安装): + ```bash + # Linux/macOS + curl -LsSf https://astral.sh/uv/install.sh | sh + + # 或使用 pipx + pipx install uv + + # Windows (PowerShell) + powershell -c "irm https://astral.sh/uv/install.ps1 | iex" + ``` + +2. **安装项目依赖**: + ```bash + uv pip install -e . + ``` + +#### 使用传统 pip +如果不使用 uv,也可以用标准的 pip 安装: + +```bash +# 创建并激活虚拟环境(推荐) +python -m venv .venv +source .venv/bin/activate # Linux/macOS +# .venv\Scripts\activate # Windows + +# 安装依赖 +pip install -e . +``` + +#### 验证安装 +安装完成后,可通过以下命令验证: +```bash +train-model --help +``` + +### 6.2 数据格式 + +训练数据应为Hugging Face数据集格式,支持本地文件或远程数据集仓库。数据集需包含 `text` 字段,并支持流式读取(streaming=True)。 + +#### 本地数据集示例 +```python +# dataset.py +from datasets import Dataset + +data = { + "text": ["这是第一个样本文本。", "这是第二个样本,用于训练输入法模型。"] +} +dataset = Dataset.from_dict(data) +dataset.save_to_disk("./local_dataset") +``` + +#### 远程数据集示例 +支持Hugging Face Hub或ModelScope上的数据集: +- `huggingface.co/datasets/username/dataset_name` +- `modelscope.cn/datasets/username/dataset_name` + +#### 数据格式要求 +- **必需字段**: `text`(字符串类型,包含中文文本) +- **流式读取**: 数据集必须支持 `streaming=True` 参数 +- **数据量**: 建议至少数百万条文本以获得良好效果 + +#### 数据预处理 +数据集会自动进行以下处理: +1. 文本分词和编码 +2. 拼音转换和编码 +3. 上下文窗口滑动生成训练样本 +4. 频率调整(削峰填谷)以平衡高频/低频字词 + +### 6.3 基本训练命令 + +使用 `train-model train` 命令开始训练: + +```bash +train-model train \ + --train-data-path "path/to/train/dataset" \ + --eval-data-path "path/to/eval/dataset" \ + --output-dir "./output" \ + --batch-size 128 \ + --num-epochs 10 \ + --learning-rate 1e-5 +``` + +#### 学习率建议 +根据模型架构和超参数配置(4层Transformer,512维度),推荐使用以下学习率范围: +- **标准范围**: 1e-4 ~ 5e-4 +- **配合Warmup策略**:在训练初期逐步提高学习率 +- **余弦退火**:使用最小学习率 1e-9 进行细调 + +### 6.4 参数详解 + +#### 数据参数 +- `--train-data-path`, `-t`: 训练数据集路径(必需) +- `--eval-data-path`, `-e`: 评估数据集路径(必需) +- `--output-dir`, `-o`: 输出目录(默认:`./output`) +- `--max_iter_length`: 最大迭代长度,控制每次训练迭代处理的数据量(默认:134217728) + +#### 模型参数 +- `--vocab-size`: 词汇表大小(默认:10019) +- `--pinyin-vocab-size`: 拼音词汇表大小(默认:30) +- `--dim`: 模型维度(默认:512) +- `--num-slots`: 历史槽位数量(默认:8) +- `--n-layers`: Transformer层数(默认:4) +- `--n-heads`: 注意力头数(默认:4) +- `--num-experts`: MoE专家数量(默认:20) +- `--max-seq-len`: 最大序列长度(默认:128) +- `--use-pinyin`: 是否使用拼音特征(默认:False) + +#### 训练参数 +- `--batch-size`, `-b`: 批次大小(默认:128) +- `--num-epochs`: 训练轮数(默认:10) +- `--learning-rate`, `-lr`: 学习率(默认:1e-5) +- `--min-learning-rate`: 最小学习率(默认:1e-9) +- `--weight-decay`: 权重衰减(默认:0.1) +- `--warmup-ratio`: 热身步数比例(默认:0.1) +- `--label-smoothing`: 标签平滑参数(默认:0.15) +- `--grad-accum-steps`: 梯度累积步数(默认:1) +- `--clip-grad-norm`: 梯度裁剪范数(默认:1.0) +- `--eval-frequency`: 评估频率(默认:500步) +- `--save-frequency`: 保存频率(默认:10000步) + +#### 高级选项 +- `--mixed-precision/--no-mixed-precision`: 是否使用混合精度训练(默认:启用) +- `--tensorboard/--no-tensorboard`: 是否使用TensorBoard(默认:启用) +- `--resume-from`: 从检查点恢复训练(可选) +- `--seed`: 随机种子(默认:42) + +### 6.5 监控训练进度 + +训练过程中会显示: +- 当前训练步数/总步数 +- 损失值和准确率 +- 学习率变化 +- 内存使用情况 + +启用TensorBoard后,可使用以下命令查看可视化结果: + +```bash +tensorboard --logdir ./output/tensorboard +``` + +### 6.6 评估模型(开发中) + +当前评估功能尚在开发中: + +```bash +train-model evaluate \ + --checkpoint "./output/checkpoint_final.pt" \ + --data-path "path/to/eval/dataset" \ + --batch-size 32 +``` + +命令将显示"评估功能待实现"的提示信息。该功能计划用于: +- 加载训练好的模型检查点 +- 在评估数据集上计算准确率、困惑度等指标 +- 生成详细的性能报告 + +### 6.7 导出模型(开发中) + +当前导出功能尚在开发中: + +```bash +train-model export \ + --checkpoint "./output/checkpoint_final.pt" \ + --output "./exported_model.onnx" +``` + +命令将显示"导出功能待实现"的提示信息。该功能计划用于: +- 将PyTorch模型转换为ONNX格式 +- 支持在不同推理引擎上部署 +- 提供优化后的推理模型 + +## 7. 总结 +本方案通过**单流 Transformer 编码**结合**结构化槽位交叉注意力**,并引入**20个专家的 MoE 模块** [1],在保证模型轻量(4层 Transformer)的同时,有效利用了历史输入习惯并提升了模型表达上限。相比暴力拼接或双流架构,该设计在工程实现上更优雅,在推理效率上更高效,是轻量级输入法模型的局部最优解。 \ No newline at end of file diff --git a/src/model/dataset.py b/src/model/dataset.py index 394dea1..0b8ddb5 100644 --- a/src/model/dataset.py +++ b/src/model/dataset.py @@ -196,7 +196,13 @@ class PinyinInputDataset(IterableDataset): random.seed(seed % (2**32)) np.random.seed(seed % (2**32)) - self.dataset = self.dataset.shard(num_shards=num_workers, index=worker_id) + # 安全检查:如果worker_id >= num_workers,则该worker不应该工作 + # 这可能发生在self.max_workers小于实际worker数量时 + if worker_id >= num_workers: + return # 产生空迭代器 + + # 使用局部变量存储分片数据集,避免竞争条件 + worker_dataset = self.dataset.shard(num_shards=num_workers, index=worker_id) # 计算每个worker的配额 # 将 max_iter_length 转换为整数以确保整数除法 @@ -213,12 +219,13 @@ class PinyinInputDataset(IterableDataset): # 单worker情况,使用全部配额 worker_quota = int(self.max_iter_length) num_workers = 1 + worker_dataset = self.dataset # 不使用分片 # 每个worker有自己的迭代计数器 current_iter_index = 0 batch_samples = [] - for sample in self.dataset: + for sample in worker_dataset: # 检查是否达到最大迭代次数 if current_iter_index >= worker_quota: break @@ -315,9 +322,10 @@ class PinyinInputDataset(IterableDataset): return_token_type_ids=True, ) samples = [] - for i, label in enumerate(labels): + # 修复变量名冲突:将内层循环变量i重命名为label_idx + for label_idx, label in enumerate(labels): repeats = self.adjust_frequency(label) - masked_labels = labels[:i] + masked_labels = labels[:label_idx] len_l = len(masked_labels) masked_labels.extend([0] * (8 - len_l)) diff --git a/src/model/trainer.py b/src/model/trainer.py index aa6eb05..43507c0 100644 --- a/src/model/trainer.py +++ b/src/model/trainer.py @@ -667,7 +667,7 @@ def train( grad_accum_steps: int = typer.Option(1, "--grad-accum-steps", help="梯度累积步数"), clip_grad_norm: float = typer.Option(1.0, "--clip-grad-norm", help="梯度裁剪范数"), eval_frequency: int = typer.Option(500, "--eval-frequency", help="评估频率"), - save_frequency: int = typer.Option(10000, "--save-frequency", help="保存频率"), + save_frequency: int = typer.Option(1000, "--save-frequency", help="保存频率"), # 其他参数 mixed_precision: bool = typer.Option( True, "--mixed-precision/--no-mixed-precision", help="是否使用混合精度训练" @@ -791,7 +791,7 @@ def train( train_dataloader = DataLoader( train_dataset, batch_size=batch_size, - num_workers=min(max(1, (os.cpu_count() or 1) - 1), 25), + num_workers=2, pin_memory=torch.cuda.is_available(), worker_init_fn=worker_init_fn, collate_fn=collate_fn, @@ -803,7 +803,7 @@ def train( eval_dataset = PinyinInputDataset( data_path=eval_data_path, max_workers=-1, - max_iter_length=1024, # 评估集较小 + max_iter_length=batch_size * 64, # 评估集较小 max_seq_length=max_seq_len, text_field="text", py_style_weight=(9, 2, 1), @@ -814,7 +814,7 @@ def train( eval_dataloader = DataLoader( eval_dataset, batch_size=batch_size, - num_workers=1, + num_workers=2, pin_memory=torch.cuda.is_available(), worker_init_fn=worker_init_fn, collate_fn=collate_fn,