# -*- coding: utf-8 -*-
import sys, os
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt
import torch.nn as nn
import time
from tqdm import tqdm
MODEL_PATH = "CNNmodel.pth.tar"
MODEL_SAVE_PATH = "CNNmodel_normal.pth.tar"
def load_cifar10(batch=128):
num_workers = 4
valid_size = 0.2
train_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transforms.Compose([ transforms.ToTensor()]))
test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transforms.Compose([ transforms.ToTensor()]))
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
# trainとvalidの境目(split)を指定
split = int(np.floor(valid_size * num_train))
train_index, valid_index = indices[split:], indices[:split]
# samplerの準備
train_sampler = SubsetRandomSampler(train_index)
valid_sampler = SubsetRandomSampler(valid_index)
# data loaderの準備
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch,
sampler = train_sampler, num_workers = num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size = batch,
sampler = valid_sampler, num_workers = num_workers)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch,
num_workers = num_workers)
return {'train_loader': train_loader, 'valid_loader': valid_loader, 'test_loader': test_loader}
class MyCNN(torch.nn.Module):
def __init__(self):
super(MyCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 5) # 28x28x1 -> 24x24x32
self.pool = nn.MaxPool2d(2, 2) # 24x24x32 -> 12x12x32
self.dropout1 = nn.Dropout2d(0.2)
self.conv2 = nn.Conv2d(32, 64, 5) # 12x12x32 -> 8x8x64
self.dropout2 = nn.Dropout2d(0.2)
self.fc1 = nn.Linear(8 * 8 * 64, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x): # predictに相当(順伝搬)
x = self.pool(F.relu(self.conv1(x)))
x = F.relu(self.conv2(x))
x = self.dropout1(x)
x = x.view(-1, 8 * 8 * 64)
x = F.relu(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
return x
def save_checkpoint(state, filename):
torch.save(state, filename)
def train():
print("will begin training")
flag_70=False
flag_99=False
for ep in range(epoch):
train_loss_total = 0
train_acc_total = 0
valid_loss_total = 0
valid_acc_total = 0
net.train()
loss = None
for i, (images, labels) in enumerate(loader['train_loader']):
# viewで28×28×1画像を1次元に変換し、deviceへ転送
images, labels = images.to(device), labels.to(device) # そのまま使う
optimizer.zero_grad() # 勾配リセット
outputs = net(images) # 順伝播の計算
loss = criterion(outputs, labels) # lossの計算
train_loss_total += loss.item() # train_loss に結果を蓄積
acc = (outputs.max(1)[1] == labels).sum() # 予測とラベルが合っている数の合計
train_acc_total += acc.item() # train_acc に結果を蓄積
loss.backward() # 逆伝播の計算
optimizer.step() # 重みの更新
if i % 10 == 0:
print('Training log: {} epoch ({} / 50000 train. data). Loss: {}, Acc: {}'.format(ep + 1,
(i + 1) * 128,
loss.item(),
acc)
)
torch.save(net.state_dict(), MODEL_SAVE_PATH)
train_loss = train_loss_total / len(loader['train_loader'].sampler)
train_acc = train_acc_total / len(loader['train_loader'].sampler)
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
net.eval()
correct = 0
with torch.no_grad():
for i, (images, labels) in enumerate(tqdm(loader['valid_loader'])):
# viewで28×28×1画像を1次元に変換し、deviceへ転送
images, labels = images.to(device), labels.to(device) # そのまま使う
outputs = net(images) # 出力を計算(順伝搬)
loss = criterion(outputs, labels) # lossを計算
valid_loss_total += loss.item() # lossを足す
acc = (outputs.max(1)[1] == labels).sum() # 正解のものを足し合わせてaccを計算
valid_acc_total += acc.item() # accを足す
valid_loss = valid_loss_total / len(loader['valid_loader'].sampler)
valid_acc = valid_acc_total / len(loader['valid_loader'].sampler)
history['valid_loss'].append(valid_loss)
history['valid_acc'].append(valid_acc)
print("valid_acc=",valid_acc)
if valid_acc>=0.7 and flag_70==False:
print("70%over")
flag_70=True
torch.save(net.state_dict(), 'CNNmodel_checkpoint_70.pth.tar')
elif valid_acc>=0.99 and flag_99==False:
print("99%over")
flag_99=True
torch.save(net.state_dict(), "CNNmodel_checkpoint_99.pth.tar")
def test():
test_loss_total = 0
test_acc_total = 0
total = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
net.eval() # ネットワークを推論モードへ
with torch.no_grad():
for i, (images, labels) in enumerate(tqdm(loader['test_loader'])):
images, labels = images.to(device), labels.to(device)
outputs = net(images)
loss = criterion(outputs,labels) # 損失を計算
# 出力と結果が一致している個数を計算
_,pred = torch.max(outputs,1)
test_acc_total += np.squeeze(pred.eq(labels.data.view_as(pred)).sum())
total += labels.size(0)
test_loss_total += loss.item()*images.size(0)
c = (pred == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i]
class_total[label] += 1
test_loss = test_loss_total / len(loader['test_loader'].sampler)
test_acc = test_acc_total.item() / len(loader['test_loader'].sampler)
history['test_loss'].append(test_loss)
history['test_acc'].append(test_acc)
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * test_acc_total.item() / total))
for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
def plot():
# 結果をプロット
plt.figure()
plt.plot(range(1, epoch+1), history['train_loss'], label='train_loss', color='red')
plt.plot(range(1, epoch+1), history['valid_loss'], label='val_loss', color='blue')
plt.title('CNN Training Loss [CIFAR10]')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.savefig('img/CNN_loss.png')
plt.figure()
plt.plot(range(1, epoch+1), history['train_acc'], label='train_acc', color='red')
plt.plot(range(1, epoch+1), history['valid_acc'], label='val_acc', color='blue')
plt.title('CNN Accuracies [CIFAR10]')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend()
plt.savefig('img/CNN_acc.png')
plt.close()
if __name__ == '__main__':
start = time.time()
epoch = 10
loader = load_cifar10()
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck') # CIFAR10のクラス
torch.backends.cudnn.benchmark=True
use_cuda=torch.cuda.is_available()
if use_cuda:
device = 'cuda'
else:
device = 'cpu'
print("device=",device)
net: MyCNN = MyCNN().to(device)
criterion = torch.nn.CrossEntropyLoss() # ロスの計算
optimizer = torch.optim.SGD(params=net.parameters(), lr=0.01, momentum=0.9,weight_decay=0.00005)
flag = os.path.exists(MODEL_PATH)
if flag: #前回の続きから学習
print('loading parameters...')
source = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)
net.load_state_dict(source)
print('parameters loaded')
else:
print("途中のパラメータなし")
history = {
'train_loss': [],
'train_acc': [],
'valid_loss': [],
'valid_acc': [],
'test_loss': [],
'test_acc': []
}
train()
test()
if flag == False:
plot()
elapsed_time = time.time() - start
print ("elapsed_time:{0}".format(elapsed_time) + "[sec]")
cycler==0.10.0
dataclasses==0.6
future==0.18.2
kiwisolver==1.3.1
matplotlib==3.3.3
numpy==1.19.4
Pillow==8.0.1
pyparsing==2.4.7
python-dateutil==2.8.1
six==1.15.0
torch==1.7.0+cu110
torchaudio==0.7.0
torchvision==0.8.1+cu110
tqdm==4.52.0
typing-extensions==3.7.4.3
最近のコメント