-
๋ณธ ๊ธ์ '๋ชจ๋๋ฅผ ์ํ ๋ฅ๋ฌ๋ ์์ฆ 2'์ 'pytorch๋ก ์์ํ๋ ๋ฅ ๋ฌ๋ ์ ๋ฌธ'์ ๋ณด๋ฉฐ ๊ณต๋ถํ ๋ด์ฉ์ ์ ๋ฆฌํ ๊ธ์ ๋๋ค.
ํ์์ ์๊ฒฌ์ด ์์ฌ ๋ค์ด๊ฐ ๋ถ์ ํํ ๋ด์ฉ์ด ์กด์ฌํ ์ ์์ต๋๋ค.
3๊ฐ ์ด์์ ์ ํ์ง์์ 1๊ฐ๋ฅผ ์ ํ! (softํ๊ฒ max๊ฐ์ ๋ฝ์์ฃผ๋)
โ ๋ค์ค ํด๋์ค ๋ถ๋ฅ (Multi-class classification)
์ธ ๊ฐ ์ด์์ ๋ต ์ค ํ๋๋ฅผ ๊ณ ๋ฅด๋ ๋ฌธ์ .
์๊ทธ๋ชจ์ด๋ ํจ์๋ ๋ก์ง์คํฑ ํจ์์ ํ ์ผ์ด์ค๋ผ ๋ณผ ์ ์๊ณ , ์ธํ์ด ํ๋์ผ ๋ ์ฌ์ฉ๋๋ ์๊ทธ๋ชจ์ด๋ ํจ์๋ฅผ ์ธํ์ด ์ฌ๋ฌ๊ฐ์ผ ๋๋ ์ฌ์ฉํ ์ ์๋๋ก ์ผ๋ฐํ ํ ๊ฒ์ด ์ํํธ๋งฅ์ค ํจ์์ ๋๋ค.
0. ์-ํซ ์ธ์ฝ๋ฉ(one-Hot Encoding)
- ์ ํ์ง์ ๊ฐ์๋งํผ ์ฐจ์์ ๊ฐ์ง๋ค.
- ์ ํ์ง์ ํด๋นํ๋ ์ธ๋ฑ์ค๋ 1, ๋๋จธ์ง๋ 0์ผ๋ก ํํํ๋ค.
ex)
๊ฐ์์ง = [1, 0, 0]
๊ณ ์์ด = [0, 1, 0]
๋์ฅ๊ณ = [0, 0, 1]-
์ ์ ์ธ์ฝ๋ฉ(1, 2, 3)๊ณผ์ ์ฐจ์ด์
โ ์ ์ ์ธ์ฝ๋ฉ์ ๊ฐ ํด๋์ค๊ฐ ์์ ์ ๋ณด๋ฅผ ํ์๋ก ํ ๋ ์ ์ฉํ๋ค.
โ ์ ํซ ์ธ์ฝ๋ฉ์ ์ผ๋ฐ์ ์ธ ๋ถ๋ฅ๋ฌธ์ , ์ฆ ์์๊ฐ ์๋ฏธ์๊ณ ๋ฌด์์์ฑ์ด ์์ ๋ ์ ์ฉํ๋ค.
(๋ชจ๋ ํด๋์ค์ ๊ด๊ณ๋ฅผ ๊ท ๋ฑํ๊ฒ ๋ถ๋ฐฐํ๊ธฐ ๋๋ฌธ!)
1. softmax function
๊ฐ ์ ํ์ง๋ง๋ค ์์๋ฅผ ํ ๋นํด์ ๊ทธ ํฉ์ด 1์ด ๋๊ฒ ๋ง๋๋ ํจ์.

for i = 1, 2, ..., k
pi๋ i๋ฒ ํด๋์ค๊ฐ ์ ๋ต์ผ ํ๋ฅ ์ ๋ปํ๋ค. pi(i=1~k)๋ฅผ ๋ค ๋ํ๋ฉด, ๊ทธ ํฉ์ 1์ด ๋๋ค. ์ฆ, ์ํํธ ๋งฅ์ค ํจ์๋ ์ด๋ ต๊ฒ ์๊ฐํ ํ์ ์์ด ์ฃผ์ด์ง ๊ฐ๋ค์ ๋ํด ํฉ์ด 1์ด ๋๋๋ก ๊ทธ ๊ฐ๋ค์ ๋น์จ์ ๋ง์ถฐ ์์๋ก ์ ๊ทํ ์์ผ์ฃผ๋ ํจ์๋ผ๊ณ ์๊ฐํ๋ฉด ๋๋ค.
Softmax( ( 1xf ) * ( fxC ) + ( Cx1 ) ) = C x 1
์ฐจ๋ก๋ก ์ ๋ ฅ๊ฐ, ๊ฐ์ค์น, ํธํฅ, ์์ธก๊ฐ์ด๋ค. (f๋ ํน์ฑ์ ์, C๋ ํด๋์ค์ ๊ฐ์)
๋ฐ์ดํฐ์ ๊ฐ์์ ๋ฐ๋ผ์ ์ ๋ ฅ๊ฐ์ 1์ด ๋ฐ๋๋ค.
2. cost function
logistic regresstion์์๋ binary cross-entropy๋ฅผ ์ฌ์ฉํ๋ค. ์๋ 2๊ฐ ์ค ํ๋๋ฅผ ๊ฒฐ๊ณผ๊ฐ์ผ๋ก ๋ด ๋์์๋๋ฐ, ์ด BCE ๋ณด๋ค ๋ ๊ทผ์์ ์ธ? ํจ์๊ฐ ์๋ค. ๋ฐ๋ก CE! cross entropy!
CE๋ 3๊ฐ ์ด์์ ๊ฐ ์ค ํ๋๋ฅผ ๋ด์ด ๋๋๋ค.

์ฌ๊ธฐ์ ์ต๋๊ฐ์ธ K๋ฅผ 2๋ก ์ง์ ํ๊ฒ ๋๋ค๋ฉด, BCE์ ์์ด ๋์ค๊ฒ ๋๋ค!
3. Code ๊ตฌํ
softmax์ cross-entropy์ ๊ตฌํ ๋ฐฉ๋ฒ์๋ 3๊ฐ์ง๊ฐ ์๋ค.
#1F.softmax() + torch.log() # = F.log_softmax()#2F.log_softmax() + F.nll_loss() # = F.cross_entropy()#3F.cross_entropy()๊ฒฐ๋ก ์ ์ผ๋ก๋ ํธํ๊ฒ 3๋ฒ๋ง ์ฌ์ฉํ๋ฉด ๋๋ค! ํน์ดํ๊ฒ ๊ฐ์ค ํจ์์ ์์ค ํจ์๋ฅผ ํ๋ฒ์ ์ธ ์ ์๋ค! ์ด๋ด ๊ฒฝ์ฐ, ์ค์ code์์๋ ํ๋ ฌ์ ๊ณฑ๋ง ์์ผ์ฃผ๊ณ , ์์ ํฉ์ด 1์ด ๋๋๋ก ์ ๊ทํ ์์ผ์ฃผ๋ ๊ณผ์ ์ ์์คํจ์๋ฅผ ์ธ ๋ ๊ฐ์ด ํ ์ ์๋ ์ F.cross_entropy()์ ๋ง๊ฒจ์ฃผ๋ฉด ๋๋ค.
4. Full Code
import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimx_train = [[1, 2, 1, 1], #4๊ฐ์ ํน์ฑ์ ๊ฐ์ง๊ณ ์๋ 8๊ฐ์ ํ ์คํธ ์ผ์ด์ค[2, 1, 3, 2],[3, 1, 3, 4],[4, 1, 5, 5],[1, 7, 5, 5],[1, 2, 5, 6],[1, 6, 6, 6],[1, 7, 7, 7]]y_train = [2, 2, 2, 1, 1, 1, 0, 0]x_train = torch.FloatTensor(x_train) #ํ ์๋ก ๋ณํy_train = torch.LongTensor(y_train)y_one_hot = torch.zeros(8, 3)#์์ 8์ ํ ์คํธ ์ผ์ด์ค ๊ฐ์, ๋ค์ 3์ ๋ต์ด 2์ผ๋ [0 0 1] ์ด๋ฐ์์ผ๋ก ๋ํ๋ผ ๊ฒ(์ง๊ธ์ ์๋ฆฌ๋ง ๋ง๋ฆ)y_one_hot.scatter_(1, y_train.unsqueeze(1), 1) #์ค์ ๊ฐ y๋ฅผ ์-ํซ ๋ฒกํฐ๋ก ๋ฐ๊ฟprint(y_one_hot.shape)# ๋ชจ๋ธ ์ด๊ธฐํW = torch.zeros((4, 3), requires_grad=True) #ํน์ฑ์ 4๊ฐ, ๊ฒฐ๊ณผ ๊ฐ์ง์๋ 3๊ฐb = torch.zeros(1, requires_grad=True) #1๋ก ํ๋ฉด 3๊ฐ์ ๊ฐ์ ๊ฐ์ด ๋ํด์ง, 3์ผ๋ก ํด๋ ์๊ด์์!# optimizer ์ค์ optimizer = optim.SGD([W, b], lr=0.1)nb_epochs = 1000for epoch in range(nb_epochs + 1):# ๊ฐ์คhypothesis = F.softmax(x_train.matmul(W) + b, dim=1)# ๋น์ฉ ํจ์ - ์ง์ ๊ณ์ฐ ๋ฒ์ cost = (y_one_hot * -torch.log(hypothesis)).sum(dim=1).mean()# cost๋ก H(x) ๊ฐ์optimizer.zero_grad()cost.backward()optimizer.step()# 100๋ฒ๋ง๋ค ๋ก๊ทธ ์ถ๋ ฅif epoch % 100 == 0:print('Epoch {:4d}/{} Cost: {:.6f}'.format(epoch, nb_epochs, cost.item()))4-1. Full Code with nn,Module
class SoftmaxClassifierModel(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(4, 3) # ์ธํ์ 4, Output์ด 3!def forward(self, x):return self.linear(x)model = SoftmaxClassifierModel() #๋ชจ๋ธ ์์ฑ# optimizer ์ค์ optimizer = optim.SGD(model.parameters(), lr=0.1)nb_epochs = 1000for epoch in range(nb_epochs + 1):# H(x) ๊ณ์ฐ - ํ๋ ฌ ๊ณฑ๋ง ํด์ค๋คprediction = model(x_train)# cost ๊ณ์ฐ - ์ด ํจ์์ softmax์๋์ผ๋ก ์ ์ฉ๋จcost = F.cross_entropy(prediction, y_train)# cost๋ก H(x) ๊ฐ์optimizer.zero_grad()cost.backward()optimizer.step()# 20๋ฒ๋ง๋ค ๋ก๊ทธ ์ถ๋ ฅif epoch % 100 == 0:print('Epoch {:4d}/{} Cost: {:.6f}'.format(epoch, nb_epochs, cost.item()))
<Reference>
https://deeplearningzerotoall.github.io/season2/lec_pytorch.html
ํด๋์๊ธฐ๋ก์ ๋จ๊ธฐ๋ ค๊ณ ๋ ธ๋ ฅํฉ๋๋ค
'๐STUDY > ๐ฅPytorch ML&DL' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
08. Perceptron (0) 2020.03.03 07. Tips and MNIST data (0) 2020.03.01 05. Logistic Regression (0) 2020.02.28 04-2. Loading Data(Mini batch and data load) (0) 2020.02.24 04-1. Multivariable Linear regression (0) 2020.02.24
