关于Checkpoints的内容在教程2里已经有了详细的说明,在本节,需要用它来利用模型进行预测
加载checkpoint并预测
使用模型进行预测的最简单方法是使用LightningModule中的load_from_checkpoint加载权重。
model = LitModel.load_from_checkpoint("best_model.ckpt")
model.eval()
x = torch.randn(1, 64)
with torch.no_grad():
y_hat = model(x)
predict_step方法
加载检查点并进行预测仍然会在预测阶段的epoch留下许多boilerplate,LightningModule中的预测步骤删除了这个boilerplate 。
class MyModel(LightningModule):
def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self(batch)
并将任何dataloader传递给Lightning Trainer
data_loader = DataLoader(...)
model = MyModel()
trainer = Trainer()
predictions = trainer.predict(model, data_loader)
预测逻辑
当需要向数据添加复杂的预处理或后处理时,使用predict_step方法。例如,这里我们使用Monte Carlo Dropout 进行预测
class LitMCdropoutModel(pl.LightningModule):
def __init__(self, model, mc_iteration):
super().__init__()
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration
def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
# take average of `self.mc_iteration` iterations
pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
pred = torch.vstack(pred).mean(dim=0)
return pred
启用分布式推理
通过使用Lightning中的predict_step,可以使用BasePredictionWriter进行分布式推理。
import torch
from lightning.pytorch.callbacks import BasePredictionWriter
class CustomWriter(BasePredictionWriter):
def __init__(self, output_dir, write_interval):
super().__init__(write_interval)
self.output_dir = output_dir
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
# 在'output_dir'中创建N (num进程)个文件,每个文件都包含对其各自rank的预测
torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))
# 可以保存'batch_indices',以便从预测数据中获取有关数据索引的信息
torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))
# 可以设置writer_interval="batch"
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])
model = BoringModel()
trainer.predict(model, return_predictions=False)
也可以加载保存的checkpoint,把它当作一个普通的torch.nn.Module来使用。可以提取所有的torch.nn.Module,并在训练后使用LightningModule保存的checkpoint加载权重。建议从LightningModule的init和forward方法中复制明确的实现。
class Encoder(nn.Module):
...
class Decoder(nn.Module):
...
class AutoEncoderProd(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
return self.encoder(x)
class AutoEncoderSystem(LightningModule):
def __init__(self):
super().__init__()
self.auto_encoder = AutoEncoderProd()
def forward(self, x):
return self.auto_encoder.encoder(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.auto_encoder.encoder(x)
y_hat = self.auto_encoder.decoder(y_hat)
loss = ...
return loss
# 训练
trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp")
model = AutoEncoderSystem()
trainer.fit(model, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt")
# 创建PyTorch模型并加载checkpoint权重
model = AutoEncoderProd()
checkpoint = torch.load("best_model.ckpt")
hyper_parameters = checkpoint["hyper_parameters"]
# 恢复超参数
model = AutoEncoderProd(**hyper_parameters)
model_weights = checkpoint["state_dict"]
# 通过 dropping `auto_encoder.` 更新key值
for key in list(model_weights):
model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key)
model.load_state_dict(model_weights)
model.eval()
x = torch.randn(1, 64)
with torch.no_grad():
y_hat = model(x)