淘先锋技术网

首页 1 2 3 4 5 6 7

  • 对参考的内容做了下优化、规范化的处理

版本信息

  • PyTorch: 1.12.1
  • Python: 3.7.13

导包

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence
from torch import nn
import csv
import time
import math
import matplotlib.pyplot as plt

数据集预览

  • 数据
  • 下载
    • 度盘: https://pan.baidu.com/s/1vZ27gKp8Pl-qICn_p2PaSw,提取码:cxe4
    • Kaggle:https://www.kaggle.com/datasets/alionsss/namecountry
filename = "data/names_test.csv"
with open(filename, "rt") as f:
    reader = csv.reader(f)
    rows = list(reader)
# 第一列是姓名,第二列是国家
rows[:10]
[['Abl', 'Czech'],
 ['Alt', 'Czech'],
 ['Bacon', 'Czech'],
 ['Bartonova', 'Czech'],
 ['Benesch', 'Czech'],
 ['Bilek', 'Czech'],
 ['Blazek', 'Czech'],
 ['Bleskan', 'Czech'],
 ['Bohac', 'Czech'],
 ['Borovka', 'Czech']]

数据集处理

  • Tokenizer
class CountryTokenizer:

    def __init__(self, file_path):
        # 生成语言对应的编码词典
        self.country2idx_dict = dict()
        idx = 0
        with open(file_path, "rt") as f:
            reader = csv.reader(f)
            for row in reader:
                if not self.country2idx_dict.__contains__(row[1]):
                    self.country2idx_dict[row[1]] = idx
                    idx += 1

        self.country_num = idx
        self.idx2country_dict = dict()
        for k, v in self.country2idx_dict.items():
            self.idx2country_dict[v] = k

    def encode(self, country):
        return self.country2idx_dict[country]

    def decode(self, idx):
        return self.idx2country_dict[idx]

    def get_country_size(self):
        return self.country_num
  • Dataset
class NameDataset:

    def __init__(self, file_path):
        self.names = []
        self.countries = []
        self.length = 0
        with open(file_path, "rt") as f:
            reader = csv.reader(f)
            for row in reader:
                # 统计名称、语言
                self.names.append(row[0])
                self.countries.append(row[1])
                self.length += 1

    def __getitem__(self, index):
        return self.names[index], self.countries[index]

    def __len__(self):
        return self.length

  • 数据处理函数
def collate_fn(data):
    # 按name长度降序
    data.sort(key=lambda unit: len(unit[0]), reverse=True)
    data_size, max_name_len = len(data), len(data[0][0]) # 降序后,第一个就是长度最大值
    
    name_seq = torch.zeros(data_size, max_name_len, dtype=torch.long) 
    name_len_seq, countries = [], []
    for idx, (name, country) in enumerate(data):
        name_seq[idx, :len(name)] = torch.LongTensor([ord(c) for c in name])
        name_len_seq.append(len(name))
        countries.append(tokenizer.encode(country))

    return name_seq, torch.LongTensor(name_len_seq), torch.LongTensor(countries)
  • 看一下数据
# 看下数据
train_file_path = "data/names_train.csv"
tokenizer = CountryTokenizer(train_file_path)
train_data = NameDataset(file_path=train_file_path,)
train_dataloader = DataLoader(dataset=train_data, batch_size=16, collate_fn=collate_fn, shuffle=True)

for tensor_seq, len_seq, country_seq in train_dataloader:
    break

print(len(train_dataloader))

tensor_seq, len_seq, country_seq
836
(tensor([[ 78,  97, 107, 104,  97,  98, 116, 115, 101, 118],
         [ 82,  97, 104, 108, 101, 118, 115, 107, 121,   0],
         [ 65, 119, 116, 117, 107, 104, 111, 102, 102,   0],
         [ 65, 116,  97,  98, 101, 107, 111, 118,   0,   0],
         [ 68,  97, 110, 105, 108, 121, 117, 107,   0,   0],
         [ 71, 117, 108, 101, 118, 105,  99, 104,   0,   0],
         [ 68, 111, 118, 108,  97, 116, 111, 118,   0,   0],
         [ 84, 122, 101, 103, 111, 101, 118,   0,   0,   0],
         [ 72, 105,  99, 107, 109,  97, 110,   0,   0,   0],
         [ 75, 114, 105, 110, 103, 111, 115,   0,   0,   0],
         [ 84, 115,  97, 108, 105, 101, 118,   0,   0,   0],
         [ 83,  97, 114, 114,  97, 102,   0,   0,   0,   0],
         [ 83, 112, 101, 101, 100,   0,   0,   0,   0,   0],
         [ 80, 121, 106, 111, 118,   0,   0,   0,   0,   0],
         [ 79, 109, 111, 114, 105,   0,   0,   0,   0,   0],
         [ 83,  97, 114, 116, 105,   0,   0,   0,   0,   0]]),
 tensor([10,  9,  9,  8,  8,  8,  8,  7,  7,  7,  7,  6,  5,  5,  5,  5]),
 tensor([ 6,  6,  6,  6,  6,  6,  6,  6,  9, 11,  6,  2,  9,  6,  3, 12]))

构建模型 RNN

  • RNN分类模型
class RNNClassifier(nn.Module):
    
    def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True):
        super(RNNClassifier, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.n_directions = 2 if bidirectional else 1
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, bidirectional=bidirectional)
        self.fc = nn.Linear(hidden_size * self.n_directions, output_size)
        
    def _init_hidden(self, batch_size):
        return torch.zeros(self.n_layers * self.n_directions, batch_size, self.hidden_size)
    
    def forward(self, inputs, len_seq):
        # 转置 
        inputs = inputs.t()
        embedding = self.embedding(inputs)
        gru_input = pack_padded_sequence(embedding, len_seq)
        hidden = self._init_hidden(inputs.size(1)).to(inputs.device)
        output, hidden = self.gru(gru_input, hidden)
        if self.n_directions == 2:
            hidden_cat = torch.cat([hidden[-1], hidden[-2]], dim=1)
        else:
            hidden_cat = hidden[-1]
        fc_output = self.fc(hidden_cat)
        return fc_output

开始训练

  • 准备参数、数据、模型
EPOCH_NUM = 1
BATCH_SIZE = 32

N_CHARS = 128
HIDDEN_SIZE = 100
N_LAYER = 2


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 数据
train_file_path = "data/names_train.csv"
tokenizer = CountryTokenizer(train_file_path)
train_data = NameDataset(train_file_path)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

test_file_path = "data/names_test.csv"
test_data = NameDataset(train_file_path)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

# 模型 
model = RNNClassifier(N_CHARS, HIDDEN_SIZE, tokenizer.get_country_size(), N_LAYER).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  • 训练
def time_since(since):
    s = time.time() - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

start = time.time()
accuracy_list = []
print(f"Training for {EPOCH_NUM} epochs...")
for epoch in range(1, EPOCH_NUM + 1):
    # Train
    model.train()
    total_step = len(train_loader)
    for i, (tensor_seq, len_seq, country_seq) in enumerate(train_loader, 1):
        # len_seq必须在cpu上,不要转换(因为pack_padded_sequence的要求)
        tensor_seq, country_seq = tensor_seq.to(device), country_seq.to(device)
        
        output = model(tensor_seq, len_seq)
        loss = criterion(output, country_seq)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if i % 100 == 0:
            print(f'Train... [time_cost {time_since(start)}] \t [Epoch {epoch}/{N_EPOCHS}, Step {i}/{total_step}] \t [loss={loss.item()}]')
            
    # Eval     
    model.eval()
    correct, total = 0, 0
    for tensor_seq, len_seq, country_seq in test_loader:
        tensor_seq, country_seq = tensor_seq.to(device), country_seq.to(device)
        
        with torch.no_grad():
            output = model(tensor_seq, len_seq)
            
        output = output.argmax(dim=1)
        correct += (output == country_seq).sum().item()
        total += len(country_seq)

    print(f'Eval...  [time_cost {time_since(start)}] \t [Epoch {epoch}/{N_EPOCHS}] \t [accuracy = {(100 * correct / total)}%]')
        
    accuracy_list.append(correct / total)
Training for 30 epochs...
Train... [time_cost 0m 0s] 	 [Epoch 1/30, Step 100/418] 	 [loss=1.3148053884506226]
Train... [time_cost 0m 0s] 	 [Epoch 1/30, Step 200/418] 	 [loss=1.0236825942993164]
Train... [time_cost 0m 1s] 	 [Epoch 1/30, Step 300/418] 	 [loss=0.6916112303733826]
Train... [time_cost 0m 1s] 	 [Epoch 1/30, Step 400/418] 	 [loss=0.5988736748695374]
Eval...  [time_cost 0m 2s] 	 [Epoch 1/30] 	 [accuracy = 79.16105877074922%]
Train... [time_cost 0m 3s] 	 [Epoch 2/30, Step 100/418] 	 [loss=0.7692814469337463]
Train... [time_cost 0m 3s] 	 [Epoch 2/30, Step 200/418] 	 [loss=0.7502091526985168]
Train... [time_cost 0m 4s] 	 [Epoch 2/30, Step 300/418] 	 [loss=0.5968592762947083]
Train... [time_cost 0m 4s] 	 [Epoch 2/30, Step 400/418] 	 [loss=0.5268656015396118]
Eval...  [time_cost 0m 5s] 	 [Epoch 2/30] 	 [accuracy = 84.11843876177659%]
Train... [time_cost 0m 6s] 	 [Epoch 3/30, Step 100/418] 	 [loss=0.5093168020248413]
Train... [time_cost 0m 6s] 	 [Epoch 3/30, Step 200/418] 	 [loss=0.47434017062187195]
Train... [time_cost 0m 7s] 	 [Epoch 3/30, Step 300/418] 	 [loss=0.5101759433746338]
Train... [time_cost 0m 7s] 	 [Epoch 3/30, Step 400/418] 	 [loss=0.4200824499130249]
Eval...  [time_cost 0m 8s] 	 [Epoch 3/30] 	 [accuracy = 87.39344997756842%]
......
Train... [time_cost 1m 15s] 	 [Epoch 28/30, Step 100/418] 	 [loss=0.09028349071741104]
Train... [time_cost 1m 15s] 	 [Epoch 28/30, Step 200/418] 	 [loss=0.02227078191936016]
Train... [time_cost 1m 16s] 	 [Epoch 28/30, Step 300/418] 	 [loss=0.0021461930591613054]
Train... [time_cost 1m 16s] 	 [Epoch 28/30, Step 400/418] 	 [loss=0.08961803466081619]
Eval...  [time_cost 1m 17s] 	 [Epoch 28/30] 	 [accuracy = 98.0933153880664%]
Train... [time_cost 1m 18s] 	 [Epoch 29/30, Step 100/418] 	 [loss=0.015819933265447617]
Train... [time_cost 1m 18s] 	 [Epoch 29/30, Step 200/418] 	 [loss=0.027052825316786766]
Train... [time_cost 1m 19s] 	 [Epoch 29/30, Step 300/418] 	 [loss=0.005545806139707565]
Train... [time_cost 1m 19s] 	 [Epoch 29/30, Step 400/418] 	 [loss=0.03861624002456665]
Eval...  [time_cost 1m 20s] 	 [Epoch 29/30] 	 [accuracy = 97.82413638402872%]
Train... [time_cost 1m 20s] 	 [Epoch 30/30, Step 100/418] 	 [loss=0.11895031481981277]
Train... [time_cost 1m 21s] 	 [Epoch 30/30, Step 200/418] 	 [loss=0.036510467529296875]
Train... [time_cost 1m 21s] 	 [Epoch 30/30, Step 300/418] 	 [loss=0.06800784915685654]
Train... [time_cost 1m 22s] 	 [Epoch 30/30, Step 400/418] 	 [loss=0.029184680432081223]
Eval...  [time_cost 1m 22s] 	 [Epoch 30/30] 	 [accuracy = 97.30073276506654%]

绘制曲线

plt.figure(figsize=(12.8, 7.2))
plt.plot(range(0, EPOCH_NUM), accuracy_list, label="train")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.legend()
plt.show()

image.png

拿几条数据预测一下

  • 代码
my_data = [
    ('Yamura', 'Japanese'), ('XiaoMing', 'Chinese'), ('Rovnev', 'Russian'), ('LiLei', 'Chinese')
]

tensor_seq, len_seq, country_seq = collate_fn(my_data)

with torch.no_grad():
    tensor_seq, country_seq = tensor_seq.to(device), country_seq.to(device)
    
    output = model(tensor_seq, len_seq)
    output = output.argmax(dim=1)
    
    print("Names   ", ["".join([chr(a) for a in n_seq if a > 0]) for n_seq in tensor_seq.tolist()])
    print("Predict ", [tokenizer.decode(i) for i in output.tolist()])
    print("Real    ", [tokenizer.decode(i) for i in country_seq.tolist()])
  • 输出信息
Names    ['XiaoMing', 'Yamura', 'Rovnev', 'LiLei']
Predict  ['Chinese', 'Japanese', 'Russian', 'Chinese']
Real     ['Chinese', 'Japanese', 'Russian', 'Chinese']

参考