본문 바로가기
AI/Machine Learning

K-fold 교차 검증

by Mesut Özil 2023. 10. 29.

K-fold 교차 검증

K-fold cross-validation기계 학습 모델성능을 평가하기 위한 기술 중 하나입니다.

이 방법은 주어진 데이터를 훈련 세트검증 세트나누어 모델을 평가하는 데 사용됩니다.

 

  1. 주어진 데이터셋을 K개의 서브셋(또는 폴드)으로 나눕니다.
  2. 모델을 K번 반복하여 훈련하고 검증합니다.
  3. 각 반복에서 하나의 폴드를 검증 데이터로 사용하고, 나머지 폴드를 훈련 데이터로 사용합니다.
  4. 각 반복에서 모델의 성능 지표(예: 정확도, 손실 등)를 기록합니다.
  5. K번의 반복 후,  이 성능 지표들을 평균을 계산하여 최종 성능을 얻습니다.

 

 

예시 코드

from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score

# 데이터 불러오기
data = load_iris()
X = data.data
y = data.target

# K-fold cross-validation을 위한 객체 생성 (K를 5로 설정)
kfold = KFold(n_splits=5, shuffle=True)

# Logistic Regression 모델 생성
model = LogisticRegression()

# K-fold cross-validation 수행
fold_accuracies = []
for train_index, test_index in kfold.split(X):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    
    # 모델 학습
    model.fit(X_train, y_train)
    
    # 예측
    y_pred = model.predict(X_test)
    
    # 정확도 계산
    accuracy = accuracy_score(y_test, y_pred)
    fold_accuracies.append(accuracy)

# 각 폴드별 정확도 출력
for i, accuracy in enumerate(fold_accuracies):
    print(f"Fold {i+1} accuracy: {accuracy}")

# 전체 폴드의 평균 정확도 계산
mean_accuracy = np.mean(fold_accuracies)
print(f"\nMean accuracy: {mean_accuracy}")

 

이 코드는 Iris 데이터셋을 사용하여 Logistic Regression 모델을 평가하는 K-fold cross-validation을 보여줍니다.

5개의 폴드데이터를 나누고, 각 폴드에서 모델을 훈련하고 테스트하여 정확도를 계산합니다.

마지막으로 각 폴드의 정확도와 전체 평균 정확도를 출력합니다.

 

가장 일반적인 K 값5 또는 10이지만, 더 작거나 더 큰 K 값도 사용될 수 있습니다.

K-fold cross-validation은 모델의 일반화 능력을 평가하고 과적합(Overfitting)을 방지하는 데 유용합니다.

특히 데이터가 적은 경우에는 K-fold cross-validation을 사용하여 모델의 안정성신뢰성을 높일 수 있습니다.

 

이 방법을 사용하면 전체 데이터를 훈련 및 검증에 사용할 수 있으며,

한 번의 단일 훈련 및 검증 과정보다 더 신뢰할 수 있는 모델 성능 지표를 얻을 수 있습니다.

K-fold cross-validation은 모델의 성능을 평가하고 하이퍼파라미터 튜닝 등에 많은 도움을 줍니다.

 

 

 

본 게시글은 ChatGPT의 도움을 받아 작성하였습니다.

 

'AI > Machine Learning' 카테고리의 다른 글

분류 모델 / 회귀 모델  (0) 2024.01.17
Ensemble (앙상블)  (0) 2024.01.10
다중공선성 (VIF, 분산팽창계수)  (0) 2023.12.18
Metrics (모델의 성능 지표)  (2) 2023.11.19
라벨 인코딩, 원핫 인코딩  (0) 2023.11.10