The Context-Ready Transformer
Abstract
We introduce the context-ready transformer, a new recurrent neural network architecture built from a -layer transformer block that pre-contextualizes each token before it enters the block. During left-to-right generation, a correction network combines the previous position’s block output—a cached summary of past context—with the current token embedding, so the token enters the block already contextualized rather than as a raw embedding. At sequential inference, the correction chain makes the architecture a recurrent neural network. For training, we unroll the correction process times over the full sequence, processing all positions in parallel at each step. A pretrained transformer can also be converted to a context-ready model by adding a zero-initialized correction FFN and fine-tuning. We evaluate across widths, depths, block sizes, and two datasets, with all comparisons against standard transformers, variants, and ablations. A model beats a 12-layer transformer while generating faster on an A100. With , a single-layer model () beats a 6-layer transformer with a inference speedup, and sequential inference matches parallel to within 0.01 PPL. The architecture benefits most from wide representations and long contexts. On a pointer-chasing task, trained with BPTT solves all 10 composition levels, while standard transformers exhibit staircase-like depth dependence.
1 Introduction
In autoregressive generation, a standard -layer transformer predicts the next token, assigns it a context-free base embedding, and devotes part or all of its layers to re-contextualize it. This round-trip from context to token ID back to context is an artifact of the architecture, not of the problem.
Context-ready inference. The context-ready transformer shortens this round-trip. Consider sequential left-to-right generation: as the model processes token , the block output encodes the context of tokens . When token arrives with embedding , a correction network computes a correction from and . The token enters the block as plus this correction—carrying contextual information from all preceding tokens—rather than as a raw embedding. Fewer layers are needed to fully contextualize the token, because the correction has already done part of the work.
The training challenge. This sequential process is exact at inference but inherently serial: each position’s correction depends on the previous position’s fully computed block output. A classical RNN faces the same issue and solves it with backpropagation through time (BPTT), which unrolls the recurrence for all positions—making the sequential training depth scale with sequence length.
We take a different approach. We approximate the sequential process by unrolling it times over the full sequence, where is a small constant independent of . All positions are still processed in parallel—just as in a standard transformer—but the correction is refined over unrolling steps rather than . Starting from raw embeddings at all positions, we run a shared-weight block and compute the correction at each position from the block output at positions . We then update all embeddings with these corrections and repeat. Each unrolling step refines the corrections using increasingly contextualized inputs. Crucially, controls the depth of the computation graph—not —so a classical RNN requires -deep unrolling while the context-ready transformer requires only . In practice, suffices for convergence at the depths tested (; see Table 6).
How unrolling steps at training lead to zero iterations at inference. Two design choices make this possible.
-
•
Non-cumulative correction. Each iteration computes a correction from scratch rather than building on the previous one. The iteration takes the form : the base embedding plus a correction that depends on both the previous iterate’s block output and the token embedding. Unlike a residual network, which computes —a sum of corrections—our formulation yields : a single correction. Previous iterations serve only to bring close to the fixed point; once it has converged, additional iterations produce the same output.
-
•
Past-only contextualization. The correction at position depends on two quantities: the block output , which encodes the context of tokens , and the token embedding . Since is already cached from processing the previous token, the correction can be computed as soon as token arrives.
Together, these two properties mean that sequential generation naturally produces the converged correction for each new token without iteration (Section 4).
A new kind of RNN. At sequential inference, the correction at position depends on , which depends on , and so on—creating a recurrent computation that unfolds over all positions. The non-cumulative + past-only structure is what enables efficient training: it converts the sequential recurrence into a fixed-point problem, so instead of exact -step BPTT, we can train by unrolling steps over the full sequence in parallel (). This is not equivalent to BPTT—positions beyond the first receive approximate rather than exact gradients—but in practice suffices at the depths tested.
Experimental evidence. We evaluate across widths , depths , block sizes , two datasets (OpenWebText, Wikipedia), and a synthetic reasoning task (Section 5). The most impactful findings: at beats a 12-layer transformer (), halving inference depth and generating faster on an A100. With , at beats a 6-layer transformer () with a inference speedup. Sequential inference matches training to within 0.01 PPL, confirming that streaming exactness holds in practice. On pointer chasing, trained with BPTT solves all 10 composition levels while standard transformers exhibit staircase-like depth dependence. The architecture benefits most from wide representations and long contexts, and requires a dedicated correction network to be effective (Section 5.8). Any pretrained transformer can be converted by adding a zero-initialized correction FFN and fine-tuning.
2 Related Work
Weight-shared and depth-adaptive architectures. ALBERT (Lan et al., 2020), Universal Transformers (Dehghani et al., 2019) with ACT (Graves, 2016), Deep Equilibrium Models (Bai et al., 2019), and Huginn (Geiping et al., 2025) share weights across layers or iterations and iteratively refine hidden states, but do not use a dedicated past-output/token-aware pre-block correction of the kind proposed here. We compare against standard transformers as the representative baseline for how tokens enter the block. Weight sharing in the context-ready transformer arises naturally from unrolling the sequential correction process into training iterations that process all positions in parallel, not as a design choice for parameter efficiency.
Early exit and layer skipping. LayerSkip (Elhoushi et al., 2024), ADEPT (Yoo et al., 2026), and PonderNet (Banino et al., 2021) reduce average layer count but require learned stopping mechanisms. Context-ready uses a fixed depth with no stopping criterion.
Lookahead Decoding (Fu et al., 2024) and CLLMs (Kou et al., 2024) apply Jacobi iteration to standard transformers as a decoding strategy. Bai et al. (2021) accelerate DEQ inference by parallelizing fixed-point solves via Jacobi-style updates. Context-ready is an architectural change: the -step unrolling can be viewed as Jacobi iterations on the correction fixed-point equation, but the non-cumulative past-only structure guarantees that a single left-to-right streaming pass recovers the exact correction without any iteration.
Subquadratic and recurrent alternatives. Mamba (Gu and Dao, 2024), RWKV (Peng et al., 2023), Griffin (De et al., 2024), and xLSTM (Beck et al., 2024) replace causal self-attention with compressed recurrent state to achieve linear-time inference. The context-ready transformer solves a different problem: it retains full causal self-attention inside the block and instead changes how tokens enter it. The two approaches are complementary, not competing—one could in principle apply pre-block correction to any of these architectures—and standard transformers remain the natural baseline for evaluating the correction mechanism.
Computational complexity of transformers. Log-precision fixed-depth transformers are confined to (Merrill and Sabharwal, 2023): they cannot solve problems requiring unbounded sequential composition, regardless of width. RNNs with arbitrary precision escape this limitation (Siegelmann and Sontag, 1995a; Siegelmann and Sontag, 1995b), but classical RNNs require gradients to flow through all time steps (BPTT), making them difficult to train on long sequences. The context-ready transformer at is recurrent at inference, but is trained by unrolling the correction times rather than exact -step BPTT—inheriting recurrent structure at inference while retaining transformer-style parallel training.
3 Method
3.1 Architecture
We first describe the architecture during sequential left-to-right generation—the setting where context-ready inference is exact. Let denote the sequence length, the embedding dimension, and the vocabulary size. The core component is a -block unit: transformer layers (each consisting of causal self-attention and a feed-forward network with residual connections), described in detail below. Processing token requires the outputs of this block unit from all preceding tokens. Let denote the token embedding at position , and define .
Correction. A dedicated correction FFN generates the correction for position from the cached block output and the current token embedding :
| (1) |
The correction FFN is a feed-forward network (Linear() GELU Linear()) with its own weights, separate from the block’s FFN. The correction is token-aware: it depends on both (context of tokens ) and (the current token embedding). Since both inputs are available when token arrives, the correction is causal—it depends on no future tokens.
Contextualization. The new token enters the block with the correction added to its raw embedding:
| (2) |
Block processing. A -block unit applies transformer blocks with separate weights and standard residual connections:
| (3) |
Each is causal self-attention with Rotary Position Embeddings (RoPE) (Su et al., 2024). Each : Linear() GELU Linear(). The parameter controls inference depth.
Prediction.
| (4) |
After prediction, is cached and the KV caches are updated for future tokens.
3.2 Parallel Training
Sequential inference is exact but inherently serial. For training, we unroll the correction process times, processing all positions in parallel at each step. Gradients flow through unrolling steps rather than time steps, making the architecture trainable like a transformer despite being recurrent at inference.
Given token embeddings , with for all (the initial cache from Section 3.1), initialize . For :
| (contextualize) | |||||
| (block output) | |||||
| (5) | |||||
The transformer blocks share weights across iterations but have separate weights across layers .
Non-cumulative correction. Each iteration replaces the previous correction entirely: , not . Only the last correction matters.
Past-only correction. The correction at position uses . Corrections propagate left to right: position converges after one iteration, position after two, and so on.
Random-depth training (). We sample each batch with , forcing the model to produce good predictions at any depth, which empirically encourages contraction.
Loss and dropout. The training loss (cross-entropy) is computed on the logits from the final iteration only. Dropout masks are resampled independently at each unrolling step .
3.3 Streaming Inference
When a new token arrives with embedding , the model computes the correction from the cached and , passes the corrected embedding through the -block unit, caches the output , and predicts. This is one forward pass—no iteration over steps—regardless of the training depth .
Why inference needs no iteration. During training, unrolling steps refine the corrections. At inference, this iteration is unnecessary: since earlier positions are already computed and cached, the correction for token is fully determined by and in a single pass (Theorem 2). The training approximation matters only during training: the first positions converge exactly after steps (Lemma 1 in Appendix A.2), and for later positions, the approximation error shrinks geometrically with when the correction operator is contractive (Theorem 3).
4 Theoretical Analysis
Full formal statements and proofs are in Appendix A.
Theorem 1 (Structural characterization).
Why non-cumulative and past-only? Under Assumptions I–II (Appendix A.1), if a weight-shared architecture unrolls a shared block times during training and applies it once per token during streaming, then for the unrolled training to converge to the same output that streaming produces, the correction must be non-cumulative (, not a sum of successive increments) and past-only (the correction at position depends only on and corrections from positions ). The resulting system has a unique fixed point, and streaming computes it exactly.
Full proof in Appendix A.1.
Theorem 2 (Exact streaming fixed point).
Why is inference exact without iteration? During sequential generation, the correction at position depends only on , which are already computed and cached. By prefix consistency (appending tokens does not change the operator at earlier positions, which holds by causal masking), the correction is exact in a single pass.
Full proof in Appendix A.2.
Theorem 3 (Training convergence).
How fast does the training iteration converge? If the correction operator is -Lipschitz with , then unrolling steps reduce the error to the fixed point by a factor of . This governs the training approximation: positions beyond the first receive approximate corrections, and the approximation improves geometrically with .
Full formal statement in Appendix A.3.
Proposition 1 (Depth separation).
Why can the context-ready architecture use fewer layers? Under a stylized state-tracking abstraction (Appendix A.4):
-
(a)
Context-ready propagation is handled by the correction chain. The context-ready architecture needs only layers for the per-token map; propagation across the sequence is handled by the recurrent correction chain rather than extra transformer layers.
-
(b)
Standard transformers need depth for propagation. With attention window , a standard transformer needs at least layers just for information from the earliest tokens to reach position , on top of the layers needed for the per-token map.
Full statement and proof in Appendix A.4. When propagation and local computation cannot be interleaved (as in pointer chasing), the standard transformer needs at least additional layers; in general, some layers may serve both roles.
5 Experiments
5.1 Setup
Data. OpenWebText (Gokaslan and Cohen, 2019) with byte-pair encoding (BPE, vocabulary 32,000). Context lengths: 64, 256, 512, and 1024 depending on the experiment. Ablations use English Wikipedia (BPE 16k); additional Wikipedia results in Appendix C.8. All results are validation perplexity (PPL) on held-out splits.
Training. AdamW optimizer (Loshchilov and Hutter, 2019), gradient clipping at 1.0. Learning rate unless noted. Training depth with by default; where noted. Dropout 0.2, FFN expansion factor 4. Full hyperparameters in Appendix C.
FLOP accounting.111Following standard convention in the literature, we count multiply-accumulate operations and label them FLOPs; actual floating-point operations are roughly . We report total FLOPs per token, including the transformer blocks, correction FFN, and prediction head (, costing FLOPs/token). Each transformer block costs FLOPs ( for attention projections, for the FFN). A context-ready model with blocks costs ; a standard -layer transformer costs . Same-width comparisons (Tables 2, 3) share the same prediction head and are FLOP-matched. Cross-width comparisons (Table 1) explore depth-width tradeoffs: wider models pay a larger prediction head, so total FLOPs differ. Despite higher total FLOPs, the wider, shallower context-ready models deliver lower wall-clock inference time because fewer sequential layers dominate latency on modern GPUs (Section 5.4).
Baselines. Standard transformers with separate weights per layer and RoPE attention (Su et al., 2024). All results are single runs. The breadth of the evaluation—across widths (–), depths (–), block sizes (64–1024), two datasets, a synthetic task, and multiple training strategies—provides stronger evidence than multi-seed runs on a single configuration: the context-ready architecture wins consistently across all these axes.
5.2 Cross-Width Results
| Model | FLOPs/tok | Val PPL | ||
|---|---|---|---|---|
| D=5 | 121M | 224 | 36.38 | |
| D=6 | 117M | 171 | 36.56 | |
| Roformer | 120M | 181 | 37.76 | 1.38 |
| Roformer | 110M | 64 | 37.83 | 1.45 |
| Roformer | 146M | 944 | 42.99 | |
| Roformer | 376M | 45 | 28.68 | |
| Roformer | 389M | 128 | 29.01 | |
| D=6 | 401M | 341 | 29.04 | 0.03 |
| Roformer | 411M | 363 | 30.35 |
Table 1 compares context-ready models against standard transformers across depth-width tradeoffs at two compute scales. At the smaller scale (block size 256, 100K iterations), context-ready at achieves 36.38 PPL, beating both roformer at (37.76, ) and roformer at (37.83, ). At the larger scale (block size 256, 200K iterations), context-ready at (29.04) matches roformer at (29.01), despite using a much higher width-to-depth ratio. Depth has diminishing returns: going from to gains only 0.33 PPL.
5.3 Correction Efficiency
| Model | Inference FLOPs | Val PPL |
|---|---|---|
| Roformer | 33.41 | |
| Roformer | 32.82 | |
| D=12 context-ready | 32.28 | |
| Roformer | 32.34 | |
| Roformer | 29.42 | |
| D=23 context-ready | 28.89 |
Table 2 isolates the value of the correction mechanism at fixed width (), so the prediction head is identical across all rows and the comparison is FLOP-matched. context-ready ( FLOPs) beats roformer () by 0.54 PPL and matches roformer (). The correction FFN adds only FLOPs yet provides a genuine PPL improvement at the same parameter budget. At deeper scale, () edges out () by 0.53 PPL—directionally consistent with the result, but a single-run margin that should be read as evidence of parity rather than robust superiority.
5.4 Width Scaling
| PPL | PPL | Relative | ||
|---|---|---|---|---|
| 256 | 158.83 | 143.57 | 15.26 | 9.6% |
| 512 | 95.48 | 84.69 | 10.79 | 11.3% |
| 1024 | 72.83 | 60.84 | 11.99 | 16.5% |
Table 3 tests whether the correction advantage is a small-scale artifact. Comparing vs. at the same width with Chinchilla-matched token budgets, the relative improvement grows from 9.6% at to 16.5% at .
Token-matched results. At with block size 64, beats at every depth tested once training progresses past a crossover point: beats by 34.4 PPL (crossover at M tokens), by 7.6 PPL (M), by 3.0 PPL (M), and by 0.7 PPL (M). All gaps are still growing at the end of training. The correction mechanism provides a consistent advantage at every depth; the advantage is largest when depth is small, consistent with the correction doing the most work when the block has the fewest layers.
Inference latency. We measure autoregressive generation speed on an A100 over 10,000 tokens with KV caching. (149M FLOPs/tok) generates at 919 tokens/s vs. 351 tokens/s for roformer (120M FLOPs/tok)—a speedup despite higher total FLOPs. (121M FLOPs/tok) generates at 349 tokens/s vs. 201 tokens/s for roformer (110M FLOPs/tok)—a speedup. The wider, shallower models are faster because fewer sequential layers dominate inference latency on modern GPUs, even when total FLOPs are higher. Per-token latency is flat across sequence length, confirming that KV caching amortizes attention cost to per token. Full timing details in Appendix C.9.
KV cache savings. Fewer layers also reduce KV cache memory ( per token). Despite being wider, at uses less cache than at ; at uses less than at .
5.5 Single-Layer Performance ()
Proposition 1(a) predicts that may match deeper transformers when the task is dominated by context propagation and sufficient and width are provided. We test , (149M FLOPs/tok) against roformer , (120M FLOPs/tok) at block size 1024 using three training strategies: fixed , random-depth with , and fine-tuning from a pretrained roformer (batch 16).
| Strategy | PPL at 100K | Best PPL | Notes |
| Roformer (baseline) | 35.37 | — | |
| , fixed depth | 34.40 | 33.63 (110K) | Beats at 65K |
| , | 36.28 | 33.63 (135K) | +1.88 penalty at 100K |
| , fine-tuned | 46.66∗ | 31.35 (215K) | Best final PPL |
| ∗At 100K total (15K fine-tune); fine-tune reaches 35.92 at 150K total. | |||
Table 4 compares the three strategies. In these experiments, larger training depth substantially improves the model’s ability to exploit the available recurrent depth. With fixed , surpasses at K iterations and reaches 34.40 at 100K (, still improving) at per-iteration cost.
Random-depth training () incurs a modest penalty of 1.88 PPL at 100K relative to fixed depth, but converges to the same quality K later (both reach 33.63). The penalty yields values consistent with contraction ( vs. for fixed ).
Fine-tuning from a pretrained roformer () with a zero-initialized correction FFN at reaches the best final PPL: 31.35 at 215K total iterations (the baseline is at 100K, so total compute differs).
The gap between and shrinks with block size ( at 256, at 512, at 1024 with ), consistent with longer contexts providing more sequential steps. Full block-size scaling in Appendix C.5.
5.6 Pointer Chasing: Depth Separation
| Model | Levels solved | Iters |
|---|---|---|
| Roformer | 1 / 11 | 50K |
| Roformer | 3 / 11 | 50K |
| Roformer | 6 / 11 | 50K |
| Roformer | 7 / 11 | 50K |
| Roformer | 8 / 11 | 50K |
| Roformer | 11 / 11 | 50K |
| D=1 context-ready (BPTT) | 11 / 11 | 16K |
An -layer transformer can compose at most sequential reasoning steps in a single forward pass (Merrill and Sabharwal, 2023). To test whether the context-ready architecture can exceed this depth limit, we design a pointer-chasing task. The input contains a base table that maps keys to values, followed by levels of index tables, each of which maps new keys to keys at the previous level. Answering a query at level therefore requires chaining sequential lookups through the tables, and we use windowed causal attention (window ) to prevent the model from bypassing these chains by attending directly to the base table. A full task specification with a worked example is given in Appendix B.
Table 5 shows a staircase-like depth dependence: deeper transformers solve more levels, while shallow transformers fail well before full depth. The context-ready model, trained with BPTT to exploit the full sequential depth of the recurrent correction chain, solves all 11 levels in K iterations and scales to 20 hops (all 21 levels).
5.7 Fine-Tuning from Pretrained
Any standard transformer can be converted to context-ready by adding a zero-initialized correction FFN and fine-tuning. To isolate the effect of conversion from additional training, we compare against a same-iteration control: the original roformer trained for the same total number of iterations without conversion (per-iteration compute differs because fine-tuning runs the block times). At , reaches 29.92 PPL at 200K iterations; continued training to 400K yields 27.20. Converting at 200K and fine-tuning to 400K total iterations yields 26.14—a gain of PPL over the continued-training baseline at matched iterations. At , converting to improves from 29.42 to 28.99 () in 18K fine-tuning iterations. The zero-initialized correction ensures no disruption at conversion (the model is function-preserving). During fine-tuning, sees a transient PPL increase before recovering, while shows no transient increase.
5.8 Ablations and Diagnostics
We compare against alternative architectures that attempt the same goal. Among the variants tested, the gain depends on a dedicated correction network with its own weights.
Convergence and sequential exactness. Table 6 shows the full depth progression. Convergence is geometric: at (block size 1024), closes 91% of the -to- gap and closes 98%, consistent with Theorem 3. Sequential matches parallel to within 0.01 PPL at every configuration, confirming Theorem 2. The correction contribution (“Corr.” column) grows as shrinks: 38–55 PPL at vs. 3.9 at .
| bs | Seq | Corr. | ||||||
|---|---|---|---|---|---|---|---|---|
| 1 | 256 | 70.80 | 36.21 | 33.23 | 32.70 | 32.79 | 32.80 | 38.1 |
| 1 | 512 | 73.22 | 34.33 | 31.24 | 30.61 | 30.68 | 30.69 | 42.6 |
| 1 | 1024 | 84.13 | 34.42 | 30.39 | 29.51 | 29.43 | 29.43 | 54.7 |
| 5 | 256 | 58.17 | 40.18 | 38.80 | 38.60 | 38.61 | 38.62 | 19.6 |
| 12 | 256 | 38.33 | 32.58 | 32.30 | 32.29 | 32.29 | 32.29 | 6.0 |
| 23 | 256 | 32.80 | 29.03 | 28.89 | 28.88 | 28.88 | 28.88 | 3.9 |
Contraction. With , empirical (measured as ); without , . is a trajectory-local diagnostic, not the global in Theorem 3.
Correction FFN is essential. Using the block’s own residual as the correction (“block_head”) gives no improvement (27.32 vs. 27.19 for a standard transformer). With a dedicated correction FFN, context-ready beats the FLOP-matched baseline by 1.8 PPL. Tying correction FFN weights to the block FFN collapses performance. The add variant () matches or beats token-blind at all depths.
Training efficiency. suffices at ; with torch.compile achieves faster training at 1.07 PPL cost. Full ablation tables in Appendix C.2.
6 Conclusion
A standard transformer assigns each new token a context-free embedding and relies entirely on depth to contextualize it. The context-ready transformer shortcuts this process: a correction derived from the previous position’s block output pre-contextualizes the token before it enters the block. Two structural choices—non-cumulative correction and past-only contextualization—make this exact at streaming inference (Theorem 2) and trainable with unrolling steps (each processing all positions in parallel) rather than full BPTT.
A model beats a 12-layer transformer while generating faster; beats a 6-layer transformer with a speedup, and sequential inference matches parallel to within 0.01 PPL. The advantage grows with width and context length. Fewer layers also reduce KV cache memory. On pointer chasing, solves all composition levels that standard transformers need proportional depth to reach. Any pretrained transformer can be converted by adding a zero-initialized correction FFN and fine-tuning.
Limitations. Datasets and scale. Results are on OpenWebText, Wikipedia, and a synthetic task. We have validated at 110–150M and 375–410M FLOPs/token but not yet on standard benchmarks or at billion-parameter scale. Training cost. From-scratch training runs the block times per iteration rather than once. Backpropagation through steps also requires storing activations for all passes, scaling activation memory by . Several approaches can reduce this cost: (i) pretrain as a standard transformer at cost, then convert and fine-tune—in our experiments this matches or exceeds from-scratch context-ready training (Sections 5.5 and 5.7); (ii) random-depth training (), which samples each batch and converges to the same quality as fixed (Section 5.5). Prefill. Processing a prompt of length in parallel requires unrolling steps, giving effective prefill depth vs. for a standard transformer.
References
- Bai et al. (2019) Shaojie Bai, J. Zico Kolter, and Vladlen Koltun. Deep equilibrium models. In Advances in Neural Information Processing Systems, 2019.
- Bai et al. (2021) Shaojie Bai, Vladlen Koltun, and J. Zico Kolter. Accelerating feedforward computation via parallel nonlinear equation solving. In International Conference on Machine Learning, 2021.
- Banino et al. (2021) Andrea Banino, Jan Balaguer, and Charles Blundell. PonderNet: Learning to ponder. In ICML Workshop on Uncertainty and Robustness in Deep Learning, 2021.
- Beck et al. (2024) Maximilian Beck, Korbinian Pöppel, Markus Spanring, Andreas Auer, Oleksandra Prudnikova, Michael Kopp, et al. xLSTM: Extended long short-term memory. arXiv preprint arXiv:2405.04517, 2024.
- De et al. (2024) Soham De, Samuel L. Smith, Anushan Fernando, Aleksandar Botev, et al. Griffin: Mixing gated linear recurrences with local attention for efficient language models. arXiv preprint arXiv:2402.19427, 2024.
- Dehghani et al. (2019) Mostafa Dehghani, Stephan Gouws, Oriol Vinyals, Jakob Uszkoreit, and Łukasz Kaiser. Universal transformers. In International Conference on Learning Representations, 2019.
- Elhoushi et al. (2024) Mostafa Elhoushi, Akshat Shrivastava, Diana Liskovich, Basil Hosmer, Bram Wasti, Liangzhen Lai, Anas Mahmoud, Bilge Acber, Saurabh Agarwal, Ahmed Roman, et al. LayerSkip: Enabling early exit inference and self-speculative decoding. arXiv preprint arXiv:2404.16710, 2024.
- Fu et al. (2024) Yichao Fu, Peter Bailis, Ion Stoica, and Hao Zhang. Break the sequential dependency of LLM inference using lookahead decoding. arXiv preprint arXiv:2402.02057, 2024.
- Geiping et al. (2025) Jonas Geiping, Tom Goldstein, Avi Schwarzschild, C. Bayan Bruss, et al. Scaling up test-time compute with latent reasoning: A recurrent depth approach. arXiv preprint arXiv:2502.05171, 2025.
- Gokaslan and Cohen (2019) Aaron Gokaslan and Vanya Cohen. Openwebtext corpus. http://Skylion007.github.io/OpenWebTextCorpus, 2019.
- Graves (2016) Alex Graves. Adaptive computation time for recurrent neural networks. arXiv preprint arXiv:1603.08983, 2016.
- Gu and Dao (2024) Albert Gu and Tri Dao. Mamba: Linear-time sequence modeling with selective state spaces. In Proceedings of ICML 2024, 2024.
- Kou et al. (2024) Siqi Kou, Lanxiang Hu, Zhezhi He, Zhijie Deng, and Hao Zhang. CLLMs: Consistency large language models. arXiv preprint arXiv:2403.00835, 2024.
- Lan et al. (2020) Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, and Radu Soricut. ALBERT: A lite BERT for self-supervised learning of language representations. In International Conference on Learning Representations, 2020.
- Loshchilov and Hutter (2019) Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. In International Conference on Learning Representations, 2019.
- Merrill and Sabharwal (2023) William Merrill and Ashish Sabharwal. The parallelism tradeoff: Limitations of log-precision transformers. In Transactions of the Association for Computational Linguistics, volume 11, pages 531–545, 2023.
- Peng et al. (2023) Bo Peng, Eric Alcaide, Quentin Anthony, Alon Albalak, Samuel Arcadinho, Huanqi Cao, Xin Cheng, Michael Chung, et al. RWKV: Reinventing RNNs for the transformer era. In Findings of EMNLP 2023, 2023.
- Siegelmann and Sontag (1995a) Hava T. Siegelmann and Eduardo D. Sontag. Computational capabilities of recurrent NARX neural networks. IEEE Transactions on Systems, Man, and Cybernetics, 26(4):535–544, 1995a.
- Siegelmann and Sontag (1995b) Hava T. Siegelmann and Eduardo D. Sontag. On the computational power of neural nets. Journal of Computer and System Sciences, 50(1):132–150, 1995b.
- Su et al. (2024) Jianlin Su, Murtadha Ahmed, Yu Lu, Shengfeng Pan, Wen Bo, and Yunfeng Liu. RoFormer: Enhanced transformer with rotary position embedding. Neurocomputing, 568:127063, 2024.
- Yoo et al. (2026) Seunghyun Yoo et al. ADEPT: Adaptive dynamic early-exit process for transformers. arXiv preprint arXiv:2601.03700, 2026.
NeurIPS Paper Checklist
-
1.
Claims
-
Answer: [Yes]
-
Justification: The abstract and introduction state the architectural contributions and experimental results with specific numbers. Limitations (datasets, scale, training cost) are discussed explicitly in Section 6.
-
2.
Limitations
-
Answer: [Yes]
-
Justification: Section 6 includes a dedicated Limitations paragraph covering dataset scope, scale, training cost, and single-run reporting.
-
3.
Theory assumptions and proofs
-
Answer: [Yes]
-
4.
Experimental result reproducibility
-
Answer: [Yes]
-
5.
Open access to data and code
-
Answer: [No]
-
Justification: Code is not released at submission time due to patent considerations. The architecture and training procedure are described in sufficient detail to reproduce.
-
6.
Experimental setting/details
-
Answer: [Yes]
-
7.
Experiment statistical significance
-
Answer: [No]
-
Justification: All results are single runs. We acknowledge this explicitly and report trends across multiple configurations (widths, depths, block sizes) rather than relying on individual comparisons.
-
8.
Experiments compute resources
-
Answer: [Yes]
-
9.
Code of ethics
-
Answer: [Yes]
-
Justification: The research conforms with the NeurIPS Code of Ethics. No human subjects, private data, or dual-use concerns.
-
10.
Broader impacts
-
Answer: [N/A]
-
Justification: This is foundational architecture research. The work reduces inference cost for language models, which has broadly positive efficiency implications but no direct path to specific negative applications beyond those inherent to language modeling in general.
-
11.
Safeguards
-
Answer: [N/A]
-
Justification: No pretrained language models or scraped datasets are released.
-
12.
Licenses for existing assets
-
Answer: [Yes]
-
Justification: OpenWebText is cited [Gokaslan and Cohen, 2019]. Wikipedia is publicly available. All referenced works are cited.
-
13.
New assets
-
Answer: [N/A]
-
Justification: No new datasets, models, or code are released with this submission.
-
14.
Crowdsourcing and research with human subjects
-
Answer: [N/A]
-
Justification: No crowdsourcing or human subjects research.
-
15.
Institutional review board (IRB) approvals or equivalent for research with human subjects
-
Answer: [N/A]
-
Justification: No human subjects research.
-
16.
Declaration of LLM usage
-
Answer: [N/A]
-
Justification: LLMs were not used as a component of the core methodology.
Appendix A Full Proofs
Notation. Throughout the appendix, denotes the base token embeddings, denotes the full correction operator, and denotes the correction vector at iteration . The fixed point is .
A.1 Proof of Theorem 1 (Structural Characterization)
Formal statement.
Setup. Let denote the base embedding at position , let Block denote a -layer transformer block with causal attention, and let be a correction function. The architecture processes position as follows: form the corrected input , compute , and cache for future positions.
Assumptions.
-
(I)
Additive correction. The corrected input has the form , where is a continuous correction function that maps into a bounded ball .
-
(II)
Single-pass streaming. During inference, tokens are processed left-to-right, one at a time. The correction at position uses the cached block output from the previous position and the current embedding . By causality, depends only on positions .
Conclusions.
-
(a)
Past-only. By streaming (Property II), the correction at position depends only on previously computed outputs: where encodes positions .
-
(b)
Non-cumulative. Among additive unrolling strategies, only the non-cumulative form is compatible with streaming. The cumulative (resnet) alternative fails because at the fixed point.
-
(c)
Existence and uniqueness. The non-cumulative past-only system has a unique fixed point, and streaming computes it exactly.
Proof.
Part (a): Past-only. Streaming (Property II) processes tokens left-to-right. When token arrives, the correction uses (cached from the previous token) and . Since causal attention ensures depends only on positions , the correction at position is a function of past outputs only.
Part (b): Non-cumulative. Given the additive correction form (Property I), there are two ways to unroll training steps:
Non-cumulative: . Each step recomputes the correction from the base embedding . At the fixed point, , a nonzero correction. In streaming, past outputs are exact (from cache), so a single evaluation gives . Exact.
Cumulative (resnet): . Each step adds an increment to the previous output. At the fixed point, , which forces : the correction function learns to output zero at convergence. In streaming, the single step starts from and gives . But was trained to vanish at , not at , so the result is neither zero nor the correct fixed point .
Therefore, among additive unrolling strategies, the non-cumulative form is the one that reproduces the correct fixed-point output during streaming.
Part (c): Existence and uniqueness. By part (a), the system is triangular: is determined first (from and ), then , then , and so on. Each step is a deterministic evaluation, so the fixed point exists, is unique, and is constructively computable by left-to-right evaluation—exactly the streaming computation. ∎
A.2 Proof of Theorem 2 (Exact Streaming)
Formal statement. Let be the unique fixed point for a sequence of length . Assume prefix consistency: for any and , the correction operator satisfies , i.e., appending tokens beyond position does not change the operator at earlier positions. (This holds by construction for any causal architecture where the operator at position depends only on positions .) Then:
-
(a)
Prefix invariance. for and any .
-
(b)
Exactness. If past corrections are at the fixed point, the streaming operator produces the exact fixed-point correction for the new token.
-
(c)
No contraction needed. Exactness holds without any contraction assumption.
The following lemma establishes that the Jacobi iteration converges in finitely many steps, which is the foundation for the streaming exactness proof.
Lemma 1 (Finite-step exactness).
Let be a past-only correction operator. Then: (a) The fixed point exists and is unique, without contraction. (b) After Jacobi iterations from any : for all . (c) The iteration reaches the exact fixed point after at most steps: .
Proof.
(a) By construction: is a constant, so is determined. Given , we have uniquely.
(b) By induction. Base: . Step: assume for . Then .
(c) Set in (b). ∎
Proof of Theorem 2.
(a) By prefix consistency, for . Hence the fixed-point equations for positions are identical in the length- and length- systems, so .
(b) . By (a), for . The streaming operator computes exactly using cached corrections . Concretely, the abstract operator is realized by the correction FFN applied to the cached block output and the current token embedding : once earlier positions are at their fixed-point values, is fully determined, and the streaming step computes exactly.
(c) By induction from the base case and part (b). ∎
A.3 Proof of Theorem 3 (Convergence)
Formal statement. Let denote the full correction operator at position . Assume for all . If , then:
-
(a)
Contraction. .
-
(b)
Warm-start bound. .
Proof.
Part (a). By the Banach fixed-point theorem with contraction constant .
Part (b). , so . Then . ∎
Theorem 3 assumes but does not say how to verify this from the per-position Jacobian structure. The following lemma provides a practical bound: given bounds on the partial derivatives , one can choose a weighted norm that makes the global Lipschitz constant explicit.
Lemma 2 (Causal contraction bound).
Let be the full-sequence correction operator with past-only dependencies, and suppose for , where is the operator norm. For positive weights , define . Then is -Lipschitz in with constant:
In particular, if , then is a contraction.
Proof.
For each : . Multiplying by and taking the maximum over gives . ∎
When is chosen at training time, one may want to know how close is to without computing . The next lemma gives a computable bound using only consecutive iterates.
Lemma 3 (A posteriori error bound).
If , then after iterations: .
Proof.
By the triangle inequality, . ∎
Theorem 3(a) gives a global contraction rate that treats all positions uniformly, but this is overly pessimistic. Because is past-only, its Jacobian is strictly lower-triangular and therefore nilpotent, so position reaches its exact fixed point after at most iterations rather than . The following lemma exploits this structure to give a tighter, position-dependent error bound via causal path sums.
Lemma 4 (Finite-depth error bound).
Let be a past-only correction operator with Jacobian bounds . Let be the strictly lower-triangular matrix with entries , and . After Jacobi iterations from :
Proof.
By the integral form of the mean value theorem, the error satisfies where by the Jacobian bound assumption. Bounding by the entrywise matrix and iterating from gives . To bound : the fixed point satisfies , so , i.e., , hence (well-defined since is nilpotent). Substituting: . ∎
A.4 Proof of Proposition 1 (Depth Separation)
Formal statement.
Setup. Suppose the data is generated by a process with state update , where is the process state and collects the token embeddings in the current window. The context-ready architecture (Equations 1–3.1) with attention window maintains a state over the full window of positions, where is the per-position state dimension. The state evolves via , where composes the correction FFN and block (Equations 1–3.1). Let be the Lipschitz constant of in its first argument.
To compare these two systems, let be a map from the process state to the architecture’s state space. The architecture faithfully tracks the process when the following diagram commutes: advancing the process state by and then projecting gives the same result as projecting and then advancing by . The commutation error
measures how far the diagram is from commuting: it captures both the block’s finite-depth approximation error and any information lost when the two state spaces differ.
Assumptions.
-
(i)
and .
-
(ii)
Prediction sufficiency. preserves prediction-relevant information: .
-
(iii)
Lipschitz readout. The prediction function is -Lipschitz.
Conclusions.
-
(a)
Context-ready error bound. The accumulated state error satisfies uniformly in . By prediction sufficiency and Lipschitz readout, the prediction error is bounded by . If , then streaming is exact for all .
-
(b)
Standard transformer receptive-field bound. Let be the number of layers in a standard transformer with attention window . Then position has no computational path to any position before , so the transformer cannot represent any function whose output at position depends on inputs before .
Proof.
Part (a). At step , the architecture computes while the projected true state satisfies . By the triangle inequality:
Unrolling with gives .
Part (b). With attention window , the output of layer at position depends only on positions in . After layers, position has no computational path to any position before . ∎
Remark (Depth separation).
Parts (a) and (b) together suggest a depth separation: the context-ready architecture propagates context through the correction chain at no additional depth cost, while a standard windowed transformer must allocate layers for propagation. When propagation and local computation cannot be interleaved—as in the pointer-chasing task where each hop requires a separate lookup—the standard transformer needs at least additional layers. In general, some layers may serve both roles, so is an upper bound on the required depth.
Appendix B Pointer Chasing Details
Motivation. Fixed-depth transformers are confined to [Merrill and Sabharwal, 2023]: an -layer transformer can compose at most sequential reasoning steps in a single forward pass. We design a synthetic task that directly tests this depth limit. Answering a query requires chaining a variable number of sequential lookups, so a model that can only perform a fixed number of parallel steps will fail once the required chain length exceeds its depth. The context-ready architecture sidesteps this barrier because its recurrent correction chain provides sequential computation at inference, even with a single block ().
Task definition. The pointer-chasing task has hops and keys per level. The input contains a base table (level 0) mapping keys to values, followed by index tables (levels ), each mapping keys to keys of the previous level via random permutations (bijections). After each table, a query section provides dense targets: a triplet Q key answer for every key at every level defined so far. Resolving a query at level requires sequential lookups.
Worked example (, , 10 values). The base table maps Av3, Bv0, Cv8. Index table 1 maps DA, EB, FC. Index table 2 maps GD, HE, IF. The encoding uses reversed triplets (value=key) so that causal attention can see the value to the left of the key:
Level 0 (base table + queries):
Input:
v3
=
A
v0
=
B
v8
=
C
Target:
_
_
_
_
_
_
_
_
_
_
Input:
Q
A
v3
Q
B
v0
Q
C
v8
Target:
_
v3
_
_
v0
_
_
v8
_
_
Level 1 (index table + queries):
Input:
A
=
D
B
=
E
C
=
F
Target:
_
_
_
_
_
_
_
_
_
_
Input:
Q
D
v3
Q
E
v0
Q
F
v8
Target:
_
v3
_
_
v0
_
_
v8
_
_
Level 2 (index table + final query):
Input:
D
=
G
E
=
H
F
=
I
Q
G
Target:
_
_
_
_
_
_
_
_
_
_
_
v3
Targets (bold) appear only at key positions in query sections. Level-0 queries are trivial lookups (Q A v3). Level-1 queries require one composition (Q D A v3). Level-2 queries require two compositions (Q G D A v3). The final token is the actual test query with no answer provided in the input. Dense targets at every level are essential: without them, the model cannot learn multi-hop composition even with BPTT.
Each level uses its own key namespace (A, B, C at level 0; D, E, F at level 1; G, H, I at level 2; etc.) to prevent ambiguity. Key ordering within each table is fixed (not shuffled), so the model can exploit positional patterns via RoPE.
Settings. hops, keys, 10 values, embedding dimension , 4 attention heads, batch size 64, window size 38, fixed key ordering, per-level key tokens, RoPE attention. Learning rate .
Why windowed attention. Without windowed attention, all models—including deep transformers—can directly attend from any query position to the base table, achieving accuracy without genuine composition. Windowed attention () ensures that higher-level query sections cannot see the base table, forcing the model to chain through intermediate levels. This reveals the true depth-limited structure of fixed-depth transformers.
Wave propagation in BPTT. The model solves levels sequentially: level 0 converges first, then level 1, then 2, and so on. This is visible in the training dynamics (see progression below). The wave pattern is consistent with corrections propagating through the recurrent chain.
BPTT progression. The progression below uses a smaller configuration (, lr , 20 values) to demonstrate wave propagation at reduced compute:
| Iter | L0 | L1 | L2 | L3 | L4 | L5 | L6 | L7 | L8 | L9 | L10 |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 3K | 1.00 | 0.74 | 0.27 | 0.20 | 0.19 | 0.18 | 0.18 | 0.18 | 0.17 | 0.16 | 0.16 |
| 8K | 1.00 | 1.00 | 0.80 | 0.59 | 0.36 | 0.21 | 0.21 | 0.19 | 0.20 | 0.18 | 0.15 |
| 13.5K | 1.00 | 1.00 | 0.99 | 0.99 | 0.94 | 0.83 | 0.51 | 0.24 | 0.20 | 0.20 | 0.18 |
| 23K | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 0.99 | 0.99 | 0.99 | 0.99 | 0.98 | 0.98 |
20-hop scaling. At with , the same architecture solves all 21 levels (20 hops) in K iterations, confirming that the recurrent mechanism scales to deeper composition chains.
Appendix C Extended Experimental Details
C.1 Hyperparameters
| Hyperparameter | M | M | Token-matched | Width scaling |
|---|---|---|---|---|
| Block size | 256 | 256 | 64 | 64 |
| Batch size | 64 | 32 | 1024 | 1024 |
| Learning rate | ||||
| Training iters | 100K | 200K | varies | varies |
| / | 5 / 2 | 5 / 2 | 5 / 2 | 5 / 2 |
| Dropout | 0.2 | 0.2 | 0.2 | 0.2 |
| Vocab size | 32,000 | 32,000 | 32,000 | 32,000 |
C.2 Ablations
| Model | FLOPs/tok | Val PPL | Seq. |
|---|---|---|---|
| D=3 corr_ffn | 23.98 | 23.96 | |
| Roformer-hFFN | 25.78 | — | |
| D=3 block_head (no corr_ffn) | 27.32 | 28.46 | |
| Roformer | 27.19 | — |
| Variant | FLOPs | PPL | Seq | ||
|---|---|---|---|---|---|
| 2 | corr_ffn (token-blind) | 26.68 | 26.72 | 0.74 | |
| corr_ffn_add | 26.09 | 26.48 | 0.54 | ||
| corr_ffn_concat | 25.48 | 25.82 | 0.54 | ||
| 3 | corr_ffn (token-blind) | 23.98 | 23.96 | — | |
| corr_ffn_add | 23.79 | 24.12 | 0.55 | ||
| corr_ffn_concat | 23.41 | 23.73 | 0.74 |
| Comparison | Context-Ready | Baseline | ||
|---|---|---|---|---|
| 50 | D=3 vs. Roformer-hFFN | 84.3 | 83.0 | +1.3 |
| 74 | D=3 vs. Roformer-hFFN | 62.1 | 61.4 | +0.7 |
| 446 | D=3 vs. Roformer | 23.79 | 24.85 | 1.06 |
| 768 | D=3 vs. Roformer | 18.66 | 20.05 | 1.39 |
At very small widths (), the correction FFN’s overhead outweighs its benefit; the correction advantage emerges at moderate widths () and grows with scale. All main-body claims are based on results at .
| Metric | No | |
|---|---|---|
| Val PPL () | 84.32 | 84.16 |
| Seq | 84.61 | 84.19 |
| Parallel | 118.35 | 130.95 |
| Empirical | 0.72 | 0.94 |
C.3 Training Details
All experiments: block size 1024, , softmax attention, , OWT.
| Iter | Gap | ||
|---|---|---|---|
| 40K | 43.24 | 43.83 | +0.59 |
| 60K | 39.21 | 39.27 | +0.06 |
| 65K | 38.55 | 38.49 | 0.06 |
| 80K | 36.99 | 36.38 | 0.61 |
| 100K | 35.37 | 34.40 | 0.97 |
C.4 Sequential Validation
Full depth progression for all configurations, confirming Theorem 2.
| Block size | Iters | Par. | Par. | Par. | Par. | Par. | Seq. |
|---|---|---|---|---|---|---|---|
| 256 | 100K | 84.44 | 43.77 | 40.52 | 39.95 | 40.02 | 40.03 |
| 256 | 400K | 70.80 | 36.21 | 33.23 | 32.70 | 32.79 | 32.80 |
| 512 | 100K | 81.02 | 38.25 | 35.01 | 34.35 | 34.41 | 34.43 |
| 512 | 200K | 73.22 | 34.33 | 31.24 | 30.61 | 30.68 | 30.69 |
| 1024 | 200K | 84.13 | 34.42 | 30.39 | 29.51 | 29.43 | 29.43 |
| Par. | Par. | Par. | Par. | Par. | Seq. | |
|---|---|---|---|---|---|---|
| 5 | 58.17 | 40.18 | 38.80 | 38.60 | 38.61 | 38.62 |
| 8 | 51.60 | 40.12 | 39.22 | 39.10 | 39.10 | 39.10 |
| 12 | 38.33 | 32.58 | 32.30 | 32.29 | 32.29 | 32.29 |
| 23 | 32.80 | 29.03 | 28.89 | 28.88 | 28.88 | 28.88 |
At higher , convergence is faster: and match within 0.01 PPL at ; is within 0.12 PPL. Parallel ratio to actual quality shrinks with (from at to at ), confirming that the correction mechanism accounts for a larger fraction of quality at low .
C.5 Block Size Scaling
Longer context lengths give the correction chain more sequential steps to accumulate depth, so the gap to should shrink with block size. Table A9 confirms this: at , the gap narrows from at block size 256 to at 1024. With at block size 1024, overtakes entirely (Section 5.5).
| Block size | PPL | PPL | Gap |
|---|---|---|---|
| 256 | 34.15 | 35.34 | +1.19 |
| 512 | 30.11 | 30.57 | +0.46 |
| 1024 | 29.22 | 29.53 | +0.31 |
C.6 Token-Matched Training Curves
To isolate the effect of the correction mechanism from FLOP differences, we compare context-ready against standard transformers at the same embedding dimension (, block size 64), so both see the same number of tokens per training iteration. At every depth tested, the context-ready model overtakes the baseline after a crossover point and the gap continues to grow.
| Comparison | PPL | PPL | Gap | Crossover |
|---|---|---|---|---|
| vs. | 114.8 | 80.4 | 34.4 | M |
| vs. | 73.7 | 66.1 | 7.6 | M |
| vs. | 62.4 | 59.4 | 3.0 | M |
| vs. | 53.0 | 52.2 | 0.7 | M |
C.7 Fine-Tuning Details
Any pretrained -layer transformer can be converted to a context-ready model by adding a zero-initialized correction FFN and fine-tuning. The zero initialization ensures no disruption at conversion: the correction is identically zero, so the model behaves exactly as the original transformer. As fine-tuning progresses, the correction FFN learns to exploit cached context, yielding PPL improvements. At , fine-tuning causes a transient PPL increase before recovering; at , there is no transient increase.
| Conversion | Baseline | Fine-tuned | Cont. baseline | vs. cont. | FT iters |
|---|---|---|---|---|---|
| 29.92 | 26.14 | 27.20 | 1.06 | 200K | |
| 29.42 | 28.99 | — | 0.43∗ | 18K | |
| 33.41 | 32.21 | — | 1.20∗ | 50K | |
| ∗Gain vs. pre-conversion checkpoint; continued-training control not available. | |||||
C.8 Wikipedia Results
For completeness, we include results on English Wikipedia (BPE 16k, context 256, 100K iterations).
| Model | FLOPs/tok | Val PPL | ||
|---|---|---|---|---|
| D=5 concat | 42.5M | 16.69 | ||
| Roformer | 42.5M | 17.95 | 1.26 | |
| D=6 add | 15.9M | 20.40 | ||
| Roformer-hFFN | 15.9M | 21.44 | 1.04 | |
| D=2 concat | 7.2M | 25.48 | ||
| Roformer | 7.2M | 27.19 | 1.71 |
C.9 Inference Timing
Autoregressive generation speed measured on a single A100 GPU over 10,000 tokens with KV caching, batch size 1 (single-sequence generation).
| Model | Params | tok/s | ms/tok | Speedup |
|---|---|---|---|---|
| 215M | 919 | 1.09 | ||
| Roformer | 155M | 351 | 2.85 | |
| 157M | 349 | 2.86 | ||
| Roformer | 134M | 201 | 4.96 |
Per-token latency is flat across sequence length in both comparisons ( growth from to ), confirming that KV caching amortizes attention cost to per token. The context-ready models are faster despite having more parameters, because fewer sequential layers dominate inference latency on modern GPUs. In addition, fewer layers reduce total KV cache memory ( per token): uses less cache than ; uses less than .