본문 바로가기
딥러닝/기타 리뷰

[Optimizer]AdamW

by 혜 림 2022. 2. 10.

 

 

"Decoupled Weight Decay Regularization"

 

 

[혜림의 말로 요약하기] Adam은 일반화 성능이 안 좋다. SGD에서는 weight decay와 L2 정규화가 (수식에서) 똑같은 의미를 갖는 반면에, Adam 에서는 다르기 때문이다. Adam에서는 L2 정규화의 효과가 weight decay에 비해서 떨어진다. 따라서 weight decay를 Adam 과 함께 이용하기 위해 고안한 것이 AdamW 이다. weight decay를 더하는 식을 가중치 업데이트 식에 별도로 넣어줌으로써, weight decay와 learning rate 하이퍼 파라미터가 정규화에 주는 영향을 decouple 한 방법이다. 

 

 

1. Introduction

 

 Adam 이 좋은 optimizer로 알려져 있지만, 사실 CIFAR-10 등을 이용한  image classification의 SOTA 모델은 주로 SGD를 이용하고 있다. 나아가, momentum(관성) + SGD에 비해서 Adam은 일반화가 잘 안 된다는 연구 결과도 꽤 나와있다. 그럼에도 불구, Adam 이 좋은 optimizer로 알려져 있는 것은 그 속도 때문이다. 실질적으로 solution도 좋은지에 대한 결과는 없다고 한다. 

 

* generalization : 새로운 test 데이터 셋이 들어왔을 때 얼마나 error 없이 predict 할 수 있는가? 즉 train 데이터와 사뭇 다른 test 데이터가 들어왔을 때도 모델은 여전히 적용가능한가. train 데이터 셋으로 학습한 모델이 일반화 가능한가. 

 

 그렇다면 Adam과 같은 Adaptive gradient method 의 일반화 성능은 왜 떨어지는가? Adam의 일반화 성능이 떨어지는 것은 L2 정규화가 SGD와 나는 시너지가 L2 정규화와 Adam이 내는 시너지에 비해 크기 때문이다. 따라서 Adam의 일반화 성능을 끌어올리기 위해서는 이 차이를 간소화시키면서, 어떻게 L2 정규화와 weight decay를 SGD와 Adam에 잘 버무릴 것인지 에 대해 고민해볼 필요가 있다. 

 

 1. L2 정규화와 weight decay는 다르다. 
 2. L2 정규화는 Adam에 효과적이지 않다.
 3. weight decay는 SGD와 Adam 과 동일한 시너지를 낸다 : L2 정규화는 그렇지 않다. 
 4. 최적의 weight decay 값은 가중치 업데이트된 횟수에 영향을 받는다. 
 5. Adam은 scheduled learning rate 를 통해 시너지를 낸다. 

 

 

그 고민에 대한 해답으로, 본 논문에서는 Adam이 더 일반화가 잘 될 수 있도록,  weight decay와 L2 정규화를 분리하여 반영할 수 있도록 가중치 업데이트 방식을 바꾸어주었다. (수식을 보면 명확하다!)

 

2. Decoupling the weight decay from the gradient base update

 

 Proposition 1 weight decay = L2 reg for standard SGD 

 proposition 2 weight decay != L2 reg for Adaptive gradients

 

By proposition 1 & 2:

 

 

 

 

 

3. JUSTIFICATION OF DECOUPLED WEIGHT DECAY VIA A VIEW OF ADAPTIVE GRADIENT METHODS AS BAYESIAN FILTERING

 

 베이지안 필터링을 직관적으로 응용한 것이 wegiht decay이기 때문에, weight decay를 이용한 것이 L2 정규화에 비해 결과가 좋을 수 밖에 없다고 한다. 이 주장은 다른 논문의 주장을 가져온 것으로, 이 블로그에서는 저 그렇구나~ 하는 것으로 넘어가도록 하자 ^^

 

4. Experimental Validation

 

4.1 Evaluating Decoupled Weight Decay With Different Learning Rate Schedules

 

Figure 1

 

 위 그림에서 X축은 weight decay/L2 regularization과 관련된 파라미터 값, Y축은 LR과 관련된 파라미터 값이다. 그리고 오른쪽에 heatmap bar로 나와있는 부분이 있는데, 붉을수록 test error 가 큰 것이고 파랄수록 test error가 적은 것이다. Adam의 generalization 성능에 대해서 말하는 것이기 때문에, test error를 기준으로 Adam과 AdamW, cosine annealing을 이용한 AdamW를 비교한다. 

 가장 오른쪽, 아래에 있는 그림의 파란 영역이 가장 넓은 것을 확인할 수 있다. 즉, L2 정규화 대신 weight decay를 썼을 때, 그리고 LR이 주기적으로 상승했다 감소하는 cosine annealing과 함께 AdamW를 썼을 때 가장 test error 가 적다는 것이다. 

 위 그림 때문에 여러 하이퍼 파라미터 값에도 성능의 변동이 크지 않다는 의미에서 '일반화' 를 말한다고 생각할 수 있지만, 본 논문에서는 test error를 일반화의 성능 척도로 보고 있다. test error 인 즉슨, train 데이터 셋에서 나오지 않았던 data에 얼마나 잘 예측하냐이기 때문에, 앞에서 언급했던 일반화로 계속 이해하는 게 맞을 것 같다. *아닐 수 있고 

 

=>  Adaptive gradient algorithm의 경우 이미 정의 자체로서 고정되어 있지 않고 변하는 learning rate를 사용하기는 하지만, cosine annealing schedule과 같은 learing rate scheduler를 썼을 때 그 성능이 더욱 개선된다. 

 

 

4.2 Decoupling the Weight Decay and Initial Learning Rate Parameters

 

[Figure 2]

weight decay 와 learning rate  하이퍼 파라미터인 람다와 알파를 decouple 하는 게 의미가 있을지 비교하기 위한 시험을 진행하였다. 

가장 첫번째 그림을 살펴보자. SGD에 L2 정규화를 이용한 경우다. 점은 test error 가 가장 작았던 지점을 의미한다. 현재 diagonal 하게 점이 찍혀 있는 것을 알 수 있다. 즉, 두 하이퍼 파라미터가 coupling 되어 있어서 하나의 파라미터를 변하면 바로 결과가 안 좋게 나타난다. 따라서 독립적으로 두 하이퍼 파라미터를 바꾸면서 실험을 진행하는 것이 불가능하다. 따라서 성능을 개선하기 위해서는 반드시 두 파라미터를 함께 변화시켜야 한다. 이런 것 때문에 SGD 는 초기 learning rate와 L2 정규화 가중치 설정에 sensitive 하다고 알려진 듯하다고 논문에서는 언급하고 있다.

 반면 대조적으로, SGDW를 보자. 이는 weight decay와 initial learning rate가 분리된 경우이다. learning rate가 최적화 되어 있지 않더라도, 그걸 고정시킨 채로 weight decay 하이퍼 파라미터를 변화시키면서 우리는 최적의 경우를 구할 수 있다. 

 

밑의 두 그림 역시 유사한 맥락으로 해석 가능하다. 

 

 => 요약하자면, wegith decay와  learning rate 하이퍼 파라미터를 decouple 하는 것은 Adam의 성능을 SGD 만큼 올릴 수 있게 하는 중요한 factor 였다. 

 

4.3 Better Generalization of AdamW

 

위에서는 최적의 파라미터의 basin이 AdamW의 경우 더 넓다는 것을 증명한 그림이라고 한다. 그리고 이번 섹션에서는 더 깊게 훈련을 하면서 Adam과 AdamW의 일반화 성능을 비교해보자.

 

[Figure 3]

 초기에는 Adam과 AdamW의 loss 곡선이 유사하게 출렁거린다. 하지만 결국은 AdamW의 loss가 더 낮게 나온다. 

 

 그러나 위 결과만으로 AdamW 가 일반화를 잘 한다고 결론 내리기에는 성급하다. 그냥 수렴을 더 잘 하는 것 때문일 수도 있기 때문이다. 오른쪽 아래의 그림을 보도록 하자. 동일한 train loss 에도 불구하고 AdamW의 test error가 더 낮은 경향을 보인다는 것을 확인할 수 있다. 따라서 더 나은 일반화 성능을 가진다고 믿는 것이 타당하다. 

 

4.4 AdamWR With Warm Restarts for Better Anytiem Performance

 

 SGDW와 AdamW의 성능을 높이기 위해서는 warm restart를 함께 사용하면 된다. 아래의 Figure 4에서 볼 수 있듯이, Warm Restart를 쓴 경우 test error가 epoch이 얼마 되지 않을 때도 낮음을 확인할 수 있다. [각주:1] test error가 낮아졌다가 다시 튀는 부분은 learing rate가 다시 높게 튀는 부분이 아닐까 싶다. 

 

Figure 4

 

*warm restart는 learing rate를 epoch에 따라 줄이다가 어느 순간에 확 늘리는 것을 의미한다. 단순히 선형적으로나/단계적으로 감소하는 것과 달리 이 방법을 이용하게 되면 local minimun에 빠질 가능성이 줄어든다. 

4.5 Use of AdamW on other dataset and architectures

 

 face detector에 썼더니 SOTA의 정확도를 얻으면서 훨씬 빨리 학습할 수 있었다. EEG 데이터 분석에도 적용시킨 사례가 있다. 그 외에 다양한 모델과 함께 썼을 때도 좋은 성과를 냈다. 블라블라 같은 내용~

 

5. Conclusion and Futre Work

 

Adam에서는 L2 정규화와 weight decay가 동일하지 않다. 그래서 weight decay를 decouple 하였더니 훨씬 일반화가 잘되었다. 또한 warm restart를 같이 쓰면 성능이 개선된다는 것을 밝혀냈다.

 

Adam 말고도 다른 adaptive gradient method에도 AdamW와 비슷한 관점으로 새로운 시도를 한다면 흥미로울 것이다.

 

 

총총..

  1. AdamWR achieved the same improved results but with a much better anytime performance. [본문으로]

댓글