]> git.sondrewold.no Git - reference_transformer.git/commitdiff
fix mistake in einsum main
authorSondre Wold <[email protected]>
Thu, 13 Mar 2025 12:35:35 +0000 (13:35 +0100)
committerSondre Wold <[email protected]>
Thu, 13 Mar 2025 12:35:35 +0000 (13:35 +0100)
reference_transformer/reference_transformer/gpt.py

index 86919de415d71417c9cf595540e8d3a0b577b56d..02ddc437cd3561d79d9c1312389839bd28de240c 100644 (file)
@@ -95,7 +95,7 @@ class MultiHeadedAttention(nn.Module):
         attention_weights = torch.softmax(attention_weights, dim=-1)
 
         # Multiply with values
-        scaled_values = torch.einsum("bhqk,bkhc->bkhc", attention_weights, values)
+        scaled_values = torch.einsum("bhqk,bkhc->bqhc", attention_weights, values)
 
         # Concat heads
         flattened = einops.rearrange(scaled_values, "b k h c -> b k (h c)")