def memory_efficient_causal_attention(q, k, v):
batch_size, seq_len, num_heads, head_dim = q.shape
output = torch.zeros_like(q)
softmax_scale = 1.0 / math.sqrt(head_dim)
m_prev = torch.zeros((batch_size, num_heads, 1), device=q.device)
p_sum = torch.zeros((batch_size, num_heads, 1), device=q.device)
o_acc = torch.zeros((batch_size, num_heads, head_dim), device=q.device)
for i in range(seq_len):
qi = q[:, i:i+1]
ki = k[:, :i+1]
vi = v[:, :i+1]
si = torch.matmul(qi, ki.transpose(-1, -2)) * softmax_scale
mi = torch.max(si, dim=-1, keepdim=True)[0]
m_curr = torch.maximum(m_prev, mi)
scale_factor = torch.exp(m_prev - m_curr)
p_sum = p_sum * scale_factor
o_acc = o_acc * scale_factor.unsqueeze(-1)
pi = torch.exp(si - m_curr)
p_sum = p_sum + pi.sum(dim=-1, keepdim=True)
o_acc = o_acc + torch.matmul(pi, vi).squeeze(1)
output[:, i] = o_acc / p_sum
m_prev = m_curr
return output