1 需求
存取上述特征向量
2 实现
- 数据结构: 使用
list
存储这些向量,[(r_emb, query), ...]
- 工具:
torch.save()
将tensor
保存为.pth
,存取对象是字典
"""
保存特征向量,推荐使用torch保存,直接保存为tensor
"""
import torch
def save_feature(feature_list, feature_path):
feature = {
}
for i, (r_emb, query) in enumerate(feature_list):
feature[f"r_emb_{i}"] = r_emb
feature[f"query_{i}"] = query
torch.save(feature, feature_path)
pass
def load_feature(feature_path):
feature = torch.load(feature_path)
feature_list = []
for i in range(len(feature.keys()) // 2):
r_emb = feature[f"r_emb_{i}"]
query = feature[f"query_{i}"]
feature_list.append((r_emb, query))
...
return feature_list
...
if __name__ == "__main__":
r_emb_1 = torch.randn((32, 75, 512))
query_1 = torch.randn((32, 22, 512))
r_emb_2 = torch.randn((32, 75, 512))
query_2 = torch.randn((32, 26, 512))
feature_list = [(r_emb_1, query_1), (r_emb_2, query_2)]
feature_path = "./save_feature.pth"
# save_feature(feature_list, feature_path)
feature = load_feature(feature_path)
print("query_1 shape:", feature[0][1].shape)
pass