• 09-3. Overfitting์„ ๋ง‰๋Š” ๋ฐฉ๋ฒ•

    2020. 3. 10.

    by. ํ•ด๋Š”์„ 

    ๋ณธ ๊ธ€์€ '๋ชจ๋‘๋ฅผ ์œ„ํ•œ ๋”ฅ๋Ÿฌ๋‹ ์‹œ์ฆŒ 2'์™€ 'pytorch๋กœ ์‹œ์ž‘ํ•˜๋Š” ๋”ฅ ๋Ÿฌ๋‹ ์ž…๋ฌธ'์„ ๋ณด๋ฉฐ ๊ณต๋ถ€ํ•œ ๋‚ด์šฉ์„ ์ •๋ฆฌํ•œ ๊ธ€์ž…๋‹ˆ๋‹ค.

    ํ•„์ž์˜ ์˜๊ฒฌ์ด ์„ž์—ฌ ๋“ค์–ด๊ฐ€ ๋ถ€์ •ํ™•ํ•œ ๋‚ด์šฉ์ด ์กด์žฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


    0. Overfitting

    ๊ณผ์ ํ•ฉ์ด๋ž€ ๊ฐ€์ค‘์น˜๋“ค์ด ์ˆ˜์น˜์ ์œผ๋กœ ๋„ˆ๋ฌด ํ•™์Šต๋ฐ์ดํ„ฐ์—๋งŒ ์•Œ๋งž๋Š” ๊ฒƒ์„ ๋งํ•œ๋‹ค. ์ฆ‰, ํ•™์Šต์ด ๋„ˆ๋ฌด ๊ณ ์ฐจ์› ์ ์œผ๋กœ ๋งŽ์ด ๋œ ๊ฒƒ์ด๋‹ค.

    ๊ทธ๋Ÿฐ๋ฐ ์ž ๊น, ํ•™์Šต๋ฐ์ดํ„ฐ์— ์ ํ•ฉํ•˜๋ฉด ์ข‹์€๊ฒŒ ์•„๋‹Œ๊ฐ€?

    ์•„๋‹ˆ๋‹ค! ํ•™์Šต๋ฐ์ดํ„ฐ = ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๊ฐ€ ์•„๋‹ˆ๊ธฐ ๋•Œ๋ฌธ์— ํ•™์Šต๋ฐ์ดํ„ฐ์—๋Š” ์ ํ•ฉํ•˜๋”๋ผ๋„, ์˜คํžˆ๋ ค ์‹ค์ „์—์„œ๋Š” ์„ฑ๋Šฅ์ด ๋–จ์–ด์งˆ ์ˆ˜ ์žˆ๋‹ค.

     

    ์ •๋ฆฌํ•˜์ž๋ฉด, Train set์—์„œ๋Š” high accuracy, test set์—์„œ๋Š” low accuracy๋ฅผ ๋ณด์ด๋Š” ํ˜„์ƒ์„ ๋งํ•œ๋‹ค.

     

    ๊ทธ๋ ‡๋‹ค๋ฉด, ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ• ๊นŒ?

    ๋Œ€ํ‘œ์ ์ธ ๋ฐฉ๋ฒ• 4๊ฐ€์ง€๊ฐ€ ์žˆ๋‹ค.

     

    1. More training data

    ๋ชจ๋ธ์€ ๋ฐ์ดํ„ฐ์˜ ์–‘์ด ์ ์„ ๊ฒฝ์šฐ, ํ•™์Šต์— ๋งž์ง€ ์•Š๋Š” ๊ทธ ๋ฐ์ดํ„ฐ ๋งŒ์˜ ํŠน์ • ํŒจํ„ด์ด๋‚˜ ๋…ธ์ด์ฆˆ๊นŒ์ง€ ํ•™์Šตํ•ด์„œ ๊ณผ์ ํ•ฉ ํ˜„์ƒ์ด ๋ฐœ์ƒํ™œ ํ™•๋ฅ ์ด ๋†’์•„์ง„๋‹ค. ๊ทธ๋ ‡๊ธฐ์—, ๋ฐ์ดํ„ฐ์˜ ์–‘์„ ๋Š˜๋ฆด ์ˆ˜๋ก ๋ชจ๋ธ์€ ๋ฐ์ดํ„ฐ์˜ ์ผ๋ฐ˜์ ์ธ ํŒจํ„ด์„ ํ•™์Šตํ•ด์„œ ๊ณผ์ ํ•ฉ์„ ๋ฐฉ์ง€ํ•  ์ˆ˜ ์žˆ๋‹ค.

     

    ๋งŒ์•ฝ, ๋ฐ์ดํ„ฐ์˜ ์–‘์ด ์ ์„ ๊ฒฝ์šฐ์—๋Š” ์˜๋„์ ์œผ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ์กฐ๊ธˆ์”ฉ ๋ณ€ํ˜•, ์ถ”๊ฐ€ํ•˜์—ฌ ๋ฐ์ดํ„ฐ์˜ ์–‘์„ ๋Š˜๋ฆฌ๊ธฐ๋„ ํ•˜๋Š”๋ฐ ์ด๋ฅผ ๋ฐ์ดํ„ฐ ์ฆ์‹ ๋˜๋Š” ์ฆ๊ฐ•(Data Augmentation)์ด๋ผ๊ณ  ํ•œ๋‹ค. ์ด๋ฏธ์ง€์˜ ๊ฒฝ์šฐ ๋ฐ์ดํ„ฐ ์ฆ์‹์ด ๋งŽ์ด ์‚ฌ์šฉ๋œ๋‹ค.

     

    2. Reduce the number of features

    ๋ชจ๋ธ์˜ ๋ณต์žก๋„, ํŠน์ง•์„ ์ค„์ด๋Š” ๋ฐฉ๋ฒ•์ด๋‹ค. ์ธ๊ณต ์‹ ๊ฒฝ๋ง์˜ ๋ณต์žก๋„๋Š” ์€๋‹‰์ธต์˜ ์ˆ˜๋‚˜ ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ์ˆ˜(์ˆ˜์šฉ๋ ฅ) ๋“ฑ์œผ๋กœ ๊ฒฐ์ • ๋œ๋‹ค. ๋ชจ๋ธ์˜ ๋ณต์žก๋„๊ฐ€ ํ•„์š” ์ด์ƒ์œผ๋กœ ๋†’์•„๋„ ๊ณผ์ ํ•ฉ์ด ์ผ์–ด๋‚  ์ˆ˜ ์žˆ๋‹ค.์˜ˆ๋ฅผ ๋“ค์–ด, 3๊ฐœ์˜ ์€๋‹‰์ธต์ด ์žˆ๋‹ค๋ฉด 2๊ฐœ๋กœ ์ค„์ด๋Š” ๊ฒƒ์ด๋‹ค.

     

    3. Regularization

    ๊ฐ€์ค‘์น˜์— ๊ทœ์ œ๋ฅผ ์ฃผ๋Š” ๋ฐฉ๋ฒ•์ด๋‹ค. ์œ„์—์„œ ๋งํ–ˆ๋“ฏ์ด ๋ณต์žกํ•œ ๋ชจ๋ธ์ด ๊ฐ„๋‹จํ•œ ๋ชจ๋ธ๋ณด๋‹ค ๊ณผ์ ํ•ฉ๋  ๊ฐ€๋Šฅ์„ฑ์ด ๋†’๋‹ค. ๊ทธ๋ ‡๊ธฐ์— ๋ณต์žกํ•œ ๋ชจ๋ธ์„ ์ข€ ๋” ๊ฐ„๋‹จํ•˜๊ฒŒ ํ•˜๋Š” ๋ฐฉ๋ฒ•์œผ๋กœ ๊ฐ€์ค‘์น˜ ๊ทœ์ œ๊ฐ€ ์žˆ๋‹ค. ๊ทœ์ œ์—๋Š” ๋‘๊ฐ€์ง€ ์ข…๋ฅ˜๊ฐ€ ์žˆ๋‹ค.

     

    L1 ๊ทœ์ œ : ๊ฐ€์ค‘์น˜๋“ค์˜ ์ ˆ๋Œ€๊ฐ’ ํ•ฉ๊ณ„๋ฅผ ๋น„์šฉํ•จ์ˆ˜์— ์ถ”๊ฐ€.

    L2 ๊ทœ์ œ : ๋ชจ๋“  ๊ฐ€์ค‘์น˜๋“ค์˜ ์ œ๊ณฑํ•ฉ์„ ๋น„์šฉํ•จ์ˆ˜์— ์ถ”๊ฐ€.

     

    4. Dropout

    ๋“œ๋กญ ์•„์›ƒ์€ ํ•™์Šต ๊ณผ์ •์—์„œ ์‹ ๊ฒฝ๋ง์˜ ์ผ๋ถ€๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” ๋ฐฉ๋ฒ•์ด๋‹ค. ์ฆ‰, ์ „๊ตฌ๋ฅผ ์ผœ๊ณ  ๋„๋Š” ๊ฒƒ์ฒ˜๋Ÿผ ์‹ ๊ฒฝ๋ง์„ on/off ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

    ์˜ˆ๋ฅผ ๋“ค์–ด, ๋“œ๋กญ์•„์›ƒ์˜ ๋น„์œจ์„ 0.5๋กœ ํ•œ๋‹ค๋ฉด ํ•™์Šต ๊ณผ์ •๋งˆ๋‹ค ๋žœ๋ค์œผ๋กœ ๋‰ด๋Ÿฐ์˜ ์ ˆ๋ฐ˜์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ํ•™์Šตํ•œ๋‹ค.

     

    ๋“œ๋กญ์•„์›ƒ์€ ์‹ ๊ฒฝ๋ง ํ•™์Šต ์‹œ์—๋งŒ ์‚ฌ์šฉํ•˜๊ณ , ์˜ˆ์ธก(์‹ค์ „) ์‹œ์—๋Š” ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š”๋‹ค. ๋“œ๋กญ์•„์›ƒ์„ ์‚ฌ์šฉํ•˜๋ฉด ํ•™์Šต ํ•  ๋•Œ ์ธ๊ณต ์‹ ๊ฒฝ๋ง์ด ํŠน์ • ๋‰ด๋Ÿฐ ๋˜๋Š” ํŠน์ • ์กฐํ•ฉ์— ๋„ˆ๋ฌด ์˜์กด์ ์ด๊ฒŒ ๋˜๋Š” ๊ฒƒ์„ ๋ฐฉ์ง€ํ•ด์ฃผ๊ณ , ์„œ๋กœ ๋‹ค๋ฅธ ์‹ ๊ฒฝ๋ง๋“ค์„ ์•™์ƒ๋ธ”ํ•˜์—ฌ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ ๊ฐ™์€ ํšจ๊ณผ๋ฅผ ๋‚ด์„œ ๊ณผ์ ํ•ฉ์„ ๋ฐฉ์ง€ํ•œ๋‹ค. (๋„คํŠธ์›Œํฌ ์•™์ƒ๋ธ”)

     

    dropout = torch.nn.Dropout(p=drop_prob)
    #p๋Š” ํ™•๋ฅ !!! (์‚ฌ์ „์— ํŠน์ •๋œ ํŠน์ • ํ™•๋ฅ , ์ „์ฒด ๋…ธ๋“œ์ค‘์— ๋ช‡ํผ์„ผํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š์„๊ฑด์ง€.)
    
    
    # model
    model = torch.nn.Sequential(linear1, relu, dropout,
                                linear2, relu, dropout,
                                linear3, relu, dropout,
                                linear4, relu, dropout,
                                linear5).to(device)
    
    model.train()    # set the model to train mode (dropout=True)
    #ํ•™์Šตํ•˜๊ธฐ ์ „์— ๊ผญ train()์„ ์„ ์–ธํ•ด์ค˜์•ผํ•จ.... : ํ™•๋ฅ ๋Œ€๋กœ ๋„๋ผ๊ณ  ์ง€์‹œ
    
    model.eval()    # set the model to evaluation mode (dropout=False)
        #์‹ค์ œ๋กœ ํ…Œ์ŠคํŠธ ํ• ๋•Œ๋Š” ๋“œ๋ž์•„์›ƒ์„ ์•ˆํ•˜๊ณ  ์ „์ฒด ์‹ ๊ฒฝ๋ง์œผ๋กœ ๋ด์•ผํ•จ!!!
        #๊ทธ๋ž˜์„œ eval()์„ ํ˜ธ์ถœํ•ด์„œ ๋“œ๋ž์•„์›ƒ์„ ์‚ฌ์šฉ์•ˆํ•œ๋‹ค๋Š”๊ฑธ ์•Œ๋ ค์ค˜์•ผํ•จ.

    <Reference>

    https://deeplearningzerotoall.github.io/season2/lec_pytorch.html

    https://wikidocs.net/60751

    '๐Ÿ“šSTUDY > ๐Ÿ”ฅPytorch ML&DL' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

    Skip connection์—์„œ add(summation) vs concatenation  (0) 2021.11.22
    09-4. Batch Normalization  (0) 2020.03.12
    09-2. Weight initialization  (0) 2020.03.10
    09-1. ํ™œ์„ฑํ™” ํ•จ์ˆ˜(Activation function)  (0) 2020.03.07
    08. Perceptron  (0) 2020.03.03

    ๋Œ“๊ธ€