본문 바로가기

데이터사이언스/인공지능

간단한 신경망 만들기(사인함수 예측하기)

이미지를 딥러닝을 이용해 인식하는 모델은 천만에서 1억 개가 넘는 가중치를 가지며 자연스러운 문장을 생성하는 OpenAI GPT-3 언어모델은 1,750억개까지의 가중치를 갖는다. 이러한 복잡한 딥러닝모델을 설계하기 전에 연습삼아 간단한 신경망모델을 만들어보자.

 

 - 파이토치에서 '모듈'은 신경망을 구성하는 기본 객체이다. 모듈에는 구성요소를 정의하는 __init__()함수와 순전파의 동작을 정의하는 forward()함수가 있다.

 - 간단한 신경망은 nn.Sequential, 복잡한 신경망은 nn.Module을 이용한다.

 - MSE(평균 제곱 오차)는 회귀, CE(크로스 엔트로피)는 분류에 이용한다.

 - 피처는 신경망의 입력으로 들어오는 값으로 데이터가 갖고 있는 특징이다.

 - 배치: 데이터셋의 일부로 신경망의 입력으로 들어가는 단위

 - 에포크: 전체 데이터를 모두 한번씩 사용했을 때의 단위

 - 이터레이션: 하나의 에포크에 들어있는 배치 수

 

1. 사인함수 예측하기

 - ① 모델을 학습하기 위해 데이터를 불러온 다음 원하는 만큼 반복해서 모델을 학습

 - ② 불러온 데이터의 예측값계산(데이터가 입력층에서 출력층 방향으로 흘러가기 때문에 순전파라고함)

 - ③ 모델의 손실함수를 이용해 오차계산(여기서는 MSE이용)

 - ④ 오차를 역전파해서 모델의 가중치를 수정

파이토치의 학습과정

 

import math #수학 패키지 임포트
import torch #파이토치 모듈 임포트
import matplot.pyplot as plt 

x = torch.linspace(-math.pi, math.pu, 1000) #-pi부터 pi 사이에 점을 1,000개 추출
# linspace(A, B, C): 시작점 A부터 종료점 B까지 데이터를 C개 반환한다
y = torch.sin(x) # 실제 사인곡선에서 추출한 값으로 y 만들기

a = torch.randn(())
b = torch.randn(())
c = torch.randn(()) # 예측 사인곡선에 사용할 임의의 가중치(계수)를 뽑아 y 만들기
d = torch.randn(())

y_random = a * x**3 + b * x**2 + c * x + d #사인함수를 근사랄 3차 다항식 정의

learning_rate = 1e-6 #학습률 정의

for epoch in range(2000): #학습 2000번 진행)
    y_pred = a * x**3 + b * x**2 + c * x + d
    
    loss = (y_pred - y).pow(2).sum().item() #손실정의 .item()은 실수값으로 반환하라는 뜻
    if epoch % 100 == 0: # 100번마다 손실을 출력해라
        print(f"epoch{epoch+1} loss:{loss}")
        
    grad_y_pred = 2.0 * (y_pred-y) #기울기를 계산하는 부분(가중치는 업데이트하는 데 사용되는 손실값을 미분함)
    grad_a = (grad_y_pred * x **3).sum()
    grad_b = (grad_y_pred * x **2).sum()
    grad_c = (grad_y_pred * x).sum()
    grad_d = grad_y_pred.sum()
    
    a -= learning_rate * grad_a #가중치는 기울기의 반대반향으로 움직인다. 
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d
    
    
plt.subplot(3, 1, 1)
plt.title('y true')
plt.plot(x, y)

plt.subplot(3, 1, 2) # 예측한 가중치의 사인곡선을 그리기
plt.title('y pred')
plt.plot(x, y_pred)

plt.subplot(3, 1, 3) # 학습이 이루어지기 전에 만든 그래프
plt.title('y random')
plt.plot(y_random)

plt.show()

 

참고: Must Have 텐초의 파이토치 딥러닝 특강