-
๋ณธ ๊ธ์ '๋ชจ๋๋ฅผ ์ํ ๋ฅ๋ฌ๋ ์์ฆ 2'์ 'pytorch๋ก ์์ํ๋ ๋ฅ ๋ฌ๋ ์ ๋ฌธ'์ ๋ณด๋ฉฐ ๊ณต๋ถํ ๋ด์ฉ์ ์ ๋ฆฌํ ๊ธ์ ๋๋ค.
ํ์์ ์๊ฒฌ์ด ์์ฌ ๋ค์ด๊ฐ ๋ถ์ ํํ ๋ด์ฉ์ด ์กด์ฌํ ์ ์์ต๋๋ค.
0. Overfitting
๊ณผ์ ํฉ์ด๋ ๊ฐ์ค์น๋ค์ด ์์น์ ์ผ๋ก ๋๋ฌด ํ์ต๋ฐ์ดํฐ์๋ง ์๋ง๋ ๊ฒ์ ๋งํ๋ค. ์ฆ, ํ์ต์ด ๋๋ฌด ๊ณ ์ฐจ์ ์ ์ผ๋ก ๋ง์ด ๋ ๊ฒ์ด๋ค.
๊ทธ๋ฐ๋ฐ ์ ๊น, ํ์ต๋ฐ์ดํฐ์ ์ ํฉํ๋ฉด ์ข์๊ฒ ์๋๊ฐ?
์๋๋ค! ํ์ต๋ฐ์ดํฐ = ํ ์คํธ ๋ฐ์ดํฐ๊ฐ ์๋๊ธฐ ๋๋ฌธ์ ํ์ต๋ฐ์ดํฐ์๋ ์ ํฉํ๋๋ผ๋, ์คํ๋ ค ์ค์ ์์๋ ์ฑ๋ฅ์ด ๋จ์ด์ง ์ ์๋ค.
์ ๋ฆฌํ์๋ฉด, Train set์์๋ high accuracy, test set์์๋ low accuracy๋ฅผ ๋ณด์ด๋ ํ์์ ๋งํ๋ค.
๊ทธ๋ ๋ค๋ฉด, ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด์๋ ์ด๋ป๊ฒ ํด์ผ ํ ๊น?
๋ํ์ ์ธ ๋ฐฉ๋ฒ 4๊ฐ์ง๊ฐ ์๋ค.
1. More training data
๋ชจ๋ธ์ ๋ฐ์ดํฐ์ ์์ด ์ ์ ๊ฒฝ์ฐ, ํ์ต์ ๋ง์ง ์๋ ๊ทธ ๋ฐ์ดํฐ ๋ง์ ํน์ ํจํด์ด๋ ๋ ธ์ด์ฆ๊น์ง ํ์ตํด์ ๊ณผ์ ํฉ ํ์์ด ๋ฐ์ํ ํ๋ฅ ์ด ๋์์ง๋ค. ๊ทธ๋ ๊ธฐ์, ๋ฐ์ดํฐ์ ์์ ๋๋ฆด ์๋ก ๋ชจ๋ธ์ ๋ฐ์ดํฐ์ ์ผ๋ฐ์ ์ธ ํจํด์ ํ์ตํด์ ๊ณผ์ ํฉ์ ๋ฐฉ์งํ ์ ์๋ค.
๋ง์ฝ, ๋ฐ์ดํฐ์ ์์ด ์ ์ ๊ฒฝ์ฐ์๋ ์๋์ ์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์กฐ๊ธ์ฉ ๋ณํ, ์ถ๊ฐํ์ฌ ๋ฐ์ดํฐ์ ์์ ๋๋ฆฌ๊ธฐ๋ ํ๋๋ฐ ์ด๋ฅผ ๋ฐ์ดํฐ ์ฆ์ ๋๋ ์ฆ๊ฐ(Data Augmentation)์ด๋ผ๊ณ ํ๋ค. ์ด๋ฏธ์ง์ ๊ฒฝ์ฐ ๋ฐ์ดํฐ ์ฆ์์ด ๋ง์ด ์ฌ์ฉ๋๋ค.
2. Reduce the number of features
๋ชจ๋ธ์ ๋ณต์ก๋, ํน์ง์ ์ค์ด๋ ๋ฐฉ๋ฒ์ด๋ค. ์ธ๊ณต ์ ๊ฒฝ๋ง์ ๋ณต์ก๋๋ ์๋์ธต์ ์๋ ๋งค๊ฐ๋ณ์์ ์(์์ฉ๋ ฅ) ๋ฑ์ผ๋ก ๊ฒฐ์ ๋๋ค. ๋ชจ๋ธ์ ๋ณต์ก๋๊ฐ ํ์ ์ด์์ผ๋ก ๋์๋ ๊ณผ์ ํฉ์ด ์ผ์ด๋ ์ ์๋ค.์๋ฅผ ๋ค์ด, 3๊ฐ์ ์๋์ธต์ด ์๋ค๋ฉด 2๊ฐ๋ก ์ค์ด๋ ๊ฒ์ด๋ค.
3. Regularization
๊ฐ์ค์น์ ๊ท์ ๋ฅผ ์ฃผ๋ ๋ฐฉ๋ฒ์ด๋ค. ์์์ ๋งํ๋ฏ์ด ๋ณต์กํ ๋ชจ๋ธ์ด ๊ฐ๋จํ ๋ชจ๋ธ๋ณด๋ค ๊ณผ์ ํฉ๋ ๊ฐ๋ฅ์ฑ์ด ๋๋ค. ๊ทธ๋ ๊ธฐ์ ๋ณต์กํ ๋ชจ๋ธ์ ์ข ๋ ๊ฐ๋จํ๊ฒ ํ๋ ๋ฐฉ๋ฒ์ผ๋ก ๊ฐ์ค์น ๊ท์ ๊ฐ ์๋ค. ๊ท์ ์๋ ๋๊ฐ์ง ์ข ๋ฅ๊ฐ ์๋ค.
L1 ๊ท์ : ๊ฐ์ค์น๋ค์ ์ ๋๊ฐ ํฉ๊ณ๋ฅผ ๋น์ฉํจ์์ ์ถ๊ฐ.
L2 ๊ท์ : ๋ชจ๋ ๊ฐ์ค์น๋ค์ ์ ๊ณฑํฉ์ ๋น์ฉํจ์์ ์ถ๊ฐ.
4. Dropout
๋๋กญ ์์์ ํ์ต ๊ณผ์ ์์ ์ ๊ฒฝ๋ง์ ์ผ๋ถ๋ฅผ ์ฌ์ฉํ์ง ์๋ ๋ฐฉ๋ฒ์ด๋ค. ์ฆ, ์ ๊ตฌ๋ฅผ ์ผ๊ณ ๋๋ ๊ฒ์ฒ๋ผ ์ ๊ฒฝ๋ง์ on/off ํ๋ ๊ฒ์ด๋ค.
์๋ฅผ ๋ค์ด, ๋๋กญ์์์ ๋น์จ์ 0.5๋ก ํ๋ค๋ฉด ํ์ต ๊ณผ์ ๋ง๋ค ๋๋ค์ผ๋ก ๋ด๋ฐ์ ์ ๋ฐ์ ์ฌ์ฉํ์ง ์๊ณ ํ์ตํ๋ค.
๋๋กญ์์์ ์ ๊ฒฝ๋ง ํ์ต ์์๋ง ์ฌ์ฉํ๊ณ , ์์ธก(์ค์ ) ์์๋ ์ฌ์ฉํ์ง ์๋๋ค. ๋๋กญ์์์ ์ฌ์ฉํ๋ฉด ํ์ต ํ ๋ ์ธ๊ณต ์ ๊ฒฝ๋ง์ด ํน์ ๋ด๋ฐ ๋๋ ํน์ ์กฐํฉ์ ๋๋ฌด ์์กด์ ์ด๊ฒ ๋๋ ๊ฒ์ ๋ฐฉ์งํด์ฃผ๊ณ , ์๋ก ๋ค๋ฅธ ์ ๊ฒฝ๋ง๋ค์ ์์๋ธํ์ฌ ์ฌ์ฉํ๋ ๊ฒ ๊ฐ์ ํจ๊ณผ๋ฅผ ๋ด์ ๊ณผ์ ํฉ์ ๋ฐฉ์งํ๋ค. (๋คํธ์ํฌ ์์๋ธ)
dropout = torch.nn.Dropout(p=drop_prob) #p๋ ํ๋ฅ !!! (์ฌ์ ์ ํน์ ๋ ํน์ ํ๋ฅ , ์ ์ฒด ๋ ธ๋์ค์ ๋ชํผ์ผํธ๋ฅผ ์ฌ์ฉํ์ง ์์๊ฑด์ง.) # model model = torch.nn.Sequential(linear1, relu, dropout, linear2, relu, dropout, linear3, relu, dropout, linear4, relu, dropout, linear5).to(device) model.train() # set the model to train mode (dropout=True) #ํ์ตํ๊ธฐ ์ ์ ๊ผญ train()์ ์ ์ธํด์ค์ผํจ.... : ํ๋ฅ ๋๋ก ๋๋ผ๊ณ ์ง์ model.eval() # set the model to evaluation mode (dropout=False) #์ค์ ๋ก ํ ์คํธ ํ ๋๋ ๋๋์์์ ์ํ๊ณ ์ ์ฒด ์ ๊ฒฝ๋ง์ผ๋ก ๋ด์ผํจ!!! #๊ทธ๋์ eval()์ ํธ์ถํด์ ๋๋์์์ ์ฌ์ฉ์ํ๋ค๋๊ฑธ ์๋ ค์ค์ผํจ.
<Reference>
https://deeplearningzerotoall.github.io/season2/lec_pytorch.html
'๐STUDY > ๐ฅPytorch ML&DL' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
Skip connection์์ add(summation) vs concatenation (0) 2021.11.22 09-4. Batch Normalization (0) 2020.03.12 09-2. Weight initialization (0) 2020.03.10 09-1. ํ์ฑํ ํจ์(Activation function) (0) 2020.03.07 08. Perceptron (0) 2020.03.03 ๋๊ธ