From: Sondre Wold Date: Thu, 13 Mar 2025 12:35:35 +0000 (+0100) Subject: fix mistake in einsum X-Git-Url: https://letsjmore.com/?a=commitdiff_plain;h=f868aa71d9007cec5dbc8a00ecf2ee74cc35e695;p=reference_transformer.git fix mistake in einsum --- diff --git a/reference_transformer/reference_transformer/gpt.py b/reference_transformer/reference_transformer/gpt.py index 86919de..02ddc43 100644 --- a/reference_transformer/reference_transformer/gpt.py +++ b/reference_transformer/reference_transformer/gpt.py @@ -95,7 +95,7 @@ class MultiHeadedAttention(nn.Module): attention_weights = torch.softmax(attention_weights, dim=-1) # Multiply with values - scaled_values = torch.einsum("bhqk,bkhc->bkhc", attention_weights, values) + scaled_values = torch.einsum("bhqk,bkhc->bqhc", attention_weights, values) # Concat heads flattened = einops.rearrange(scaled_values, "b k h c -> b k (h c)")