x_tok = self.embedding(idx) x_pos = self.pos_encoding(torch.arange(T, device=idx.device)) x = x_tok + x_pos