논문/NLP

논문 리뷰) Titans: Learning to Memorize at Test Time

Dogun Kim 2025. 1. 20. 15:39

Titans: Learning to Memorize at Test Time

https://arxiv.org/abs/2501.00663

 

Titans: Learning to Memorize at Test Time

Over more than a decade there has been an extensive research effort on how to effectively utilize recurrent models and attention. While recurrent models aim to compress the data into a fixed-size memory (called hidden state), attention allows attending to

arxiv.org

 

0. Abstract

 

 RNN, LSTM과 같은 Recurrent Model들은 고정 크기의 메모리((Hidden-state))로 압축하는 걸 목표로 한다. 

이에 반대로 Attention은 전체 context window에 접근하여 모든 토큰 간의 직접적인 종속성들을 캡처할 수 있도록 한다.

cf. context window: 한 번에 처리하는 입력 데이터의 범위 또는 길이

 

 하지만 이러한 더 정확한 종속성 모델링((Attention))은 이차적((Quadratic)) 비용이 수반되어 모델을 고정 길이의 context로 제한하게 된다. 

 

cf. Accurate dependency modeling은 데이터 특히 시퀀스 데이터 내의 요소들 간의 의존성을 정확히 표현하고 학습하는 과정

 

 해당 논문은 어텐션이 현재 context에도 집중하면서도, 과거 context를 기억하는 법을 학습하여 오래된 정보도 활용할 수 있는 새로운 Neural Long-term Memory Module 신경 기반 장기 메모리 모듈을 제안한다. 이 신경 메모리는 빠른 병렬 학습이 가능하면서도 빠른 추론 속도를 유지하는 장점이 있다. 

 

 메모리 관점에서 어텐션은 제한된 context 길이와 정확한 종속성 모델링에 의해 단기 메모리로 작동하게 되고, 신경 메모리는 데이터를 기억하는 능력을 학습하므로 장기적이고 지속적인 메모리로 작동할 수 있다고 주장한다. 이러한 두 모듈을 바탕으로 새로운 아키텍쳐 Titans를 소개하고 메모리를 효과적으로 통합하는 세 가지 변형 모델을 제시한다.

 

 실험을 통해 Titans는 언어 모델링, 게놈 연구, 시계열 작업 등에서 기존 트랜스포머보다 더 효과적임을 보이고, Scale to Larger Context,  Needle-in-Haystack Tasks에서도 높은 정확도를 보인다.

cf. Scale to Larger Context: 긴 시퀀스 데이터를 효율적으로 처리할 수 있는 능력. 2M 이상의 컨텍스트 창 크기로 확장할 수 있다.

cf. Needle-in-Haystack Tasks: 매우 많은 데이터 중에서 유용한 정보를 찾아내는 작업

 

 


1. Introduction

  • 기존 Transformer의 문제점: Scalability issue

 Transformer는 순수 어텐션 기반 아키텍처로, 시퀀스 모델링 분야에서 in-context 학습과 확장 가능한 학습 능력 덕분에 SOTA 모델로 확립되었다. Transformer의 주요 구성 요소인 어텐션 모듈은 연관 기억 블록으로 작동하며, 쿼리 검색 신호와 키 컨텍스트 간의 쌍별 유사도를 계산해 키-값 연관성을 학습하고 검색한다. 따라서 Transformer의 출력은 현재 context window에서 토큰 간 직접적인 종속성에만 조건화된다

 

 이러한 정확한 종속성 모델링 accurate modeling of dependencies은 context의 길이에 따라 Quadratic time과 메모리 복잡도나 나타나게 된다.

cf. Associative Memory Blocks 연관 기억 블록: 특정 키에 대해 관련 값을 저장하고 회상하는 매커니즘

 

>>> Transformer의 출력 계산은 context window 안에 있는 토큰들끼리 이루어지게 된다. 이 계산은 (n^2)의 계산, 메모리 복잡도가 발생하여 context window가 매우 커지는 복잡한 실제 과제에서는 Transformer의 적용이 어려워진다.

cf. complex real-world tasks 예시: language modeling, video understanding, long-term time series forecasting

 

 

  • To overcome the scalability issue of Transformers

 확장성 문제를 해결하기 위해 최근 연구에서는 어텐션의 Softmax를 커널 함수로 대체하여 메모리 소비를 크게 줄이는 선형 Transformer 변형 모델을 설계한 시도가 있다. 효율성과 긴 컨텍스트 처리 능력에 불구하고 데이터를 행렬 값 상태로 압축하는 커널 기법으로 인해 기존 트랜스포머보다 성능이 떨어지게 된다.

>>> 모순) 확장성, 효율을 높이기 위해 선형 모델 사용 <-> 매우 긴 context는 작은 벡터나 행렬로 적절히 압축 불가능

 cf. 선형 Transformer는 과거 데이터를 고정 크기의 행렬 값 메모리로 압축하는 반면, Transformer는 모든 과거 데이터를 컨텍스트 길이 내에서 압축 없이 유지

 

 

  • 기존 아키텍쳐들이 놓치고 있는 것들

 효율성 외에도 Hopfield 네트워크, LSTM, Transformer에 이르는 대부분의 기존 아키텍처는 complex real-world tasks에 필수적인 일반화, Length Extrapolation , 또는 추론 처리에서 어려움을 겪는다.

cf. Length Extrapolation: 학습된 데이터 길이보다 더 긴 데이터에서도 성능을 유지하는 능력

 

 이러한 기존 아키텍쳐들은 인간 두뇌에서 영감을 받았지만, 놓치고 있는건 다음과 같다.

1) short-term memory, long-term memory, meta-memory, attending to current context

2) 이러한 구성 요소들이 독립적으로 작동할 수 있는 상호 연결된 시스템이라는 점

3) 데이터로부터 능동적으로 학습하고 과거 이력을 추상적으로 기억하는 능력

 

 우리 인간의 두뇌처럼 효과적인 학습 패터다임에서는, 위와 같이 각기 구분되면서도 상호 연결된 모듈이 존재하여, 각 모듈이 학습 과정에서 중요한 구성 요소를 담당한다고 이 논문은 주장한다. 

 

cf. Meta-memory 메타 메모리: 자신의 기억 과정을 이해하고 제어하는 능력.

 

 

  • Memory Perspective

 기억은 기본적인 정신 과정이며, 인간 학습의 필수적인 요소이다. 이러한 기억이 없다면 인간과 동물은 기본적인 반사 행동과 정형화된 행동으로 제한될 것이다. 따라서 기억은 머신러닝 문헌에서 Hopfield,  LSTM, Transformer와 같은 기념비적인 연구의 영감이 되었다.

 

 대부분의 아키텍쳐는 기억을 입력에 의해 발생하는 Neural update로 간주하며, 학습을 주어진 목표에 따라 효과적이고 유용한 기억을 습득하는 과정으로 정의한다. 이 관점에서 기존 모델들은 다음과 같이 정의할 수 있다.

cf. Neural update 신경 업데이트: 새로운 입력 데이터를 처리하며, 기존 메모리를 갱신. 가중치 등이 변화. 즉 기존 아키텍쳐들은 기억을 입력에 의해 갱신되는 신경망 상태로 간주한다.

 

1) RNN: 벡터 형태의 메모리 모듈 M((=hidden-state))를 가진 모델. 시간 $t$에서 새 입력 $x_t$가 주어지면 f 함수를 사용하여 메모리를 업데이트 하고, 필요한 경우 함수 g를 사용하여 입력과 대응되는 메모리를 검색한다.

 Recurrent Model들은 고정 크기의 메모리((Hidden-state))로 압축하는 걸 목표하여 데이터를 압축하여 저장하기 때문에 기억 용량이 제한된다.

 

2) Transformers: 확장 가능한 메모리를 가진 아키텍처. key와 value 행렬 쌍이 모델의 메모리 역할을 하며, i) key와 value를 메모리에 압축 없이 추가하여 메모리를 업데이트하고, ii) query 벡터와 key 벡터의 유사성을 찾아 해당되는 메모리를 검색하고, 이를 통해 value 벡터에 가중치를 부여하여 출력을 생성한다.

 

이외에도 선형 Transformer, 선형 RNN 구조들이 있는데 이들은 메모리를 압축하여 저장하고, 각각 행렬 값 메모리, 벡터 값 메모리를 사용한다.

 

이러한 구조들의 비교를 통해 논문은 다음과 같은 질문을 얻는다.

 

Q1. 메모리에 적합한 구조는 무엇인가? 

Q2. 적절한 메모리 업데이트 메커니즘은 무엇인가?

Q3. 좋은 메모리 검색 과정은 무엇인가?

 

이에 대한 해답을 찾기 위해 해당 연구는 인간의 기억은 단기, 작업, 장기 기억 등 서로 다른 신경 구조를 가진 다양한 기억 시스템으로 나뉘며, 각 시스템은 독립적으로 작동하고 연합한다는 것에 주목한다. 이 사실을 통해 다음과 같은 추가 질문을 얻는다.

 

Q4. 서로 연결된 다양한 기억 모듈을 효율적으로 통합하는 아키텍처를 어떻게 설계할 것인가?

 

 마지막으로 기억을 저장하는 것은 단순히 데이터를 저장하는 것이 아니라 과거의 abstraction을 인코딩하고 저장하는 신경 과정이다. 따라서 단일 벡터나 행렬로 장기 기억을 저장하는 것은 과도하게 단순화된 접근일 수 있다. 그렇기에 다음과 같은 질문을 또 얻게 된다.

 

Q5. 장기적인 과거를 효과적으로 저장/ 기억하기 위해 깊은 모듈이 필요한가?

 

 

 

  • Contributions and Roadmap

 이 논문에서는 Test time에  효율적이고 효과적으로 기억을 학습할 수 있는 long-term neural memory module을 설계하여 위 다섯 가지 질문에 답하고자 한다. 또한 이 설계를 바탕으로, 그것이 어떻게 아키텍처에 통합될 수 있는지 논의한다.

cf. test time: 학습이 완료된 후, 모델이 실제로 데이터를 처리하는 시간.

 

설계는 다음과 같다.

1) Neural Memory: 장기 신경망 기억 시스템

 '테스트 시간에' 데이터를 저장하고 기억하는 방법을 학습하는 깊은 신경망 장기 기억 시스템. 이 시스템은 a meta in-context model 작동하며, 데이터를 매개변수에 저장하고 처리하는 방식.

 

 인간의 장기 기억 시스템에 영감을 받았으며 예상과 다른 사건, 즉 놀라운 사건을 더 잘 기억할 수 있도록 한다. 이 놀라움을 gradient를 통해 측정하며, 기억에 반영하고, associative memory loss을 활용해 입력에 대한 반응을 학습한다.

 

또한 decaying mechanism을 통해 메모리 크기 제한과 데이터의 놀라움 정도를 고려하여 중요한 정보를 더 잘 저장하고 유지할 수 있게 돕는다. 이는 현대 recurrent model들의 forgetting mechanism을 일반화한 것이다.

 

 mini-batch gradient descent, momentum, and weight decay를 통해 meta 신경망을 최적화하고, 텐서화된 gradient descent활용해 빠르고 병렬적인 학습이 가능하다.

 

2) Titans Architectures

Titans 아키텍처는 Neural Memory를 포함한 세 가지 주요 메모리 모듈을 갖춘 아키텍쳐로...  메모리를 딥러닝 아키텍처에 효과적이고 효율적으로 통합하는 방법에 대한 해답이다. 이것은 세 개의 하이퍼 헤드로 구성된 딥 모델의 집합이다.

  1. Core: 단기 기억을 담당하며, 데이터 처리 흐름을 관리. 어텐션 기법을 사용하여 제한된 윈도우 크기에서 데이터 처리
  2. Long-term Memory: 장기 기억을 담당하여, 과거의 정보를 저장하고 기억하는 역할을 합니다.
  3. Persistent Memory: 지속적인 기억으로, 작업에 대한 지식을 저장하는 학습 가능한 매개변수입니다. 이 매개변수는 날짜와 관계없이 작업 지식을 인코딩합니다.

마지막으로 i) 컨텍스트, ii) 레이어, (iii) a gated branch 를 통해 Titans의 세 가지 변형을 제시한다.

 

 

  • Experimental Results

 

 Titans 아키텍처는 언어 모델링, 상식 추론, 회상 집약적 작업, 바늘 찾기, 시계열 예측, DNA 모델링 등 다양한 작업에서 우수한 성능을 보임. Titans는 모든 현대 순환 모델과 슬라이딩 윈도우 어텐션을 결합한 하이브리드 모델을 포함한 여러 벤치마크에서 성능을 능가. 또한 같은 컨텍스트 윈도우를 사용할 때 Transformers보다 더 우수한 성능을 보였으며, 전체 컨텍스트를 사용하는 Transformers와 비교해도 경쟁력 있는 성능을 보여준다. Titans는 2M 이상의 컨텍스트 윈도우 크기로 확장할 수 있어 Transformers보다 더 큰 데이터셋을 처리할 수 있는 확장성을 가진다.

 

 

 


2. Preliminaries 

 논문에서 사용되는 표기법과 몇 가지 배경 개념에 대해서 설명하는 파트이다.

  • 표기법 정리

cf. Mask: 딥러닝에서 특정 데이터를 선택하거나 무시할 때 사용하는 기법. 주로 어텐션, 패딩 처리에 사용됨.

 

 

  • Backgrounds

1) Attention

입력에 의존하는 키, 값, 쿼리 행렬에 대해 소프트맥스를 적용하여 출력 y를 계산한다. 이 때 모든 W는 학습 가능한 parameter이며 크기는 (d_in) x (d_in).

 

 Transformers는 뛰어난 기억력과 효율성에도 불구하고, 소프트맥스를 통해 출력을 계산하기 위해 최소 𝑁 × 𝑑 연산자가 필요하여, 긴 시퀀스에서는 더 큰 메모리 소비와 낮은 throughput을 초래한다.

 

2) Efficient Attentions

 긴 시퀀스에 대한 소프트맥스 어텐션의 효율성을 개선하기 위해 다음과 같은 방법들이 연구되고 있다.

 

  1. I/O aware implementations: 입력/출력(I/O) 성능을 고려한 구현 방식.
  2. Sparsifying the attention matrix: 어텐션 행렬의 희소화로, 불필요한 값을 제거하여 계산 효율을 높이는 기법.
  3. Softmax approximation: 소프트맥스 근사화로, 계산을 간소화하는 방법.
  4. Kernel-based (linear) attentions: 커널 기반(선형) 어텐션으로, 어텐션 계산을 간소화하는 기법.

 

 4번 방식에 대해서 자세히 알아보면 다음과 같다.

이는 선형 어텐션으로 준 어텐션의 소프트맥스를 대체할 커널 함수 𝜙 ((. , .))를 사용하게 된다.

여기서 𝜙((𝑥, 𝑦)) = 𝜙 ((𝑥))𝜙((𝑦))이다. 즉 두 벡터의 연산을 분리해서 처리할 수 있게 된다.

커널 함수를 사용해서 Q, K의 내적을 계산하고 이를 Value에 적용

재사용 가능한 계산을 통해 처리량이 높아지며, 커널을 단위 행렬로 선택하면 다음과 같은 수식으로 표현할 수 있다.

recursive 방식으로 어텐션 계산 가능

이는 선형 어텐션을 위한 효율적인 추론을 가능하게 한다.

 

3) Modern Linear Models and Their Memory Perspective

 

.... 미완


3. Learning to Memorize at Test Time ****

 트렌스포머 기반 모델들은 context window 안에 있는 토큰들끼리 계산이 이루어지게 되는데, 이는 (n^2)의 계산, 메모리 복잡도가 발생하여 context window가 매우 커지는 복잡한 실제 과제에서는 Transformer의 적용이 어려웠다. 즉 장기 기억에 대한 성능이 떨어지는 것이다.

 

 이러한 장기 기억 부족을 극복하고 모델이 학습, 망각, 정보를 검색할 수 있도록 해당 논문은 Test Time에서도 기억을 학습하는 meta modal로서의 neural long-term memory 신경 장기 기억 모듈을 제시한다.

 

 

3-1. Neural memory 설계 동기

  • 과거를 기억한다? # 단순하게 훈련 데이터를 기억하는 것이 아니라, 과거를 추상적으로 인코딩....

 장기 기억을 갖는다 즉 과거를 기억한다는 것은 단순히 훈련 데이터를 저장하는 것이 아니다. 이는 모델의 일반화 능력을 제한하고, 프라이버시 문제를 야기하며, 결과적으로 테스트 시 성능을 저하시킨다. 또한 훈련 데이터와 테스트 데이터가 분포가 다른 경우 전혀 도움이 되지 않는다.

 

 그렇기에 해당 논문은 Test 시점에서 데이터를 기억하고 망각하는 방법을 학습하는 Online meta-model의 필요성을 주장한다. 이 설정의 모델은 기억이 가능하도록 하면서, 훈련 데이터를 저장하는 방식이 아니기에 훈련 데이터에 과적합되지 않을 것이며 테스트 시 더 좋은 일반화 성능을 제공하게 된다.

 

 

  • 어떻게 효과적으로 기억할 것인가?

 인간도 모든 기억을 온전히 가지고 있지 않다. 그 과거 사건이 얼마나 놀라웠냐에 따라 추상적으로 기억을 보존하게 된다.

그렇기에 이를 모방하여 Test time에서 기억을 효과적으로 기억 /망각하기 위해 과거 정보 $x_1, x_2, ... x_t-1$를 놀라움의 정도에 따라 신경 장기 기억 모듈 $M_t$의 매개변수에 추상적으로 압축하는 방식이 해당 모듈의 핵심 아이디어이다. 

 

1) 단순하게 Gradient를 통해 놀라움을 표현하고 기억을 update

놀라움의 정도는 입력에 대한 gradient로 표현할 수 있다. gradient가 클수록 입력 데이터과 과거 데이터와 더 다를 것이다. gradient로 구한 놀라움의 정도를 통해 기억을 다음과 같이 업데이트 할 수 있다.

loss function은 다음 장에서 정의하게 된다. cf. Test time에서 학습하는 Online meta model임을 잊지 말자..

 

하지만 이러한 방식은 몇 번의 놀라운 단계 이후에 그래디언트가 매우 작아질 수 있으며 그래디언트가 소멸하여 local minima에 빠지거나 시퀀스의 일부 정보를 놓치게 된다.

 

 인간의 관점에서 한 사건이 기억에 오래 남더라도 같은 놀라움을 계속 주지는 않을 것이다. 첫 순간이 놀라우면 이후 데이터를 기억하는 경향이 있다.

 

2) Surprise = Past Surprise + Momentary Surprise

1번 방식의 단점을 해결하기 위해 위 식처럼 놀라움을 과거의 놀라움과 현재의 놀라움으로 나눠 신경 장기 기억 모듈을 update 해간다. 흥미롭게도 $S_가 모멘텀 요소인 모멘텀을 가진 그래디언트 강하와 유사하며 마지막 놀라움을 반영하는 정도를 조정하기 위해 ηt를 0으로 설정하거나, 1로 설정하도록 제어할 수 있다.

 

  • Objective

메모리 모듈은 손실 함수를 최적화하여 입력 데이터를 효율적으로 기억하도록 설계된 메타 학습 모델이다. 여기서 해당 연구는 associative memory에 중점을 두며, 이를 key-value 쌍으로 저장하여 관련성을 기반으로 데이터를 효율적으로 검색하고 기억하게 한다.

입력 데이터를 각각의 two linear layers W를 통해 키와 값으로 변환하고, 메모리 모듈이 키와 값 간의 연관성을 학습하기 위하여 다음과 같은 Loss function을 설정하게 된다.

메모리 모듈의 출력을 실제 값과 비교하여 차이를 최소화하도록 학습하는 손실 함수이며, 메모리 모듈 내 루프에서 이를 최적화함으로써 테스트 시점에서 키와 값 간의 매핑을 기억하는 방법을 학습하게 된다.

cf. 전체 아키텍쳐 학습 과정과 메모리 모듈 내 Test time에서의 학습 과정을 구분해야함.

 

 

  • 어떻게 효과적으로 망각할 것인가? # Forgetting Mechanism 추가

 매우 긴 시퀀스(예: 수백만 개의 토큰)를 처리할 때, 과거 정보 중 어떤 것을 잊어야 할지 관리하는 것이 중요하다. 메모리의 제한된 용량을 더 잘 관리해야하기 때문이다. 이를 위해 얼마나 많은 과거 정보를 잊을지에 대한 gating mechanism $ α_t$​∈[0,1] 를 추가해, 위와 같이 업데이트 규칙을 다음과 같이 수정한다.

 

$α_t$를 0으로 보내면 과거 추상화에 영향을 주지 않을 것이며, 1로 설정하면 전체 메모리가 지워질 것이다.

 

 

  • 해당 메모리 모듈을 어떻게 효과적으로 통합할 것인가?

해당 논문에서는 장기 기억 모듈 설계, 그리고 이를 더 잘 동기화하고 아키텍쳐에 통합하는 걸 초점으로 한다. 그러기에 2 layer 이상의 MLP를 사용하여 장기 기억을 전체 아키텍쳐에 통합한다.

 

cf. 연구 초점에 의해 2 layer 이상의 MLP를 사용한 것. 더 효과적으로 통합하는 구조에 대한 가능성이 존재한다.

 

 

  • Retrieving a Memory

메모리 검색은 입력 데이터를 쿼리로 투영한 후, 메모리 모듈을 통해 관련 정보를 검색하는 방식으로 수행됨. 

 

 

3.2) 설계된 아키텍쳐의 이점

모델 학습이 얼마나 효율적이고 확장 가능한 방식으로 이루어질 수 있는지에 대해 다룬다.

미완...

 

3.3) 아키텍쳐 확장

 학습 가능한 데이터 독립적 매개변수를 사용하여 태스크에 대한 메타 정보를 학습하는 지속적 기억 모듈을 사용하여 아키텍처를 확장 한다. 

 

  • Persistent Memory 추가

효과적인 기억 시스템은 입력 독립적 파라미터를 통해 Task 관련 지식을 저장할 필요가 있다.

 


4. How to Incorporate Memory? 

설계한 기억 모듈을 어떻게 아키텍쳐에 통합할 것이냐에 대한 내용이다.

 

1) Memory as a Context (MAC) Architecture # 시퀀스를 segment로 쪼개서 이용.

과거 정보를 검색, 추출
과거 정보, Task에 대한 영구 정보, 현재 들어온 시퀀스의 segment를 합쳐 어텐션
장기 기억 모듈을 업데이트하고 최종 값 출력

 

2) Memory as a Gate (MAG) Architecture # 시퀀스 그대로 사용 슬라이딩 윈도우 어텐션 이용

 

3) Memory as a Layer (MAL) Architecture

 

 

 

....

 

 

 

 

    •