From f868aa71d9007cec5dbc8a00ecf2ee74cc35e695 Mon Sep 17 00:00:00 2001 From: Sondre Wold Date: Thu, 13 Mar 2025 13:35:35 +0100 Subject: [PATCH] fix mistake in einsum --- reference_transformer/reference_transformer/gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reference_transformer/reference_transformer/gpt.py b/reference_transformer/reference_transformer/gpt.py index 86919de..02ddc43 100644 --- a/reference_transformer/reference_transformer/gpt.py +++ b/reference_transformer/reference_transformer/gpt.py @@ -95,7 +95,7 @@ class MultiHeadedAttention(nn.Module): attention_weights = torch.softmax(attention_weights, dim=-1) # Multiply with values - scaled_values = torch.einsum("bhqk,bkhc->bkhc", attention_weights, values) + scaled_values = torch.einsum("bhqk,bkhc->bqhc", attention_weights, values) # Concat heads flattened = einops.rearrange(scaled_values, "b k h c -> b k (h c)") -- 2.39.5