AIML/딥러닝 최신 트렌드 알고리즘

[ 딥러닝 최신 알고리즘 - PRMI Lab ] - KD: Knowledge Distillation

Hyunseo😊 2024. 1. 8. 23:20

이전에는 ViT의 논문을 리뷰하고 이에대한 코드를 짜보고 Pre-train과 fine-tuning까지 해보았습니다. 하지만 ViT의 고질적인 문제인 데이터 효율적이지 못하다는 점이 가장 아쉬웠습니다. 이에, 데이터 효율적인 ViT인 DeIT를 리뷰하려고 했습니다. 다만, 그 전에 DeIT에서 Knowledge Distillation과 관련된 사전 지식을 요해서 이와 관련된 내용을 논문과 함께 간단히 정리하고 가면 좋을거 같다고 생각했습니다. 해당 논문은 딥하게 파고들지 않고, 그냥 이런 개념이 있구나~ 정도로만 살펴보도록 하겠습니다. ( 사실 KD가 Nosiy Student Model기반 모델과 유사하다는 느낌이 들어서 얼른 이것도 알아보고 싶거든요!)

 

https://arxiv.org/pdf/1503.02531.pdf

 

초록에서는 기존에 머신러닝에서 모델의 성능을 높이는 방법은, 앙상블이였다고 소개하고 있습니다. 하지만 이러한 앙상블 모델은, 변덕(cumbersome)이 심하고 계산량이 너무 많다는 단점이 있답니다. 계산량이 많으니까 큰 앙상블 모델을 배포하려는 입장에서는 이걸 사용자에게 모바일 환경 같은 곳에서 써보라고 하기엔 부적절할 것입니다. 

 

"Model compression. In Proceedings of the 12th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, KDD"에서는 앙상블에서 지식을 하나의 모델로 압축할 수 있음을 증명해냈습니다. 하지만 해당 논문에서는 이러한 방법과는 다르게, 근본적으로 큰 모델을 작은 모델로 옮기는 Knowledge Distilling 방법을 MNIST 데이터를 사용해 증명하며 살펴보겠다고 합니다!

 

Soft Label

KD의 대략적인 구조

 

Image Classification task에서 신경망의 마지막 벡터를 softmax층에 투입시켜 각 클래스의 확률 값을 뱉어내게 합니다. 이에 기존의 original hard targets는 아래와 같은 그림의 레이블을 말합니다.

original hard targets

실제로 모델이 뱉어내는 값은 아래와 같을 것입니다.

model's output

클래스 마다 다른 확률값을 뱉어낼 것입니다. 가장 높은 출력값인 0.9인 "개"클래스를 예측하게 되는 구조입니다. 하지만 Hinton교수는 예측한 클래스 이외의 값을 주의 깊게 보았습니다. 개를 제외한 고양이, 자동차, 젖소등의 확률을 본것입니다.

 

Hinton교수는 이러한 확률도 해당 모델의 지식이 될 수 있다고 합니다. 우리가 예측하려는 이 사진이 강아지인것은 알겠지만, 자동차나 젖소보다 고양이에 더 가까운 형태를 띠고 있구나 라는 것을 모델이 함축하고 있다는 것이죠. 하지만 이러한 값들은 softmax에 의해 너무 작아져 모델이 반영하기가 어려울 것입니다.

 

따라서 출력값의 분포를 좀 더  soft하게 만들면, 이 값들이 모델이 가진 지식이라고 볼 수 있을거 같습니다. 이게 바로 Knowledge Distillation의 아이디어입니다. 이를 논문에서는 dark knowledge라고 표현합니다.

softmax with T(Temperature)

dark knowledge를 만들기 위해 $T$를 분모에 곱해주었습니다. 이를 통상 온도(temperature)라고 표현합니다. 이 값이 높아지면 더 soft해지고 낮아지면 더 hard하게 만드는 것을 유추할 수 있습니다. 이 $T$라는 값때문에 증류(Distillation)라는 과정이 나오게 된 것으로 판단됩니다.

 

Distillation Loss

KD의 방법과 이를 적용하기 위한 loss식입니다. 사실 위 그림에 다 나와있긴 하지만 첨은을 통해 이해를 돕겠습니다. 우선 큰 모델인 Teacher모델을 Big Data set에 대해 Pre-train시켜버립니다. 그리고 우리는 해당 지식을 Distillation하며 Student모델을 손실함수로 학습시켜 나갈 것입니다.

 

위 식에서 $L$은 손실함수, $S$는 Student model, $T$는 Teacher model, $(x,  y)$는 하나의 이미지와 그 라벨, $\theta$는 모델의 학습가능한 파라미터, $\tau$는 temperature를 의미합니다.

 

$L_{CE}$는 그냥 hard prediction과 hard label y에 대한 Cross Entropy Loss입니다. 그리고 $L_{KD}$는 soft labels와 soft predictions을 통해 계산된 Distillation loss를 의미합니다. 이는 구하는 방법이 많지만, 종종 이와같은 Response Distillation 방식은 확률분포의 차이를 최소화 시켜야하므로 KL Divergence로 계산되곤 합니다. 이때, T와 S모델은 $\tau$인 temperature를 동일하게 설정해주어야 합니다!

 


 

이런 KD는 다양한 방법이 있다고 합니다. 또한, 요즘에는 BERT와 같은 무거운 모델을 경량화하고 배포하기 위해 KD를 사용하는 추세라고 합니다. 이렇게 간단히 논문을 보고 리뷰해보니까, ViT라는 무거운 트랜스포머 기반 모델을, KD와 약간의 data-augmentation을 통해서 더 data-efficient하게 즉 더 적은 데이터로 효율적이게 모델을 만들어버릴 수 있겠다라는 생각이 듭니다. 다음에는 DeIT에 대해 리뷰해보도록 하겠습니다.