undersampling, oversampling, loss function, weight sampling
불균형 데이터를 다루는 방법에 대해 알아보자
technique
Author
dotpyo
Published
October 20, 2023
Summary
under sampling : 너무 많은 양의 데이터를 잘라내는 방식
over sampling : 적은 양의 데이터를 증강하는 방식
weight sampling : 학습할 배치에 데이터가 들어갈 확률을 지정하는 방식
loss function : 적은 양의 데이터 학습과정에 가중치를 주는 방식
데이터의 절대량을 조정하는 방법과 (1, 2) 학습할 때 데이터의 균형을 맞추는 방법이 (3, 4) 있다.
이들 중 weight sampling 방법과 imbalanced data task에 적합한 loss function를 알아보겠다.
Weight sampling
배치(batch)크기는 하이퍼파라미터(hyperparameter)의 한 종류로 한 번 기울기를 갱신할 때(step) 사용하는 데이터의 개수를 말한다. 배치는 미니배치(mini batch)라고도 불리며 \(2^n\) 개로 구성된다. 이 때 배치를 구성하는 방식을 샘플링(Sampling)이라고 하는데 Weight Sampling은 배치를 구성하는 데이터를 각각 다른 확률에 따라 추출하는 샘플링 방식이다. 따라서 데이터의 절대적인 개수가 작아 뽑힐 확률이 적은 데이터에게 가중치를 주어 더 자주 뽑힐 수 있게 조정하는 과정을 거칠 수 있다.
예를 들어, 아래 표와 같은 데이터가 있을 때, c가 뽑힐 확률은 0.1, d가 뽑힐 확률은 0.4로 d가 뽑힐 확률이 네배 더 크다. 불균형 데이터(imbalanced data) 에서는 치명적으로, 한 배치에 뽑힌 데이터가 모두 한 label로 구성될 가능성이 있기 때문이다. 아래 표에서 배치가 32라고 할 때, 배치를 구성하는 label이 모두 d라면 모델은 균형있는 학습을 하지 못하게 되거나 d에 과적합 될 수 있다. 그러므로 전체 데이터가 불균형하더라도 배치 안에서는 균형있는 학습을 진행하기 위해 torch.utils.data.WeightedRandomSampler 메서드를 사용한다.
label
count
a
30
b
20
c
10
d
40
WeightedRandomSampler 를 사용하면 코드 셀과 같은 결과를 확인할 수 있다. 첫번째 예제에서는 index 1의 가중치가 0.9로 가장 크며 복원추출(replacement=True)한 결과 역시 1이 세번으로 가장 많이 추출된 것을 확인 할 수 있다.
다른 방법으로는 학습 과정에서 가중치를 주는 방법이 있다. 모델이 문제를 풀 때 해당 문제가 쉬운지 어려운지는 어떻게 판별할까? 분류 문제에서는 최종 확률값으로 문제의 난이도를 판별한다. 이것 같기도 하고, 저것 같기도 해서 헷갈리니 각 label이 답일 확률이 비슷비슷하게 높은 것이다. 따라서 \(logit\)\(^1)\) 값의 평균은 낮을 수밖에 없다. 답을 결정하는 최종 확률은 그 중 가장 높은 값을 고른 것이니 최종 확률이 낮을수록 어려운 문제다.
문제가 어려운 문제인지 아닌지 알아야 하는 이유는 여기에 있다. focal loss는 불균형 데이터 문제를 해결할 때 대표적으로 쓰이는 손실함수로, 쉬운 문제를 틀렸을 때엔 작은 loss 값을, 어려운 문제를 틀렸을 때엔 큰 loss 값을 반환한다. 데이터가 적어 상대적으로 잘 학습하지 못한 label은 틀렸을 때 모델의 성능에 크게 영향을 미치게 되므로 학습 과정에서 가중치를 준다고 생각할 수 있다. 그렇다면 focal loss의 최대값은 어떻게 될까? focal loss는 기본적으로 연산한 loss에서 난이도만큼 값을 ‘깎는’ 원리이므로 focal loss의 최대값은 기본 손실값과 같을 것이다.
이 외에도 기존 손실함수에 가중치를 줄 수 있는데 f1, cross entropy, fbeta, accuracy 등의 함수가 그러하다. 해당 함수들의 ‘average’ 인자값에 ’weighted’를 주면 가중된 손실이 누적된다. 이렇게 가중된 손실함수를 여러개 사용하면 모델의 성능이 개선될 수 있다.