位置:espnet/espnet/asr/pytorch_backend/asr.py
一、读取输入输出维度
idim_list:特征向量维数[23](20 Fbank + 3 pitch)
odim:483(汉字字符数)
# 从jason文件中获取输入、输出维度,idim_list:特征向量维数[23], odim:483(汉字字符数)
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
utts = list(valid_json.keys())
idim_list = [
int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs)
] # 输入维度
odim = int(valid_json[utts[0]]["output"][0]["shape"][-1]) # 输出维度
二、载入设置的模型
load_trained_modules(idim, odim, args, interface=ASRInterface)
返回带有初始化权重的模型
模型由args.model_module决定
# 载入设置的模型
model = load_trained_modules(idim_list[0], odim, args)
三、在model.json中写入相关参数
# 在model.json中写入输入输出维度和.yaml文件里所有模型参数
model_conf = args.outdir + "/model.json"
with open(model_conf, "wb") as f:
logging.info("writing a model config file to " + model_conf)
f.write(
json.dumps(
(idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)),
indent=4,
ensure_ascii=False,
sort_keys=True,
).encode("utf_8")
)
四、设置 optimizer (以adam为例)
model_params = model.parameters()
optimizer = torch.optim.Adam(model_params, weight_decay=args.weight_decay)
五、设置converter
# CustomConverter类:返回下采样后的xs_pad, ilens, ys_pad
converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype)
六、读取数据
1、make_batchset函数从json中读取数据转换为 List[List[Tuple[str, dict]]] 格式的batch set。
make_batchset的用法:
>>> data = {'utt1': {'category': 'A', 'input': ...},
... 'utt2': {'category': 'B', 'input': ...},
... 'utt3': {'category': 'B', 'input': ...},
... 'utt4': {'category': 'A', 'input': ...}}
>>> make_batchset(data, batchsize=2, ...)
[[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]
# 读取训练集数据(验证集同理)
with open(args.train_json, "rb") as f:
train_json = json.load(f)["utts"]
# 构造训练数据batchset
train = make_batchset(
train_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=0,
)
2、LoadInputsAndTargets的功能是构造mini batch,其call函数:call(self, batch, return_uttid=False) 可以从dict中提取输入特征向量(feats)和标签(targets)。
feats = [(T_1, D), (T_2, D), …, (T_B, D)]
targets = [(L_1), (L_2), …, (L_B)]
LoadInputsAndTargets用法:
>>> batch = [('utt1',
... dict(input=[dict(feat='some.ark:123',
... filetype='mat',
... name='input1',
... shape=[100, 80])],
... output=[dict(tokenid='1 2 3 4',
... name='target1',
... shape=[4, 31])]))]
>>> load_tr = LoadInputsAndTargets()
>>> feat, target = load_tr(batch)
load_tr = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf, # 检查预处理conf,如specaug
preprocess_args={"train": True}, # Switch the mode of preprocessing
)
3、ChainerDataLoader是一个Chainer风格的pytorch DataLoader。
TransformDataset将数据转换为Pytorch Dataset,
class TransformDataset(torch.utils.data.Dataset):
def __init__(self, data, transform):
super(TransformDataset).__init__()
self.data = data
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.transform(self.data[idx])
train_iter = ChainerDataLoader(
dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
batch_size=1,
num_workers=args.n_iter_processes,
shuffle=not use_sortagrad,
collate_fn=lambda x: x[0],
)
七、设置Updater
自定义CustomUpdater,核心代码(简化后)如下:
def update_core(self):
# When we pass one iterator and optimizer to StandardUpdater.__init__,
# they are automatically named 'main'.
train_iter = self.get_iterator("main")
optimizer = self.get_optimizer("main")
epoch = train_iter.epoch
batch = train_iter.next()
x = _recursive_to(batch, self.device)
is_new_epoch = train_iter.epoch != epoch
loss = (data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad)
loss.backward() # 反向传播
# 正则化方法的一种:噪声注入
if self.grad_noise:
......
self.forward_count += 1
if not is_new_epoch and self.forward_count != self.accum_grad:
return
self.forward_count = 0
# 计算grad_norm,检查梯度是否正常
......
optimizer.step() # 更新参数
optimizer.zero_grad() # 梯度清零
def update(self):
self.update_core()
if self.forward_count == 0:
self.iteration += 1
updater = CustomUpdater(
model,
args.grad_clip, # 如果在更新梯度的时候,梯度超过这个阈值,则会将其限制在这个范围之内,防止梯度爆炸。
{"main": train_iter}, # chainer iterator
optimizer, #
device,
args.ngpu,
args.grad_noise, # 正则化方法的一种:噪声注入
args.accum_grad, # 梯度累加(默认2,即每两轮梯度清零)
use_apex=use_apex,
)
八、设置Chainer训练器
格式为 trainer = training.Trainer(updater, (max_epoch, ‘epoch’), out=path)
# 设置Chainer训练器,training.Trainer(updater, (max_epoch, 'epoch'), out=path)
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
九、训练器扩展功能
# 评估模型
trainer.extend(CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu))
# 每个epoch保存attention权重
trainer.extend(att_reporter, trigger=(1, "epoch"))
# 每个epoch保存CTC prob
trainer.extend(ctc_reporter, trigger=(1, "epoch"))
# 绘制 loss.png
trainer.extend(
extensions.PlotReport(
[
"main/loss",
"validation/main/loss",
"main/loss_ctc",
"validation/main/loss_ctc",
"main/loss_att",
"validation/main/loss_att"
],
"epoch",
file_name="loss.png",
)
)
# 绘制 acc.png
trainer.extend(
extensions.PlotReport(
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
)
)
# 绘制cer.png
trainer.extend(
extensions.PlotReport(
["main/cer_ctc", "validation/main/cer_ctc"], "epoch", file_name="cer.png",
)
)
# 保存loss best模型
trainer.extend(
snapshot_object(model, "model.loss.best"),
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
)
# 保存acc best模型
trainer.extend(
snapshot_object(model, "model.acc.best"),
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
)
# 每个epoch保存snapshot (用于模型平均)
trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
# 每100次迭代,在train.log中记录一次
trainer.extend(
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
)
# 每100次迭代,在log中记录report_keys,包括"epoch", "iteration", "main/loss" ......
trainer.extend(
extensions.PrintReport(report_keys),
trigger=(args.report_interval_iters, "iteration"),
)
# 每100次迭代,在train.log中绘制进度条
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
十、设置早停
十一、运行
trainer.run()
check_early_stop(trainer, args.epochs)