# xformers 기반의 메모리 효율적인 인과적 어텐션 구현 def memory_efficient_causal_attention ( q, k, v ): # 필요한 어텐션 점수만 계산 (마스킹된 점수는 계산하지 않음) batch_size, seq_len, num_heads, head_dim = q.shape # 출력 텐서 초기화 output = torch.zeros_like(q) # 누적 소프트맥스 상태 추적 softmax_scale = 1.0 / math.sqrt(head_dim) m_prev = torch.zeros((batch_size, num_heads, 1 ), device=q.device) # 최대값 저장 p_sum = torch.zeros((batch_size, num_heads, 1 ), device=q.device) # 합 저장 o_acc = torch.zeros((batch_size, num_heads, head_dim), device=q.device) # 가중합 저장 # 순차적 처리로 메모리 사용 최소화 for i in range(seq_len): # 현재 위치의 쿼리 qi = q[:, i:i+ 1 ] # [B, 1, H, D] # 현재 위치까지의 키와 값만 처리 (인과성) ki = k[:, :i+ 1 ] # [B, i+1, H, D] vi = v[:, :i+ 1 ] # [B, i+1, H, D] # 어텐션 점수 계산 (스케일링 포함) si = torch.matmul(qi, ki.transpose( -1 , -2 )) * softmax_scale # [B, 1, H, i+1] # Rabe & Staats 방식: 소프트맥스 최적화 mi = torch.max(si, dim= -1 , keepdim= True )[ 0 ] m_curr = torch.maximum(m_prev, mi) # 이전 누적값 스케일 조정 scale_factor = torch.exp(m_prev - m_curr) p_sum = p_sum * scale_factor o_acc = o_acc * scale_factor.unsqueeze( -1 ) # 새로운 어텐션 가중치 계산 pi = torch.exp(si - m_curr) p_sum = p_sum + pi.sum(dim= -1 , keepdim= True ) # 출력 누적 (중간 어텐션 행렬 저장 없이) o_acc = o_acc + torch.matmul(pi, vi).squeeze( 1 ) # 현재 위치 출력 계산 output[:, i] = o_acc / p_sum # 상태 업데이트 m_prev = m_curr return output
