ML/머신러닝

Pytorch를 이용한 Mnist 학습하기

KAU 2020. 2. 24. 14:44

<2020 02 16>

구글 코랩을 이용하여 라이브러리 '파이토치' 기반으로 Mnist를 학습하였다.

 

Mnist


아래 문구를 통해서 다운 받게 된다.
28x28의 이미지 셋으로 라이브러리 'torchvision'에서 제공해준다.

 

mnist_train = torchvision.datasets.MNIST(root="MNIST_data/", train=True, transform=torchvision.transforms.ToTensor(), download=True)

mnist_test = torchvision.datasets.MNIST(root="MNIST_data/", train=False, transform=torchvision.transforms.ToTensor(), download=True)

 

   데이터 셋을 다운받은 것을 확인할 수 있다

 

 

아래 코드를 통해서 이미지 셋을 직접 확인 할 수 있다.

 

 def plot_img(image):

    image = image.numpy()[0]

    mean = 0.1307

    std = 0.3081

    image = ((mean * image) + std)

    plt.imshow(image,cmap='gray')

import matplotlib.pyplot as plt

sample_data = next(iter(data_loader))

plot_img(sample_data[0][2])

 

plot_img(sample_data[0][30])

 

<전체 코드>

import torch

import torchvision

batch_size = 1000

mnist_train = torchvision.datasets.MNIST(root="MNIST_data/", train=True, transform=torchvision.transforms.ToTensor(), download=True)

mnist_test = torchvision.datasets.MNIST(root="MNIST_data/", train=False, transform=torchvision.transforms.ToTensor(), download=True)

data_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)



device = torch.device("cuda:0")

linear = torch.nn.Linear(784, 10, bias=True).to(device)

loss = torch.nn.CrossEntropyLoss().to(device)

SDG = torch.optim.SGD(linear.parameters(), lr=0.1)

total_batch = len(data_loader) # 60 = 60000 / 1000 (total / batch_size)

training_epochs = 10



for epoch in range(training_epochs):

    total_cost = 0

    for X, Y in data_loader:

        X = X.view(-1, 28 * 28).to(device)

        Y = Y.to(device)

        

        hypothesis = linear(X)

        cost = loss(hypothesis, Y)

        SDG.zero_grad()

        cost.backward()

        SDG.step()

        total_cost += cost 

    avg_cost = total_cost / total_batch

    print("Epoch:", "%03d" % (epoch+1), "cost =", "{:.9f}".format(avg_cost))





with torch.no_grad():

    X_test = mnist_test.data.view(-1, 28 * 28).float().to(device)

    Y_test = mnist_test.targets.to(device)

    prediction = linear(X_test)

    correct_prediction = torch.argmax(prediction, 1) == Y_test

    accuracy = correct_prediction.float().mean()

    print("Accuracy: ", accuracy.item())





def plot_img(image):

    image = image.numpy()[0]

    mean = 0.1307

    std = 0.3081

    image = ((mean * image) + std)

    plt.imshow(image,cmap='gray')

import matplotlib.pyplot as plt

sample_data = next(iter(data_loader))

plot_img(sample_data[0][2])