PyTorch

Interesting findings while exploring PyTorch.


nn.Embedding

When you write nn.Embedding(1000, 128), you are creating a table with 1000 rows and 128 columns: one 128-dimensional embedding for each token in a vocabulary of size 1000.

Calling the module is basically a lookup into that table. If your input tokens are [2, 7, 4], PyTorch grabs rows 2, 7, and 4 from the embedding matrix and returns them stacked together.

import torch
import torch.nn as nn

embedding = nn.Embedding(1000, 128)

tokens = torch.tensor([2, 7, 4])
out = embedding(tokens)

print(out.shape)
# torch.Size([3, 128])

You can think of it as something close to manually indexing the weight matrix and stacking the results.

row_2 = embedding.weight[2]
row_7 = embedding.weight[7]
row_4 = embedding.weight[4]

manual = torch.stack([row_2, row_7, row_4])

print(row_2.shape)
# torch.Size([128])

print(manual.shape)
# torch.Size([3, 128])

So the mental model is: nn.Embedding is a learnable lookup table, and calling it with token ids returns the corresponding rows packed into one tensor.

tensor.view

view feels a lot like reshape, but the important difference is that view works by reinterpreting the same contiguous block of memory instead of making a new tensor when it can avoid it.

If a tensor is already stored contiguously, view can change only the shape metadata and leave the underlying values exactly where they are in memory.

import torch

x = torch.arange(12)
print(x)
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

y = x.view(3, 4)
print(y)
# tensor([
#   [0, 1, 2, 3],
#   [4, 5, 6, 7],
#   [8, 9, 10, 11],
# ])

Conceptually, nothing was rearranged. PyTorch just changed how it reads the same flat memory:

flat memory:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

viewed as (3, 4):
[
  [0, 1, 2, 3],
  [4, 5, 6, 7],
  [8, 9, 10, 11],
]

That is the efficiency trick: view is just a new interpretation of the same contiguous memory, not a rebuild of the tensor.

reshape will try to do the same thing, but if the tensor is no longer contiguous it is allowed to make a copy in order to return the requested shape.

x = torch.arange(12).view(3, 4)
z = x.t()  # transpose makes layout non-contiguous

print(z.is_contiguous())
# False

# z.view(2, 6) would fail here

r = z.reshape(2, 6)
print(r.shape)
# torch.Size([2, 6])

In that case, reshape succeeds by materializing data in a new layout, while view refuses because the old memory layout cannot be reinterpreted that way directly.

So the rough mental model is: use view when you want a cheap reshaping of contiguous memory, and remember that reshape may silently fall back to copying data when the layout no longer fits.

Broadcasting: a visualized example with attention

Broadcasting shows up in all of the Q, K, and V projections in attention, but it is easiest to see through the query projection as one concrete example. If your input has shape (batch_size, context_length, embedding_dim) and your linear layer maps embedding_dim -> head_dim, then conceptually you are multiplying each token vector by a weight matrix of shape (embedding_dim, head_dim).

The interesting part is that one object is 3D and the other is 2D, but PyTorch still knows how to do the multiplication. It treats the last dimension of the input as the feature dimension and applies the same projection independently across every batch and time position.

import torch
import torch.nn as nn

batch_size = 2
context_length = 4
embedding_dim = 8
head_dim = 3

x = torch.randn(batch_size, context_length, embedding_dim)
W_q = torch.randn(embedding_dim, head_dim)

q = x @ W_q

print(x.shape)
# torch.Size([2, 4, 8])

print(W_q.shape)
# torch.Size([8, 3])

print(q.shape)
# torch.Size([2, 4, 3])

You can read that as: for each one of the batch_size * context_length token vectors, multiply a length-embedding_dim vector by the same query matrix and produce a length-head_dim output vector.

Each token vector of length embedding_dim gets multiplied by the same query matrix and becomes a vector of length head_dim. PyTorch just does all of those projections in one batched operation.

x[0, 0] @ W_q  # one token -> shape [head_dim]
x[0, 1] @ W_q  # next token -> shape [head_dim]
x[1, 3] @ W_q  # another token -> shape [head_dim]

# PyTorch batches all of these together:
q = x @ W_q

So the broadcasting intuition is not that PyTorch literally copies the weight matrix across the batch and sequence dimensions. Instead, it reuses the same 2D weight matrix across all those positions and performs the batched matmul as if that projection were being applied over and over to each token.

This is exactly the kind of thing that happens inside torch.nn.Linear. The conceptual math is input times a matrix of shape (embedding_dim, head_dim), but PyTorch stores the linear layer weight internally as (head_dim, embedding_dim). That transposed storage is convenient because it reduces extra work in the backward pass and fits the way the internal matrix multiplies are implemented.