📚STUDY/🔥Pytorch ML&DL

Skip connection에서 add(summation) vs concatenation

해는선 2021. 11. 22. 18:18

 U-Net architecture를 공부하다가 어떤 네트워크는 image size와 channel이 동일해서 add를 하기도 하고, 어떤 네트워크는 image size는 동일하지만 channel이 달라 (보통 2배 차이남) concatenation해 주기도 한다.

 

이 둘의 차이가 뭔지 궁금해서 검색해 보았고, 이를 정리했다.

 

 

What are Skip Connections?

먼저 skip connection이란, 단순한 연결 작업이다. 위의 사진은 FCN network를 간략하게 표시한 것인데, 이 네트워크는 step을 거칠 수록 coarse(거친) 정보들 부터 fine(고차원적인) 정보들을 담게 된다. 그리고 마지막에 바로 큰 image로 키우게 되는데, 이때 고차원적인 정보들만 담고 있는 feature map으로 upsampling하다 보니 coarse한 정보들이 빠지게 된다. 

이때, coarse한 정보들도 넘겨주기 위해 skip connection을 사용하게 된다. 즉, 하나의 network가 low level feature와 high level feature를 조합해서 학습하게 된다.

 

 

Why are skip connections important?

1. short skip connections : deep한 network를 학습할 때, 뒤로 갈 수록 사라지는 gradient를 살리거나 gradient vanishing (기울기 소실) 문제를 해결할 수 있다. 대표적인 예로 Resnet이 있다. 

 

2. long skip connections : downsampling 중 손실되는 spatial 정보를 복구하는데 도움이 된다. (where)

 

3. 둘 중 하나라도 사용한다면 convergence time을 줄일 수 있다!

 

 

When to add and when to concatenation?

주로 short skip connection에는 element-wise add(summation)을 사용한다. 짧고 반복적인 estimation procedure이 주를 이루고, network의 다양한 layer를 모두 거치기 때문이다. 보통, 전체 block에 대해 고정된 feature 수를 유지해서 compact한 solution에 이용된다.

ex) resnet

 

long skip connection을 사용하는 경우는, 보통 subsequent layer(후속, 뒤에 나오는 레이어)가 middle representation (중간 레이어가 가지는 feature)를 재 사용해서, coarse한 정보 (공간적인 정보)를 얻기를 원할 때 이다. (feature re-use)

 

ex) densenet

 

U-net의 경우, 대부분의 architecture가 concatnation을 이용한다. long skip connection이기 때문이다. 애초에 원본 2D U-net이 concatnation을 사용한다.

여기서는 contracting path와 expansive path의 input size가 조금씩 달라 expansive path의 image size에 맞춰서 crop후, concatnation 해 주고 있다. 이런 size 차이를 없애고 싶다면 convolution 연산에 stride = 1, padding = 1을 주면 된다.

 

물론 다양한 u-net variation으로 concatnation이 아닌 summation을 한 경우도 존재한다. 

https://theaisummer.com/unet-architectures/

 

summation을 위해서는 encoder path의 image와 decoder path image의 size, depth가 모두 맞아야 하기에 (아니면 crop하던지) 대부분은 concatnation을 사용하는 것 같다.

 

 

(잘못된 지식이거나 올바른 답을 알고계시는 분은 답글남겨주세요,,, 완전 환영합니다)

 

 

 

 

Reference

https://medium.com/@mikeliao/deep-layer-aggregation-combining-layers-in-nn-architectures-2744d29cab8

 

Deep Layer Aggregation — Combining Layers in NN Architectures

I’ll be working to explain Deep Layer Aggregation, a neural network architecture that explores how best to aggregate layers across a…

medium.com

https://stackoverflow.com/questions/49164230/deep-neural-network-skip-connection-implemented-as-summation-vs-concatenation

 

Deep neural network skip connection implemented as summation vs concatenation?

In deep neural network, we can implement the skip connections to help: Solve problem of vanishing gradient, training faster The network learns a combination of low level and high level features Re...

stackoverflow.com

https://lswook.tistory.com/105

 

Skip connection 정리

Skip connection이란? deep architectures에서 short skip connections[1]은 하나의 layer의 output을 몇 개의 layer를 건너뛰고 다음 layer의 input에 추가하는 것이다. 이는 VGG[2]같은 기존의 model이 output만..

lswook.tistory.com