# xformers 기반의 메모리 효율적인 인과적 어텐션 구현 def memory_efficient_causal_attention(q, k, v): # 필요한 어텐션 점수만 계산 (마스킹된 점수는 계산하지 않음) batch_size, seq_len, num_heads, head_dim = q.shape # 출력 텐서 초기화 output = torch.zeros_like(q) # 누적 소프트맥스 상태 추적 softmax_scale = 1.0 / math.sqrt(head_dim) m_prev = torch.zeros((batch_size, num_heads, 1), device=q.device) # 최대값 저장 p_sum = torch.zeros((batch_size, num_heads, 1), device=q.device) # 합 저장 o_acc = torch.zeros((batch_size, num_heads, head_dim), device=q.device) # 가중합 저장 # 순차적 처리로 메모리 사용 최소화 for i in range(seq_len): # 현재 위치의 쿼리 qi = q[:, i:i+1] # [B, 1, H, D] # 현재 위치까지의 키와 값만 처리 (인과성) ki = k[:, :i+1] # [B, i+1, H, D] vi = v[:, :i+1] # [B, i+1, H, D] # 어텐션 점수 계산 (스케일링 포함) si = torch.matmul(qi, ki.transpose(-1, -2)) * softmax_scale # [B, 1, H, i+1] # Rabe & Staats 방식: 소프트맥스 최적화 mi = torch.max(si, dim=-1, keepdim=True)[0] m_curr = torch.maximum(m_prev, mi) # 이전 누적값 스케일 조정 scale_factor = torch.exp(m_prev - m_curr) p_sum = p_sum * scale_factor o_acc = o_acc * scale_factor.unsqueeze(-1) # 새로운 어텐션 가중치 계산 pi = torch.exp(si - m_curr) p_sum = p_sum + pi.sum(dim=-1, keepdim=True) # 출력 누적 (중간 어텐션 행렬 저장 없이) o_acc = o_acc + torch.matmul(pi, vi).squeeze(1) # 현재 위치 출력 계산 output[:, i] = o_acc / p_sum # 상태 업데이트 m_prev = m_curr return output
