잔차 쿱만 스펙트럼 프로파일링을 이용한 트랜스포머 학습 불안정성 예측 및 방지
Residual Koopman Spectral Profiling for Predicting and Preventing Transformer Training Instability
트랜스포머 학습 과정에서의 발산은 컴퓨팅 자원을 낭비하지만, 불안정성이 발생하는 것을 경험적으로 파악하는 데에는 상당한 시간과 비용이 소요됩니다. 따라서, 학습 시작 전에 트랜스포머의 실패 가능성을 예측하는 것이 중요합니다. 본 연구에서는 잔차 쿱만 스펙트럼 프로파일링(RKSP)을 통해 이러한 예측을 가능하게 합니다. RKSP는 초기화 단계에서 단일 순전파 과정을 통해 각 레이어의 잔차를 획득하고, 화이트닝된 동적 모드 분해(Whitened Dynamic Mode Decomposition)를 적용하여 쿱만 스펙트럼 특징을 추출합니다. 핵심 진단 지표인 '거의 단위 스펙트럼 질량(near-unit spectral mass)'은 단위 원 주변에 집중된 모드의 비율을 측정하며, 이는 불안정성 위험을 나타냅니다. 본 연구에서 개발한 추정기는 다양한 설정에서 발산을 예측하는 데 있어 0.995의 AUROC 값을 달성하여, 기존의 최적 성능을 보이는 경사 기반 방법보다 우수한 성능을 보입니다. 또한, 쿱만 스펙트럼 성형(KSS)을 통해 이 진단 지표를 실제 학습 과정에 적용할 수 있도록 하였습니다. 실험 결과, RKSP는 초기 단계에서 발산을 예측하며, RKSP가 높은 위험을 감지하면 KSS를 활성화하여 발산을 성공적으로 방지할 수 있음을 확인했습니다. 정규화 레이어 없이 높은 학습률을 사용하는 어려운 환경에서도 KSS는 발산율을 66.7%에서 12.5%로 감소시키고, 학습률을 50%에서 150%까지 높일 수 있었습니다. 이러한 결과는 WikiText-103 언어 모델링, CIFAR-10 데이터셋에서의 비전 트랜스포머, GPT-2 및 LLaMA-2 (최대 7B)와 같은 사전 학습된 언어 모델, MoE, Mamba 스타일의 SSM, KAN과 같은 새로운 아키텍처에도 일반화됩니다.
Training divergence in transformers wastes compute, yet practitioners discover instability only after expensive runs begin. They therefore need an expected probability of failure for a transformer before training starts. Our study of Residual Koopman Spectral Profiling (RKSP) provides such an estimate. From a single forward pass at initialization, RKSP extracts Koopman spectral features by applying whitened dynamic mode decomposition to layer-wise residual snapshots. Our central diagnostic, the near-unit spectral mass, quantifies the fraction of modes concentrated near the unit circle, which captures instability risk. For predicting divergence across extensive configurations, this estimator achieves an AUROC of 0.995, outperforming the best gradient baseline. We further make this diagnostic actionable through Koopman Spectral Shaping (KSS), which reshapes spectra during training. We empirically validate that our method works in practice: RKSP predicts divergence at initialization, and when RKSP flags high risk, turning on KSS successfully prevents divergence. In the challenging high learning rate regime without normalization layers, KSS reduces the divergence rate from 66.7% to 12.5% and enables learning rates that are 50% to 150% higher. These findings generalize to WikiText-103 language modeling, vision transformers on CIFAR-10, and pretrained language models, including GPT-2 and LLaMA-2 up to 7B, as well as emerging architectures such as MoE, Mamba-style SSMs, and KAN.
No Analysis Report Yet
This paper hasn't been analyzed by Gemini yet.