Skip to content

Learning to Remember Rare Events

hyerim1048 edited this page Mar 11, 2018 · 1 revision

Learning to Remember Rare Events

Abstract

이 논문은 잘 일어나지 않는 rare events의 정보를 어떻게 효과적으로 기억할 수 있을지에 대해 연구한 논문입니다. 그리고 여기서 제안한 모듈은 어떤 supervised NN 에도 쉽게 붙일 수 있는 유연한 모듈입니다.

Introduction

딥러닝으로 할 수 있는 것은 많지만, 잘 나타나지 않는 사건이나 단어들에 대해서 training 하는 것은 쉽지 않습니다. 딥러닝 모델에서는 흔하지 않은 데이터들을 학습시키기 위해서 더 많은 데이터를 모으고 다시 학습을 하는 방식을 사용합니다. 하지만 사람은 그런 데이터도 쉽게 배웁니다.

그래서 이 논문에서는 다양한 NN에서 one shot learning (적은 데이터를 가지고 잘 예측하는 것)을 가능하게 하는 a life-long memory module을 제안했습니다. 이 모듈은 key-value parir 를 사용합니다. key는 NN에서 우리가 선택한 layer의 activation Vector이고, value는 그 label 입니다.

Memory module

  • M = (Key, Value, Age)
  • K : memory-size x key-size
  • V : memory-size
  • A : memory-size

memory module은 matrix K 와 vector V, A 로 이루어진 dictionary 형태의 모듈입니다. memory query는 key-size 크기의 벡터이고 normalize되어 있습니다. (||q|| = 1)

  • NN(q,M) = argmax(i)(dot(q,K[i]))

the nearest neighbor of q in M 은 K의 각 row의 key vector 와 q를 dot product했을 때 가장 큰 값입니다. key들은 다 normalize되어 있기 때문에 cosine similarity를 구한거라고 보면 됩니다. 그리고 k nearest neighbors 는 NNk(q,M) 라고 정의합니다. 이 논문에서는 실험에서 k = 256 로 정의했습니다.

  • (n1, n2, ..., nk) = NNk(q,M)

메모리 모듈에서 k nearest neighbor를 코사인 유사도의 내림차순으로 정렬합니다. 그리고 가장 가까운 neighbor의 value를 V[n1] 이라고 정의합니다. 추가적으로 코사인 유사도 di = dot(q,K[ni]) 로 정의하고 t 를 softmax temperature의 inverse라고 했을 떄 (이 실험에서는 t = 40) softmax(dot(d1,t), dot(d2,t),.....dot(dk,t)) 가 있다고 해봅시다. 이 softmax값과 임베딩된 output을 곱해서 메모리 confidence 에 대한 단일 신호를 제공합니다.

Memory Loss

loss term 에서 각 항은 positive key와 negative key 각각과 query 사이의 코사인 유사도입니다. loss를 줄이면서 positive key 와의 유사도는 키우고 negative key와의 유사도는 줄이는 방향으로 학습하게 됩니다. 그리고 margin alpha값을 통해서 (alpha = 0.1 in all experiments) 충분히 그 거리가 멀면 loss를 전달하지 않도록 합니다. (다음의 논문을 참고한 방식입니다. the one in Schroff et al. (2015) and similar to many other distance metric learning works (Weinberger & Saul, 2009; Weston et al., 2011))

Memory Update

로그를 계산할 때 Memory (M) 도 계속 업데이트를 합니다. 그리고 positive key가 query와 정답셋이 같은지 다른지에 따라 업데이트 방식이 다릅니다.

n1 = NN(q,M) 이라고 했을 떄

  1. V[n1] = v 이면, 즉 query의 라벨과 같다면 K[n1] 은 다음과 같이 업데이트합니다.

Imgur

그리고 age A[n1] <- 0 으로 reset해줍니다.

  1. V[n1] != v 이면, 즉 query 라벨과 다르다면 어떻게 할까요? 이 때 age를 사용하는데 age가 제일 큰 값들 중 랜덤하게 선택을 하고, (LRU 방식) 그 item에 값을 write합니다. 식으로는 다음과 같이 표현합니다.

Imgur

위와 같이 업데이트가 끝난 후에 모든 업데이트 되지 않은 index의 age를 1씩 증가시킵니다.

Efficient nearest neighbor computation

모든 operation 중에 KNN을 계산하는 비용이 제일 높을 것입니다. 이 계산을 효율적으로 하기 위해서 matrix multiplication을 실행합니다.

Q : mini-batch of query (q1,....,qb) Q : batch-size * key-size K.T : key-size * memory-size (X.T 는 X의 transpose matrix 임)

Q * K.T 에서 top-k를 linear하게 뽑습니다.

LSH 해쉬를 통해서 유사한 벡터를 뽑기도 합니다. h1, ... hl까지 random normalized hash vectors 를 뽑은 후에 dot(q,hi) > 0 이면 bi = 1이라고 합니다. 이걸 통해서 b1, ... bl 연속 bit를 만들어내고 만약 유사하다면 이 비트가 동일할 확률도 높아지니까 이 방식을 통해서 유사한 query를 구하는 방식도 있습니다.

2.1 Using the Memory module

메모리 모듈은 어떤 classification network에도 적용할 수 있습니다. 우리는 어떤 layer를 query로 사용할 것이고 모듈의 output을 어떻게 사용할지를 정하면 됩니다. 아래는 몇 가지 예시입니다.

CNN with memory

image classification 에서 memory module의 query로 CNN의 마지막 layer의 output을 사용하고 memory에서 return된 결과를 마지막 network prediction으로 사용합니다.

Seq2seq with Memory

Imgur

구글의 Neural Machine Translation model은 encoder RNN 을 통해서 source language의 문장들의 representation을 만들고, decoder RNN을 통해서 target language sentence를 출력합니다. 여기서 우리는 encoder RNN은 고치지 않고 decoder RNN에서 메모리 모듈을 사용했습니다. decoding 단계에서 attention 에서 나온 벡터를 query로 두었습니다. attention vector는 GNMT 모델에서는 모든 LSTM layer에서 사용되는데 final softmax layer 와 memory output을 Linear layer를 통해 결합했습니다.

4. experiments

One shot learning을 실험하기 매우 어렵긴 함 adam optimizer를 사용하였으며 k = 256, alpha = 0.1으로 유지했습니다.

Omniglot : 50 different alphabets , 1623 hand-written characters 기존의 실험과 맞추기 위해서 데이터 augmentation하고 실행했는데 결과는 다음과 같고, 우리는 batch normailzation 없이 더 간단한 모델로 돌렸는데도~ 잘 나오더라~ 역시 SOTA

Imgur

Translation : GNMT 모델로 english to German translation task 수행함

qualitative 평가 : 우리의 메모리 augmented model은 Dostoevsky와 같이 잘 나오지 않는 단어들도 일반 모델에 비해서 잘 맞췄다.

Imgur

quantitative 평가 : WMT test set을 사용했으며 test set을 두개로 나눴는데 even line set as a context set and the odd line as the one shot evaluation sentence 로 사용하였다. 메모리의 효과를 검증하기 위해서 모델을 training 시킨 후 even text context로 memory module만 업데이트 시킨 모델로 성능을 겸정하고자 했으며 기본 모듈에 비해 좋은 성능을 보였다.

Clone this wiki locally