淘先锋技术网

首页 1 2 3 4 5 6 7

位置:espnet/espnet/asr/pytorch_backend/asr.py

一、读取输入输出维度

idim_list:特征向量维数[23](20 Fbank + 3 pitch)
odim:483(汉字字符数)

    # 从jason文件中获取输入、输出维度,idim_list:特征向量维数[23], odim:483(汉字字符数)
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim_list = [
        int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs)
    ]  # 输入维度
    odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])  # 输出维度

二、载入设置的模型

load_trained_modules(idim, odim, args, interface=ASRInterface)
返回带有初始化权重的模型
模型由args.model_module决定

    # 载入设置的模型
    model = load_trained_modules(idim_list[0], odim, args)

三、在model.json中写入相关参数

    # 在model.json中写入输入输出维度和.yaml文件里所有模型参数
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps(
                (idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)),
                indent=4,
                ensure_ascii=False,
                sort_keys=True,
            ).encode("utf_8")
        )

四、设置 optimizer (以adam为例)

	model_params = model.parameters()
	optimizer = torch.optim.Adam(model_params, weight_decay=args.weight_decay)

五、设置converter

	# CustomConverter类:返回下采样后的xs_pad, ilens, ys_pad
	converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype)

六、读取数据

1、make_batchset函数从json中读取数据转换为 List[List[Tuple[str, dict]]] 格式的batch set。

make_batchset的用法:
    >>> data = {'utt1': {'category': 'A', 'input': ...},
    ...         'utt2': {'category': 'B', 'input': ...},
    ...         'utt3': {'category': 'B', 'input': ...},
    ...         'utt4': {'category': 'A', 'input': ...}}
    >>> make_batchset(data, batchsize=2, ...)
    [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]
# 读取训练集数据(验证集同理)
	with open(args.train_json, "rb") as f:
    	train_json = json.load(f)["utts"]
# 构造训练数据batchset    
	train = make_batchset(
	    train_json,
	    args.batch_size,
	    args.maxlen_in,
	    args.maxlen_out,
	    args.minibatches,
	    min_batch_size=args.ngpu if args.ngpu > 1 else 1,
	    shortest_first=use_sortagrad,
	    count=args.batch_count,
	    batch_bins=args.batch_bins,
	    batch_frames_in=args.batch_frames_in,
	    batch_frames_out=args.batch_frames_out,
	    batch_frames_inout=args.batch_frames_inout,
	    iaxis=0,
	    oaxis=0,
	)

2、LoadInputsAndTargets的功能是构造mini batch,其call函数:call(self, batch, return_uttid=False) 可以从dict中提取输入特征向量(feats)和标签(targets)。
feats = [(T_1, D), (T_2, D), …, (T_B, D)]
targets = [(L_1), (L_2), …, (L_B)]

LoadInputsAndTargets用法:
>>> batch = [('utt1',
...           dict(input=[dict(feat='some.ark:123',
...                            filetype='mat',
...                            name='input1',
...                            shape=[100, 80])],
...                output=[dict(tokenid='1 2 3 4',
...                             name='target1',
...                             shape=[4, 31])]))]
>>> load_tr = LoadInputsAndTargets()
>>> feat, target = load_tr(batch)
    load_tr = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=args.preprocess_conf, # 检查预处理conf,如specaug
        preprocess_args={"train": True},  # Switch the mode of preprocessing
    )

3、ChainerDataLoader是一个Chainer风格的pytorch DataLoader。
TransformDataset将数据转换为Pytorch Dataset,

	class TransformDataset(torch.utils.data.Dataset):
		def __init__(self, data, transform):
		   super(TransformDataset).__init__()
		   self.data = data
		   self.transform = transform
		
		def __len__(self):
		   return len(self.data)
		
		def __getitem__(self, idx):
		   return self.transform(self.data[idx])
    train_iter = ChainerDataLoader(
        dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
        batch_size=1,
        num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad,
        collate_fn=lambda x: x[0],
    )

七、设置Updater

自定义CustomUpdater,核心代码(简化后)如下:

    def update_core(self):
        # When we pass one iterator and optimizer to StandardUpdater.__init__,
        # they are automatically named 'main'.
		train_iter = self.get_iterator("main")
        optimizer = self.get_optimizer("main")
        epoch = train_iter.epoch
        
        batch = train_iter.next()
        x = _recursive_to(batch, self.device)
        is_new_epoch = train_iter.epoch != epoch
        
        loss = (data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad)
        loss.backward()  # 反向传播
        # 正则化方法的一种:噪声注入
        if self.grad_noise:
			......

        self.forward_count += 1
        if not is_new_epoch and self.forward_count != self.accum_grad:
            return
            
        self.forward_count = 0
		# 计算grad_norm,检查梯度是否正常
		......
		optimizer.step()  # 更新参数
        optimizer.zero_grad()  # 梯度清零
        
    def update(self):
        self.update_core()
        if self.forward_count == 0:
            self.iteration += 1
    updater = CustomUpdater(
        model,
        args.grad_clip,  # 如果在更新梯度的时候,梯度超过这个阈值,则会将其限制在这个范围之内,防止梯度爆炸。
        {"main": train_iter},  # chainer iterator
        optimizer, # 
        device,
        args.ngpu,
        args.grad_noise,  # 正则化方法的一种:噪声注入
        args.accum_grad,  # 梯度累加(默认2,即每两轮梯度清零)
        use_apex=use_apex,
    )

八、设置Chainer训练器

格式为 trainer = training.Trainer(updater, (max_epoch, ‘epoch’), out=path)

    # 设置Chainer训练器,training.Trainer(updater, (max_epoch, 'epoch'), out=path)
    trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)

九、训练器扩展功能

	# 评估模型
    trainer.extend(CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu))
    # 每个epoch保存attention权重
    trainer.extend(att_reporter, trigger=(1, "epoch"))
    # 每个epoch保存CTC prob
    trainer.extend(ctc_reporter, trigger=(1, "epoch"))
    
    # 绘制 loss.png
    trainer.extend(
        extensions.PlotReport(
            [
                "main/loss",
                "validation/main/loss",
                "main/loss_ctc",
                "validation/main/loss_ctc",
                "main/loss_att",
                "validation/main/loss_att"
            ],
            "epoch",
            file_name="loss.png",
        )
    )
	# 绘制 acc.png
    trainer.extend(
        extensions.PlotReport(
            ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
        )
    )
    # 绘制cer.png
    trainer.extend(
        extensions.PlotReport(
            ["main/cer_ctc", "validation/main/cer_ctc"], "epoch", file_name="cer.png",
        )
    )

	# 保存loss best模型
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss"),
    )
    # 保存acc best模型
    trainer.extend(
    	snapshot_object(model, "model.acc.best"),
   		trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
	)
	# 每个epoch保存snapshot (用于模型平均)
	trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
	
	# 每100次迭代,在train.log中记录一次
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
    )   
    # 每100次迭代,在log中记录report_keys,包括"epoch", "iteration", "main/loss" ......
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )
	# 每100次迭代,在train.log中绘制进度条
    trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))

十、设置早停

十一、运行

    trainer.run()
    check_early_stop(trainer, args.epochs)