[PyTorch] 장단기 메모리 (Long Short-Term Memory, LSTM)
Posted: Updated:
AI 스터디를 하며 ‘파이토치 트랜스포머를 활용한 자연어 처리와 컴퓨터비전 심층학습’ 교재를 정리한 글입니다.
장단기 메모리 (Long Short-Term Memory, LSTM)
메모리 셀(Memory cell), 게이트(Gate) 구조를 도입해 장기 의존성 문제와 기울기 소실 문제 해결
망각 게이트, 입력 게이트, 출력 게이트를 통해 어떤 정보를 버릴지, 기억할지, 출력할지에 대한 연산 수행
활성화 함수로 시그모이드를 사용하여 게이트가 입력값에 대해 얼마나 많은 정보를 통과시킬지 결정
lstm = torch.nn.LSTM( # 활성화함수 매개변수 없음
input_size,
hidden_size,
num_layers=1,
bias=False,
batch_first=True,
dropout=0,
bidirectional=False,
proj_size=0 # 투사 크기: hidden_size보다 작게 설정, 장단기 메모리 계층의 출력에 대한 선형 투사 크기 결정, 0보다 크면 선형 투사로 은닉 상태를 다른 차원에 매핑, 0이면 유지
)
장기 의존성 문제 (Long-term dependencies)
순환 신경망에서 앞선 시점에서의 정보를 끊임없이 반영해 학습 데이터 크기가 커질수록 앞서 학습한 정보가 충분히 전달되지 않을 수 있음
망각 게이트 (Forget gate)
이전 셀에서 어떤 정보를 삭제할지 결정하는 역할
현재 시점의 입력과 이전 시점의 은닉 상태를 입력으로 받아 시그모이드 함수를 거친 값과 메모리 셀을 곱한 값으로 이전 시점의 메모리 셀 업데이트
출력 값은 0에서 1사이이며, 1에 가까울수록 많은 정보를 유지하고, 0에 가까울수록 많은 정보를 삭제
$f_t$: 시그모이드 활성화 함수로 계산
$W_x^{(f)}$, $W_h^{(f)}$: 입력값, 은닉상태를 위한 가중치
$b^{(f)}$: 망각 게이트의 편향
기억 게이트 (입력 게이트, Input gate)
현재 시점 입력과 이전 시점 은닉 상태를 입력받아 시그모이드 함수를 거친 값과 하이퍼볼릭 탄젠트 함수를 거친 값의 곱으로 새로운 기억 값 계산
시그모이드 함수는 입력값의 중요도 결정
하이퍼볼릭 탄젠트 함수는 입력값을 -1과 1 사이의 값으로 변환하여 어떤 정보를 얼마나 추가할지 결정
$g_i$: 하이퍼볼릭 탄젠트 사용, -1에서 1 사이의 값을 가짐
$i_t$: 시그모이드 사용, [0, 1]의 값을 가지며 현재 시점에서 얼마나 많은 정보를 기억할 것인지 결정하는 가중치 역할, 0에 가까울수록 정보 망각
메모리 셀
망각 게이트, 기억 게이트 정보로 현재 시점 메모리 셀 값 계산
\[c_t = f_t \odot c_{t-1} + g_t \odot i_t\]$f_t$: 망각 게이트 출력값
$c_{t-1}$: 이전 시점의 메모리 셀 값
두 값을 원소별 곱셈 연산인 아다마르 곱(Hadamard Product)하여 현재 시점의 메모리 셀 값 계산
망각 게이트 값 $f_t$가 0에 가까울수록 이전 시점 메모리 셀 값이 현재 시점의 메모리 셀 값에 영향을 미치지 않게 됨
출력 게이트 (Output gate)
현재 시점의 입력과 이전 시점의 은닉 상태, 새로 업데이트된 메모리 셀을 입력으로 받아 현재 시점의 출력값 계산
\[o_t = \sigma \left( W_x^{(o)} x_t + W_h^{(o)} h_{t-1} + b^{(o)} \right)\]은닉 상태 갱신
출력 게이트와 현재 메모리 셀을 연산하여 새로운 은닉 상태 계산
\[h_t = o_t \odot \tanh(c_t)\]$h_t$: 출력 게이트와 하이퍼볼릭 탄젠트를 적용한 메모리 셀 값으로 계산
출력 게이트는 [0, 1] 범위, 하이퍼볼릭 탄젠트를 적용한 메모리셀은 [-1, 1] 범위를 가지므로 아다마르 곱 연산하면 이전 시점의 은닉 상태에 얼마나 영향 받는 지 계산할 수 있음
양방향 다층 장단기 메모리
import torch
from torch import nn
input_size = 128
ouput_size = 256
num_layers = 3
bidirectional = True
proj_size = 64
model = nn.LSTM(
input_size=input_size,
hidden_size=ouput_size,
num_layers=num_layers,
batch_first=True,
bidirectional=bidirectional,
proj_size=proj_size,
)
batch_size = 4
sequence_len = 6
inputs = torch.randn(batch_size, sequence_len, input_size)
h_0 = torch.rand(
num_layers * (int(bidirectional) + 1),
batch_size,
proj_size if proj_size > 0 else ouput_size,
)
c_0 = torch.rand(num_layers * (int(bidirectional) + 1), batch_size, ouput_size)
outputs, (h_n, c_n) = model(inputs, (h_0, c_0))
print(outputs.shape)
print(h_n.shape)
print(c_n.shape)
출력
torch.Size([4, 6, 128])
torch.Size([6, 4, 64])
torch.Size([6, 4, 256])
댓글남기기