손실 함수
손실 함수 loss function
- 비용 함수 cost function 또는 오차 함수 error function라고도 함
- 예측과 실제의 차이를 계산
- 손실 ≥ 0
- 손실을 최소화할 수록 성능이 개선
MSE mean squared error
- 오차 제곱의 평균(오차 = 실제 – 예측)
- 연속변수의 예측에 사용
- 이상치(outlier)에 민감
torch.nn.functional.mse_loss
교차 엔트로피 cross entropy
- 두 확률 분포의 차이를 계산
- 범주형 변수의 예측에 사용
- p와 q 두 확률 분포가 비슷할 수록 작아짐
- 높은 확률로 예측했을 때 맞고, 낮은 확률로 예측했을 때 틀려야 감소
- 다항 분류:
torch.nn.functional.cross_entropy
- 이항 분류:
torch.nn.functional.binary_cross_entropy
우도 likelihood
- 어떤 모형을 가정했을 때, 우리가 가진 샘플의 데이터가 관찰될 가능성
- 우도가 높으면, 우리의 가정이 맞다고 생각할 수 있다 (최대우도법)
- 우도에 로그를 씌운 것이 로그 우도
- 마이너스 로그 우도 = 교차 엔트로피
- 교차 엔트로피의 최소화는 최대우도법과 동일
데이터 로더
from torch.utils.data import DataLoader
BATCH_SIZE = 32
train_loader = DataLoader(
binary_train_dataset, # 훈련 데이터
batch_size=BATCH_SIZE, # 32개씩
shuffle=True) # 섞어서(미니배치마다 조합이 다양하도록)
훈련
from pytorch_lightning import Trainer
model = BinaryClassifier()
logger = pl.loggers.CSVLogger("csv_logs", name="binary_mnist")
trainer = Trainer(max_epochs=1, logger=logger, log_every_n_steps=1)
trainer.fit(model, train_loader)
훈련 기록 보기
import glob
import pandas as pd
import matplotlib.pyplot as plt
# 마지막 실행 기록
last_version = sorted(glob.glob('csv_logs/binary_mnist/version_*'))[-1]
log = pd.read_csv(f"{last_version}/metrics.csv")
plt.subplot(1, 2, 1)
plt.plot(log['loss'])
plt.subplot(1, 2, 2)
plt.plot(log['accuracy'])
(그래프: 왼쪽은 손실(loss)이 감소하는 추세, 오른쪽은 정확도(accuracy)가 증가하는 추세를 보여줌)