
Introduction
Loop models are becoming popular lately, with exciting results [1,2,3,4,5]
Less is More: Recursive Reasoning with Tiny Networks
A. Jolicoeur-Martineau, (2025)
LinkScaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach
J. Geiping, S. McLeish, N. Jain, J. Kirchenbauer, S. Singh, B. Bartoldson, B. Kailkhura, A. Bhatele, T. Goldstein, (2025)
LinkParcae: Scaling Laws For Stable Looped Language Models
H. Prairie, Z. Novack, T. Berg-Kirkpatrick, D. Fu, (2026)
LinkScaling latent reasoning via looped language models
R. Zhu, Z. Wang, K. Hua, T. Zhang, Z. Li, H. Que, B. Wei, Z. Wen, F. Yin, H. Xing, others, (2025)
Hierarchical Reasoning Model
G. Wang, J. Li, Y. Sun, X. Chen, C. Liu, Y. Wu, M. Lu, Y. Yadkori, (2025)
Link
. Once we decide to reuse the same block across multiple layers, however, one practical question becomes unavoidable: what does the loop cost during training?
That is the question of this post. Loop models now come with a growing collection of training tricks and schedules, but those choices are easy to blur together if we describe them all under the loose label of “recurrence.” In this post, I build a clean ablation chain over the most common ingredients: non-shared versus shared weights, terminal loss versus per-step losses, outer-step detach, instant update, internal truncation inside the cell, and gradient checkpointing.
Disclaimer. I am not trying to settle which training strategy gives the best downstream performance. Different strategies change not only cost, but also the learning algorithm itself, so accuracy comparisons deserve a separate discussion. In this essay, I focus on the cost analysis.
Setup and Measurement Contract
A common loop model takes the following structure.
Write the common outer loop as
$$ h_t = f_\theta(h_{t-1}, x_t),\quad t=1,\ldots,T, $$where $h_0$ is an initial state that is usually not parameterized.
Throughout the post,
- cell means one call to the repeated computation $f_\theta$ in the outer loop. When that cell contains its own solver, I call those internal iterations inner loops.
- NFE means the number of forward applications of the repeated block that define the effective depth of the computation. In the simple outer-loop picture above, that is one count per outer-step cell call. In the toy benchmark later, where each outer-step cell contains an inner refinement loop, I count those inner refinements toward NFE, because they are the actual repeated block applications that make up the logical depth.
- checkpointing means gradient checkpointing. I keep the full term when first introducing it, and then shorten it to checkpointing.
For the cost story, two Jacobian objects matter:
$$ {\color{blue} J_t} = \frac{\partial h_t}{\partial h_{t-1}},\quad {\color{red} B_t} = \frac{\partial h_t}{\partial \theta_t} \text{ or } {\color{red} B_t} = \frac{\partial h_t}{\partial \theta}, $$depending on whether the weights are non-shared or shared.
A compact schematic for the final step is:
In this picture, the uppercase index $T$ simply marks the last step of the rollout. The blue top arrow is therefore the last temporal Jacobian,
$$ J_T = \frac{\partial h_T}{\partial h_{T-1}}, $$and the red downward arrow is the last-step local parameter map,
$$ B_T = \frac{\partial h_T}{\partial \theta_T} \text{ or } B_T = \frac{\partial h_T}{\partial \theta}, $$depending on whether the weights are non-shared or shared.
With that notation in place, we can finally state the measurement contract used in the rest of the post. Throughout, we focus on the costs during one training interval/optimizer interval, namely the computation between two model updates. Unless stated otherwise, Variants 1-4 all use the same schedule: run the full $T$-step outer rollout, accumulate all outer losses, launch one backward pass, and then take one optimizer step. Variants 5-7 intentionally change that contract: once instant update appears, there is one optimizer step per outer step, which is the schedule used in HRM and TRM-style training setups [1,5]
Less is More: Recursive Reasoning with Tiny Networks
A. Jolicoeur-Martineau, (2025)
LinkHierarchical Reasoning Model
G. Wang, J. Li, Y. Sun, X. Chen, C. Liu, Y. Wu, M. Lu, Y. Yadkori, (2025)
Link
.
This distinction matters because the optimizer interval determines what counts as one training interval, which in turn changes both the backward graph and the lifetime of saved tensors.
Throughout the post, the two main cost types refer to different parts of this graph:
- Symbolic FLOPs cost counts backward-side training arithmetic at fixed forward work: propagating a state gradient through ${\color{blue}J_t}$, forming parameter-gradient contributions through ${\color{red}B_t}$, accumulating shared gradients, and applying the optimizer update.
- Memory cost counts tensors that remain alive across the forward/backward boundary.
- Parameter-side memory is also real, and I account for it separately through parameter tensors, gradient buffers, and optimizer states.
Defining the FLOPs Cost Model
We define the following local costs:
- $c_\ell$: cost of one local loss backward $\partial \ell_t / \partial h_t$; for example, backward through a softmax-cross-entropy head or an MSE decoder attached to $h_t$.
- $e_J$: cost of obtaining or evaluating the local Jacobian operator ${\color{blue}J_t}$.
- $m_J$: cost of multiplying the incoming state gradient by ${\color{blue}J_t}$.
- $e_B$: cost of obtaining or evaluating the local Jacobian operator ${\color{red}B_t}$.
- $m_B$: cost of multiplying the incoming state gradient by ${\color{red}B_t}$.
- $c_h$: cost of one hidden-state gradient addition; for example, adding the current-step local gradient to the future-to-past contribution in a recursion like $a_t = \delta_t + a_{t+1} J_{t+1}$.
- $c_\theta$: cost of one accumulation into the shared-gradient buffer for one step-local parameter block; under tensor-level accounting, multiply by the number of tensors in that block.
- $c_u$: cost of updating one step-local parameter block once in the optimizer step; under tensor-level accounting, multiply by the number of tensors in that block. For example, SGD applies $W \leftarrow W - \eta G$, while Adam-style updates also touch moment buffers.
Because the Jacobian-evaluation term and the Jacobian-product term usually appear together, define the grouped costs
$$ c_J = e_J + m_J, \qquad c_B = e_B + m_B. $$This split is just a convenient decomposition rather than a claim that autograd materializes or pays for two fully separate Jacobian constructions. In real kernels, parts of the local derivative preparation may be shared or fused across the temporal and parameter VJP paths.
I keep the $e_{*}$ / $m_{*}$ split for interpretation, but most formulas below use the compact forms $c_J$ and $c_B$.
Here the incoming row adjoint is
$$ a_t = \frac{dL}{dh_t}. $$The right-boundary condition is whichever loss touches the final state. For a terminal-loss rollout,
$$ a_T = \frac{\partial \ell(h_T)}{\partial h_T}, $$while for a per-step-loss rollout,
$$ a_T = \delta_T = \frac{\partial \ell_T}{\partial h_T}. $$This does not mean explicitly materializing full Jacobian matrices. It is only a convenient decomposition between local Jacobian evaluation and the arithmetic driven by the incoming adjoint; in real autograd code, these substeps are often fused inside one backward kernel.
Both $c_\theta$ and $c_u$ scale with parameter shape. The difference is conceptual:
- $c_\theta$ is one shared-gradient accumulation for one step-local parameter block;
- $c_u$ is one optimizer update of one step-local parameter block.
For SGD, $c_\theta$ and $c_u$ are often of similar order. For Adam-style optimizers, $c_u$ is typically larger because the optimizer also updates moment buffers and applies extra elementwise operations.
Defining the Memory Cost Model
For memory, it helps to separate activation-side effects from parameter-side effects. Start from
$$ M_{\text{total}} = M_{\text{act}} + M_{\text{param}} + M_{\text{grad}} + M_{\text{opt}}. $$The main formulas below focus on the activation term $M_{\text{act}}$, because detach, checkpointing, and internal truncation primarily change that term.1 Parameter-side memory still matters, so I track it separately through parameter tensors, gradient buffers, and optimizer state.
Define:
- $p_\theta$: storage of one parameter block.
- $p_g$: storage of one parameter-gradient buffer.
- $p_{\text{opt}}$: storage of optimizer state for one parameter block. For plain SGD without momentum this can be $0$; for momentum SGD it is often about one extra tensor; for Adam-style methods it is often about two extra tensors.
For shared and non-shared parameterizations, the parameter-side memory is
$$ M_{\text{param-side}}^{\text{shared}} = p_\theta + p_g + p_{\text{opt}}, \qquad M_{\text{param-side}}^{\text{non-shared}}(T) = T\bigl(p_\theta + p_g + p_{\text{opt}}\bigr). $$These parameter-side formulas refer to the recurrent block and whatever head is included in the modeled training cell. In larger language-model style systems, prelude and coda parameters contribute additional mostly constant terms relative to the outer-loop ablation chain here.
Now define the activation-side quantities:
- $a_f$: activation memory retained by one iterative cell $f$ under ordinary autograd. For the affine-tanh reference cell $h_t = \tanh(W_h h_{t-1} + W_x x_t + b)$, this includes the boundary hidden state, the preactivation $z_t = W_h h_{t-1} + W_x x_t + b$, and any normalization or mask tensors saved for backward. If you want finer resolution, write $a_f = a_h + a_{\text{int}}$, where $a_h$ is the boundary hidden state and $a_{\text{int}}$ are the internal saved tensors of that step.
- $a_\ell$: activation memory retained by one loss/head branch; for example, logits $o_t = U h_t$ together with softmax or decoder-side saved tensors.
- $a_f^{\text{ckpt}}$: activation memory retained by one iterative cell under checkpointing, with $a_f^{\text{ckpt}} \le a_f$, typically with strict inequality in the intended checkpointed settings.
- $r_{\text{ckpt}}$: extra recomputation FLOPs induced by checkpointing one iterative cell during backward; for that same affine-tanh cell, this is the cost of rerunning the affine map and nonlinearity to reconstruct dropped intermediates.
- $e_B^{\text{trunc}}$: cost of obtaining or evaluating the truncated local Jacobian operator for the parameter path when gradients are truncated inside the cell.
- $m_B^{\text{trunc}}$: cost of multiplying the incoming gradient by that truncated local parameter Jacobian operator.
- $a_f^{\text{trunc}}$: activation memory retained by one iterative cell when gradients are truncated inside the cell.
- $r_{\text{ckpt}}^{\text{trunc}}$: extra recomputation FLOPs when checkpointing the remaining differentiable part of an already truncated cell.
- $a_f^{\text{trunc,ckpt}}$: activation memory retained by one truncated iterative cell under checkpointing.
For compact formulas in the truncated case, also define
$$ c_B^{\text{trunc}} = e_B^{\text{trunc}} + m_B^{\text{trunc}}. $$Below, whenever I write $M_{(\cdot)}$ without further qualifiers, I mean the activation-memory term $M_{\text{act}}$. If you want total training memory, simply add the parameter-side terms above.
Variants 1-4 keep the fixed end-of-rollout schedule introduced at the start of the post. Variants 5-7 intentionally switch to one optimizer step per outer step as soon as the instant-update strategy appears.
FLOPs
- \(c_\ell\): one local loss backward
- \(e_J, m_J\): split temporal-Jacobian costs; together \(c_J = e_J + m_J\)
- \(e_B, m_B\): split local parameter-backward costs; together \(c_B = e_B + m_B\)
- \(c_h\): one hidden-state gradient addition
- \(c_\theta\): one shared-gradient accumulation
- \(c_u\): one optimizer update for one parameter block
- \(r_{\text{ckpt}}\): extra recomputation FLOPs from checkpointing
- \(c_B^{\text{trunc}}\): local parameter-backward cost under internal cell truncation
- \(r_{\text{ckpt}}^{\text{trunc}}\): extra recomputation FLOPs for checkpointing the remaining differentiable part of a truncated cell
Memory
- \(a_f\): activation memory of one iterative cell
- \(a_\ell\): activation memory of one loss/head branch
- \(a_f^{\text{ckpt}}\): activation memory of one checkpointed iterative cell
- \(a_f^{\text{trunc}}\): activation memory of one iterative cell under internal cell truncation
- \(a_f^{\text{trunc,ckpt}}\): activation memory of one truncated iterative cell under checkpointing
The Ablation Chain
From this point on, the cell-local algebra is fixed. Each row changes exactly one training choice. To avoid re-deriving the same local identities every time, keep the notation above The ablation chain keeps the same cell-local Jacobian notation from the previous section.; $a_t = \frac{dL}{dh_t}$ with terminal-loss boundary $a_T = \frac{\partial \ell(h_T)}{\partial h_T}$; $J_t = \frac{\partial h_t}{\partial h_{t-1}}$, $B_t = \frac{\partial h_t}{\partial \theta}$; $c_J = e_J + m_J$ and $c_B = e_B + m_B$; Per-step-loss boundary: $a_T = \delta_T$, with $\delta_t = \frac{\partial \ell_t}{\partial h_t}$ from Defining the FLOPs Cost Model in view: $J_t$, $B_t$, $a_t$, the right-boundary value $a_T$, and the grouped costs all keep the same meaning. The only thing that changes from row to row is how these local pieces compose across the rollout.
For a terminal-loss rollout,
$$ a_T = \frac{\partial \ell(h_T)}{\partial h_T}, \qquad a_{t-1} = a_t J_t, \qquad \nabla_{\theta_t}L = a_t B_t. $$For a per-step-loss rollout, with $\delta_t = \partial \ell_t / \partial h_t$,
$$ a_T = \delta_T, \qquad a_t = \delta_t + a_{t+1} J_{t+1}. $$The boundary conventions used below are simple:
- Right boundary: there is no future term beyond step $T$, so every backward recursion starts from $a_T$.
- Left boundary: if $h_0$ is a fixed initial state, we stop after forming $a_1$; only if $h_0$ were learnable would we continue one more multiplication to get $a_0 = a_1 J_1$.
- Degenerate case $T=1$: every temporal term disappears, so any formula with $(T-1)c_J$ or $(T-1)c_h$ collapses to its purely local part.
This is why the formulas below repeatedly show $T-1$ temporal-Jacobian applications but $T$ local parameter-gradient terms: a $T$-step rollout has $T-1$ state-to-state edges but $T$ step-local parameter leaves.
Before walking through the rows, it helps to place the recent loop-model literature into three nearby buckets.
- HRM-style latent reasoning: HRM uses repeated supervision segments, detached carried state between segments, parameter updates after each segment, and a one-step gradient approximation for the final inner transition. In the language of this post, it sits closest to the run from
+ detachto+ instant updateto+ internal truncation inside the cell[5] Hierarchical Reasoning Model
G. Wang, J. Li, Y. Sun, X. Chen, C. Liu, Y. Wu, M. Lu, Y. Yadkori, (2025)
Link . - TRM-style latent reasoning: TRM keeps the repeated supervision, detached carried state, and per-segment updates, but backpropagates through the full final recursion process and treats the one-step approximation as a weaker ablation. So it is closer to
+ detachto+ instant update, with full inner-step backward in the last segment rather than Variant 6 style inner truncation [1] Less is More: Recursive Reasoning with Tiny Networks
A. Jolicoeur-Martineau, (2025)
Link . - Large-scale looped language models: recurrent-depth LMs such as Geiping et al.’s recurrent-depth approach and Parcae use a shared recurrent block with a terminal language-model loss, stochastic loop depth, and truncated backpropagation through the main recurrent-depth loop. So they are closest to the shared-weight + terminal-loss setup, plus an extra outer-loop truncation ingredient that is adjacent to, rather than identical with, the seven-row chain below [2,3]
Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach
J. Geiping, S. McLeish, N. Jain, J. Kirchenbauer, S. Singh, B. Bartoldson, B. Kailkhura, A. Bhatele, T. Goldstein, (2025)
LinkParcae: Scaling Laws For Stable Looped Language Models
H. Prairie, Z. Novack, T. Berg-Kirkpatrick, D. Fu, (2026)
Link .
With that context in mind, let’s start changing one knob at a time.
V1. Baseline: non-shared + final loss
One final loss at h3 sends credit through the whole state chain, and each step owns its own parameter leaf.
Mental picture: one blue state chain, one terminal loss, and one separate red parameter leaf at each step.
This untied-depth model is the natural reference point. I keep it because it makes the effect of weight tying completely explicit in the next variant.
Start from the non-shared case:
$$ h_t = f_{\theta_t}(h_{t-1}, x_t),\quad L = \ell(h_T). $$There are two equivalent ways to write the gradient:
- the fully expanded chain-rule form, which makes the dependency on every later step explicit;
- the adjoint form, which compresses that suffix product into the state gradient \(a_t = dL/dh_t\).
Start with the notation above This baseline section reuses the same local Jacobian definitions introduced earlier.; $a_t = \frac{dL}{dh_t}$ and $a_T = \frac{\partial \ell(h_T)}{\partial h_T}$; $J_t = \frac{\partial h_t}{\partial h_{t-1}}$, $B_t = \frac{\partial h_t}{\partial \theta_t}$ in the non-shared case; $c_J = e_J + m_J$ and $c_B = e_B + m_B$ fully expanded chain-rule expression:
$$ \nabla_{\theta_t}L = \frac{\partial \ell(h_T)}{\partial h_T}\left(\prod_{k=t+1}^{T}\frac{\partial h_k}{\partial h_{k-1}}\right)\frac{\partial h_t}{\partial \theta_t}. $$Now define the adjoint at step \(t\) to be the suffix of that expression up to \(h_t\):
$$ a_t = \frac{dL}{dh_t} = \frac{\partial \ell(h_T)}{\partial h_T}\left(\prod_{k=t+1}^{T}\frac{\partial h_k}{\partial h_{k-1}}\right). $$With that definition, the same gradient becomes the local parameter term hit by the incoming adjoint,
$$ \nabla_{\theta_t}L = a_t {\color{red} B_t}. $$and the adjoints themselves satisfy the one-step reverse recurrence
$$ a_{t-1} = a_t {\color{blue} J_t}. $$So the compact notation is not a new approximation or a different derivation. It is just the same expanded chain rule rewritten in terms of the recursively computed state gradient:
$$ \nabla_{\theta_t}L = a_T \left(\prod_{k=t+1}^{T} {\color{blue} J_k}\right) {\color{red} B_t} = a_t {\color{red} B_t}. $$Reverse mode touches each kind of edge exactly once:
- \(T-1\) uses of \({\color{blue} J_t}\);
- \(T\) uses of \({\color{red} B_t}\);
- \(T\) optimizer updates, one for each step-local parameter block.
So, under the cost model above,
$$ C_{\text{non-shared, final}}(T) = c_\ell + (T-1)c_J + Tc_B + Tc_u. $$The activation-memory cost is
$$ M_{\text{non-shared, final}}(T) = T a_f + a_\ell. $$All $T$ iterative cells must keep their step-local activations alive, because the terminal loss sends credit through the entire chain. The gradient graph is therefore a state chain ending in one loss node, with a separate parameter leaf $\theta_t$ at each step.
V2. + shared weights
Relative to Variant 1, the blue time path is unchanged; only the parameter side changes from separate leaves to one shared accumulation point.
What changes: the blue temporal chain stays the same, but the red stepwise parameter contributions now merge into one shared leaf.
Relative to Variant 1, the blue state recursion is unchanged. Only the red parameter side changes: each step still contributes a local parameter term, but those terms now accumulate into one shared parameter block.
$$ h_t = f_{\theta}(h_{t-1}, x_t),\quad L = \ell(h_T). $$The adjoint recursion is therefore exactly the same as in Variant 1:
$$ a_{t-1} = a_t {\color{blue} J_t}. $$What changes is only how the local parameter contributions are collected. In fully expanded form, the shared-parameter gradient is the sum over all step-local uses of that same parameter block:
$$ \nabla_{\theta}L = \frac{\partial \ell(h_T)}{\partial h_T}\sum_{t=1}^{T}\left(\prod_{k=t+1}^{T}\frac{\partial h_k}{\partial h_{k-1}}\right)\frac{\partial h_t}{\partial \theta}. $$Now reuse the same adjoint definition as in Variant 1,
$$ a_t = \frac{dL}{dh_t} = \frac{\partial \ell(h_T)}{\partial h_T}\left(\prod_{k=t+1}^{T}\frac{\partial h_k}{\partial h_{k-1}}\right), $$so the shared-parameter gradient compresses to
$$ \nabla_{\theta}L = a_T \sum_{t=1}^{T} \left(\prod_{k=t+1}^{T} {\color{blue} J_k}\right) {\color{red} B_t} = \sum_{t=1}^{T} a_t {\color{red} B_t}. $$This is the key structural difference from Variant 1:
- in Variant 1, each \(a_t {\color{red} B_t}\) goes to its own parameter block \(\theta_t\);
- in Variant 2, the same \(a_t {\color{red} B_t}\) terms are accumulated into one shared gradient buffer for \(\theta\).
Weight tying still gives linear-time BPTT. Compared with Variant 1, it adds \(T-1\) shared-gradient accumulations and replaces \(T\) separate optimizer updates with one shared update.
Here \(c_\theta\) counts only cross-step accumulation into an already existing shared gradient buffer. The first write is absorbed into the corresponding local parameter-backward term, which is why the accumulation count is \(T-1\) rather than \(T\).
Hence
$$ C_{\text{shared, final}}(T) = c_\ell + (T-1)c_J + Tc_B + (T-1)c_\theta + c_u. $$The activation-memory cost is unchanged,
$$ M_{\text{shared, final}}(T) = T a_f + a_\ell. $$because tying the parameters does not let step \(t\) reuse the saved activations of step \(t+1\). The state chain is the same; only the parameter leaves have merged.

This variant is the cleanest symbolic match to the large-scale looped-language-model family. Citation: Geiping
J. Geiping, S. McLeish, N. Jain, J. Kirchenbauer, S. Singh, B. Bartoldson, B. Kailkhura, A. Bhatele, T. Goldstein, (2025)
Link
; Citation: Prairie
H. Prairie, Z. Novack, T. Berg-Kirkpatrick, D. Fu, (2026)
Link
both follow the same basic prelude $\to$ shared recurrent block $\to$ coda pattern with one final language-model loss, even though in practice they also introduce sampled loop depth and truncated backpropagation through the main recurrent-depth loop.
V3. + per-step losses
Each state now has its own local loss branch, but the blue cross-time backward path is still intact and all red parameter contributions still merge into one shared θ.
What changes: each state now gets its own local loss branch. The temporal chain remains intact.
Relative to Variant 2, add one local loss branch at every step:
$$ h_t = f_{\theta}(h_{t-1}, x_t),\quad L = \sum_{t=1}^{T}\ell_t(h_t). $$Define the local loss gradient
$$ \delta_t = \frac{\partial \ell_t}{\partial h_t}. $$Then the backward recursion becomes
$$ a_T = \delta_T,\quad a_t = \delta_t + a_{t+1} {\color{blue} J_{t+1}},\quad t=T-1,\ldots,1. $$The parameter gradient is still
$$ \nabla_{\theta}L = \sum_{t=1}^{T} a_t {\color{red} B_t}. $$Compared with Variant 2, this replaces one terminal loss backward with \(T\) step-local loss backward contributions, so the incremental increase is \(T-1\) additional local loss backward terms plus \(T-1\) hidden-state gradient additions, while keeping the blue temporal path fully alive.
So
$$ C_{\text{shared, per-step}}(T) = Tc_\ell + (T-1)c_J + (T-1)c_h + Tc_B + (T-1)c_\theta + c_u. $$The activation-memory cost becomes
$$ M_{\text{shared, per-step}}(T) = T(a_f + a_\ell). $$because every step now retains both one iterative cell and one local loss/head branch.
For the next variants, it is useful to separate the cost into local and temporal pieces:
$$ C_{\text{local}} = c_\ell + c_B + c_\theta,\quad C_{\text{temporal}} = c_J + c_h. $$Then
$$ C_{\text{shared, per-step}}(T) = T C_{\text{local}} + (T-1) C_{\text{temporal}} - c_\theta + c_u. $$This decomposition makes the next row easy to read: detach removes the temporal term and leaves the local term behind.
V4. + detach each step under end-of-rollout backward
Detach removes only the blue temporal backward edges. The local loss branches and the shared red parameter path remain fully active.
What changes: only the blue temporal edges are cut. The gold local losses and the red parameter branches remain.
In this row, I keep the same end-of-rollout execution contract as Variant 3: all detached step losses are accumulated first, and one backward pass is launched only after the full outer rollout.
Relative to Variant 3, cut only the blue temporal edges:
$$ h_t = f_{\theta}(\operatorname{detach}(h_{t-1}), x_t),\quad L = \sum_{t=1}^{T}\ell_t(h_t). $$After detach, future losses no longer flow backward across time. The recursion collapses to
So the parameter gradient becomes
$$ \nabla_{\theta}L = \sum_{t=1}^{T} \delta_t {\color{red} B_t}. $$This is the key point of the article:
detachremoves the blue temporal path. It does not remove the gold local loss branches or the red parameter-gradient work.- Under a fixed end-of-rollout backward, that means less temporal FLOPs but still $T$ detached local graphs kept alive for backward.
After detach, the outer-step graphs are independent. So for the detached per-step objective
there are two execution contracts:
- accumulate all step losses and call one backward at the end of the rollout;
- call
loss_t.backward()after each outer step, accumulate gradients in the shared parameter buffers, and delayoptimizer.step()until the end of the rollout.
These two contracts produce the same parameter gradient up to floating-point accumulation order, because the step graphs are already disconnected by detach. The second contract is usually the more sensible implementation: it releases each detached local graph earlier and reduces the outer-depth activation peak from linear-in-$T$ storage to effectively constant-in-$T$ storage.
I keep the end-of-rollout version here only to keep the comparison clean: in this row, the new change is just that detach cuts cross-step gradients, not that backward is called at a different time. In practice, once detach is present, streaming backward accumulation is often the better implementation, as the toy experiment section will show.
So, relative to Variant 3, the entire temporal term $(T-1)C_{\text{temporal}}$ disappears, while the local loss branches and red parameter branches remain.
Hence
$$ C_{\text{shared, per-step, detach}}(T) = Tc_\ell + Tc_B + (T-1)c_\theta + c_u. $$Under the fixed execution schedule above, the activation-memory scaling is still linear in $T$:
$$ M_{\text{shared, per-step, detach}}(T) = T(\tilde a_f + \tilde a_\ell), \qquad \tilde a_f + \tilde a_\ell = \Theta(a_f + a_\ell). $$Equivalently, the saved FLOPs are exactly the temporal part:
$$ C_{\text{shared, per-step}}(T) - C_{\text{shared, per-step, detach}}(T) = (T-1)c_J + (T-1)c_h = (T-1)C_{\text{temporal}}. $$If $C_{\text{local}}$ is large, detach may save much less compute than the recurrence diagram first suggests. The memory constant can also shift a bit: after detach, each step-local backward no longer needs to send gradients into $h_{t-1}$, so an autodiff engine may skip saving some tensors that would only have been needed for that input-gradient path. But the parameter-gradient path is still alive, so most of the step-local state needed for the loss/head backward and the local parameter backward still has to remain. That is why the robust claim here is about scaling rather than exact bytes: under a fixed end-of-rollout backward, detach changes temporal FLOPs while keeping the outer-depth activation scaling linear in $T$.
V5. + instant update
The outer graph is still detached across time, but the runtime schedule changes: each local step is backpropagated and updated immediately, so θ0 → θ1 → θ2 → θ3 inside the rollout.
What changes: the detached local graph is the same as in Variant 4, but the training schedule changes. This is the first row that changes the learning algorithm itself. Instead of waiting until the end of the rollout, every outer step now does forward $\to$ local backward $\to$ optimizer update immediately. Step $t+1$ therefore sees parameters that have already been updated after step $t$.
From here on, the natural unit is one optimizer interval = one outer step. This is exactly the point where both the measurement unit and the learning algorithm change.
Per optimizer interval,
$$ C_{\text{shared, per-step, detach, instant}}^{\text{interval}} = c_\ell + c_B + c_u. $$and the peak activation memory is
$$ M_{\text{shared, per-step, detach, instant}}^{\text{peak}} = a_f + a_\ell. $$There is no $c_\theta$ term here, because there is no cross-step gradient accumulation before the optimizer update.
Over the same $T$-step horizon, the total cost is
$$ C_{\text{shared, per-step, detach, instant}}^{\text{rollout}}(T) = T(c_\ell + c_B + c_u), $$with peak activation memory
$$ M_{\text{shared, per-step, detach, instant}}^{\text{rollout}} = a_f + a_\ell. $$Within the numbered chain, this is the first row whose profiled optimizer interval is one outer step. That schedule has constant outer-depth activation peak. A detached streaming-backward implementation would already achieve the same outer-depth peak without changing the optimizer interval. What instant update adds on top is the optimizer-schedule change: the parameters are updated inside the rollout rather than after the rollout.
V6. + internal truncation inside the cell
The outer one-step schedule is the same as Variant 5. The extra change is inside each cell: the purple dashed inner region marks a smaller cell-internal backward graph.
What changes: nothing at the outer schedule level. The only new truncation happens inside the cell-local backward.
Relative to Variant 5, the outer schedule is unchanged. The only new change is internal to the cell: if $f_\theta$ itself contains an inner loop, we can truncate gradients there as well.
This does not change the outer-step forward computation. It only reduces the local backward and local activation storage inside the cell.
Under the instant-update schedule of Variant 5, the exact cost per optimizer interval becomes
$$ C_{\text{shared, per-step, detach, instant, trunc}}^{\text{interval}} = c_\ell + c_B^{\text{trunc}} + c_u, $$where
$$ c_B^{\text{trunc}} \le c_B, $$typically with strict inequality in the intended truncated settings.
The peak activation memory per optimizer interval becomes
$$ M_{\text{shared, per-step, detach, instant, trunc}}^{\text{peak}} = a_f^{\text{trunc}} + a_\ell,\quad a_f^{\text{trunc}} \le a_f, $$again typically with strict inequality in the intended benchmark settings.
Over the same $T$-step horizon, the total compute is
$$ C_{\text{shared, per-step, detach, instant, trunc}}^{\text{rollout}}(T) = T(c_\ell + c_B^{\text{trunc}} + c_u), $$with peak activation memory still
$$ M_{\text{shared, per-step, detach, instant, trunc}}^{\text{rollout}} = a_f^{\text{trunc}} + a_\ell. $$This row combines two logically separate changes:
- instant update changes the optimizer schedule at the outer-loop level;
- internal truncation reduces the local backward and local activation memory inside each cell.
The forward cost of one outer step is unchanged; the savings come entirely from the local backward and activation-storage side of the cell.
This is the cleanest place in the literature to anchor HRM rather than TRM. Both use internal truncation, but with different truncation lengths.
HRM backpropagates through only one inner step: HRM trains on repeated supervision segments, detaches the carried state between segments, updates the parameters after each segment, and uses a one-step gradient approximation for the final inner transition [5] Hierarchical Reasoning Model
G. Wang, J. Li, Y. Sun, X. Chen, C. Liu, Y. Wu, M. Lu, Y. Yadkori, (2025)
Link . That approximation is motivated as a practical surrogate for the Implicit Function Theorem-style gradient signal discussed in Citation: Bai, & al. [6] Deep Equilibrium Models
S. Bai, J. Kolter, V. Koltun, (2019)
Link .TRM backpropagates through more inner steps: TRM keeps the same repeated-supervision, detach-between-segments, and per-segment update structure, but extends the gradient path across more inner steps [1] Less is More: Recursive Reasoning with Tiny Networks
A. Jolicoeur-Martineau, (2025)
Link .
V7. + gradient checkpointing
This keeps the one-step truncated schedule of Variant 6, but now the remaining differentiable part of each cell is checkpointed: memory drops a bit further, and backward pays extra recomputation FLOPs.
What changes: keep the one-step outer schedule and the inner truncation, then checkpoint the remaining differentiable part of the cell.
Relative to Variant 6, the outer schedule is unchanged and the NFE is unchanged. The local gradient structure is still the truncated one-step version of Variant 6. The only new change is that the remaining differentiable part of the already truncated cell is now checkpointed instead of stored in the ordinary way.
So, per optimizer interval,
$$ C_{\text{shared, per-step, detach, instant, trunc, ckpt}}^{\text{interval}} = c_\ell + c_B^{\text{trunc}} + c_u + r_{\text{ckpt}}^{\text{trunc}}. $$The peak activation memory becomes
$$ M_{\text{shared, per-step, detach, instant, trunc, ckpt}}^{\text{peak}} = a_f^{\text{trunc,ckpt}} + a_\ell, \qquad a_f^{\text{trunc,ckpt}} \le a_f^{\text{trunc}}, $$with strict inequality in the intended checkpointed settings.
Over the same $T$-step horizon,
$$ C_{\text{shared, per-step, detach, instant, trunc, ckpt}}^{\text{rollout}}(T) = T\bigl(c_\ell + c_B^{\text{trunc}} + c_u + r_{\text{ckpt}}^{\text{trunc}}\bigr), $$with peak activation memory still
$$ M_{\text{shared, per-step, detach, instant, trunc, ckpt}}^{\text{rollout}} = a_f^{\text{trunc,ckpt}} + a_\ell. $$The measured charts in the experiment section also include three indented auxiliary rows right after Variants 3 and 4: the fixed-schedule 3 + checkpointing, 4 + streaming backward accumulation, and 4 + checkpointing comparisons, shown for apples-to-apples reference, but not counted as new numbered variants in the main ablation line.
How gradient checkpointing saves memory
Without checkpointing, autograd keeps the step-local activations produced inside each iterative cell so that backward can reuse them later. If we write
$$ a_f = a_h + a_{\text{int}}, $$then $a_h$ is the boundary hidden state and $a_{\text{int}}$ are the internal saved tensors of that step.
With checkpointing, we keep only a smaller boundary representation during forward and drop most of the internal saved tensors. During backward, we rerun the forward of the checkpointed region to reconstruct the missing activations and only then apply the local backward. That is the whole tradeoff:
$$ a_f^{\text{ckpt}} \le a_f, $$with strict inequality in the intended checkpointed settings, and the checkpointed cell pays an extra \(r_{\text{ckpt}}\) recomputation FLOPs during backward.
Checkpointing saves memory by storing fewer forward activations and reconstructing them later. It does not change the gradient graph.
Shared vs non-shared under checkpointing
Checkpointing acts on activations, not on parameter tying. If two models have the same unrolled depth and the same per-step transition shape, then a fixed checkpointing policy saves the same kind of activation memory in both: it drops step-local internal activations and recomputes them during backward. Weight sharing does not make step-$t$ activations reusable for step $t+1$, because those activations come from different hidden states.
The differences appear elsewhere. Read “shared + checkpointing” as “non-shared + checkpointing,” plus the usual weight-sharing differences:
- parameter memory: $T$ parameter blocks $\to$ one shared parameter block;
- gradient buffers: $T$ separate parameter-gradient buffers $\to$ one shared gradient buffer;
- optimizer state: $T$ optimizer-state blocks $\to$ one optimizer-state block;
- cross-step parameter accumulation: none $\to$ required.
So the clean summary is:
- checkpointing helps both shared and non-shared models because both store step-specific activations;
- weight sharing reduces parameter-side memory, not step-activation memory;
- checkpointing and weight sharing are complementary because they attack different parts of the footprint.
In practice, a shared-weight loop may re-enter the same block many times during checkpointed backward, while an untied model walks through different parameter blocks. The activation-memory story is the same; only the implementation overheads differ.
At this point, the symbolic story is complete:
- shared weights: mostly change parameter-side accumulation, not step-local activation storage;
- per-step losses: add local backward branches;
detach: removes temporal credit assignment, not local backward;- instant update: changes the optimizer interval from one full rollout to one outer step;
- internal truncation: shrinks the cell-local backward inside that new schedule;
- checkpointing: swaps saved activations for recomputation without changing the detached gradient estimator.
Toy Experiment Check
To make the discussion less purely symbolic, I ran a small benchmark that mirrors the nested structure of the post. The exact benchmark script is here.
- the outer loop rolls for $T_{\text{out}}$ steps and owns the losses;
- each outer step applies one iterative cell;
- inside that cell, there is an inner refinement loop of depth $K$;
- the internal-truncation variant differentiates only the last refinement, while the earlier refinements run without gradient tracking.
The tying convention in this toy benchmark is deliberately asymmetric. With the setting used below, \(T_{\text{out}} = 32\) and \(K = 6\):
- Variant 1 (non-shared / untied) fully materializes the rollout into \(32 \times 6 = 192\) distinct refinement layers, so every refinement has its own parameter block;
- Variants 2-7 (shared / tied) use one outer-step cell shared across all \(32\) outer steps, with \(6\) inner refinements inside each call;
- the two setups therefore execute the same \(192\) logical refinement calls per full rollout, but Variant 1 realizes them as \(192\) separate parameter blocks, whereas Variants 2-7 reuse one shared block throughout.
Concretely, each outer step uses the same style of cell as the setup section,
$$ h_t = \tanh(W_h h_{t-1} + W_x x_t + b), $$and in the batched benchmark implementation this means
$$ Z_t = H_{t-1} W_h^\top + X_t W_x^\top + \mathbf{1} b^\top, \qquad H_t = \tanh(Z_t), $$with \(H_{t-1} \in \mathbb{R}^{B \times D}\), \(X_t \in \mathbb{R}^{B \times X}\), \(W_h \in \mathbb{R}^{D \times D}\), \(W_x \in \mathbb{R}^{D \times X}\), \(b \in \mathbb{R}^{D}\), and \(\mathbf{1} \in \mathbb{R}^{B}\) the all-ones vector. If \(A_t = dL / dH_t\) is the incoming batch adjoint, define
$$ S_t = 1 - H_t \odot H_t, \qquad \bar A_t = A_t \odot S_t. $$Then the two local backward objects specialized to this benchmark cell are
$$ A_t J_t = \bar A_t W_h, $$and
$$ A_t B_t = \Bigl(\bar A_t^\top H_{t-1},\; \bar A_t^\top X_t,\; \mathbf{1}^\top \bar A_t\Bigr). $$So, in the toy benchmark, the blue temporal term is the exported hidden-to-hidden multiply \(A_t J_t\), while the red local term \(A_t B_t\) collects the gradients with respect to \(W_h\), \(W_x\), and \(b\) inside one outer-step cell. The outer state then goes through one linear head and an MSE loss.
Benchmark setup
The run below used one deliberately small but still nontrivial setting:
- batch size $= 32$
- outer rollout length $T_{\text{out}} = 32$
- hidden size $d_h = 256$
- input size $d_x = 256$
- inner cell depth $K = 6$
- Adam on CPU
For every row below, torch.profiler wraps the full optimizer-interval body: forward through the iterative cell and head, loss formation, backward, and optimizer.step(). So the FLOPs shown below are profiler-estimated FLOPs per training interval.
These profiler-estimated FLOPs should be read as a comparative operator-accounted proxy under one fixed implementation, not as a complete hardware-level accounting of all floating-point work.2
Under the definition above, checkpointing does not change NFE; it only raises profiler-estimated FLOPs by recomputation during backward.
The two memory columns answer slightly different questions:
- Peak saved activations: how much backward-facing activation storage autograd kept alive.
- Tracked peak memory: peak saved activations plus model parameters, gradient buffers, and optimizer state inside this benchmark setup.
Because this run uses Adam, the optimizer-state breakdown is nontrivial rather than an all-zero auxiliary chart.
The table below maps the article variants to the benchmark implementation. Variant 4 uses detached carry with end-of-rollout backward, while an indented auxiliary row measures the same detached objective under streaming per-step backward accumulation with a delayed optimizer.step().
| Row | Outer loss attachment | Carry between outer steps | Optimizer-step timing | Inner gradient scope | Checkpoint boundary |
|---|---|---|---|---|---|
| 1 baseline | terminal loss on the final outer state only | full graph carried across the rollout | once after full rollout | all 6 inner loops differentiated | none |
| 2 + shared | terminal loss on the final outer state only | full graph carried across the rollout | once after full rollout | all 6 inner loops differentiated | none |
| 3 + per-step losses | one outer loss at every step | full graph carried across the rollout | once after full rollout | all 6 inner loops differentiated | none |
| ↳ + checkpointing on top of 3 | one outer loss at every step | full graph carried across the rollout | once after full rollout | all 6 inner loops differentiated | checkpoint the full outer-step cell |
4 + detach each step under end-of-rollout backward | one outer loss at every step | value carried, graph detached between steps | once after full rollout | all 6 inner loops differentiated | none |
| ↳ + streaming backward accumulation on top of 4 | one outer loss at every step | value carried, graph detached between steps | backward after each outer step; optimizer step once after full rollout | all 6 inner loops differentiated | none |
| ↳ + checkpointing on top of 4 | one outer loss at every step | value carried, graph detached between steps | once after full rollout | all 6 inner loops differentiated | checkpoint the full outer-step cell |
| 5 + instant update | one outer loss on the profiled step | next step would start from detached carried state | once per outer step | all 6 inner loops differentiated | none |
| 6 + internal truncation | one outer loss on the profiled step | next step would start from detached carried state | once per outer step | only the last inner loop is differentiable | none |
| 7 + gradient checkpointing | one outer loss on the profiled step | next step would start from detached carried state | once per outer step | only the last inner loop is differentiable | checkpoint only that remaining differentiable inner loop |
The measurement contract changes across the chain, so the empirical results should be read in two blocks: Rows 1-4 plus the three auxiliary fixed-schedule rows are measured per full 32-step outer rollout, whereas Rows 5-7 are measured per single outer step. When I want apples-to-apples compute across schedules, I normalize Rows 5-7 back to the same 32-step outer horizon.
Results
Read `↓` as the main numbered ablation chain and `↳` as an auxiliary branch on the immediately preceding numbered variant. Rows 1-4 plus the three indented auxiliary rows are full-rollout intervals, whereas Rows 5-7 are one-step intervals, so the figure should be read as two blocks rather than one monotone scale. FLOPs are `torch.profiler(with_flops=True)` estimates, and NFE means the logical number of forward cell refinements in the modeled computation, so checkpointing can raise FLOPs without changing NFE; keep the aggregate peak-memory chart visible first and expand the breakdown only when you want the components.
Expand peak-memory breakdown (activations / parameters / gradients / optimizer state)
The charts keep all rows in one compact visual stack. Read ↓ in the left labels as the main numbered ablation chain and ↳ as an auxiliary branch that modifies only the immediately preceding numbered variant.
Why V3 and V4 are close? A detailed analysis
Start from the two formulas already derived above:
$$ C_{\text{shared, per-step}}(T) = Tc_\ell + (T-1)c_J + (T-1)c_h + Tc_B + (T-1)c_\theta + c_u, $$and
$$ C_{\text{shared, per-step, detach}}(T) = Tc_\ell + Tc_B + (T-1)c_\theta + c_u. $$So the entire V3 \(\to\) V4 gap is
$$ C_{\text{shared, per-step}}(T) - C_{\text{shared, per-step, detach}}(T) = (T-1)c_J + (T-1)c_h. $$That is the right entry point. The only real question is how large \(c_J + c_h\) is, in this toy benchmark, relative to the local term \(c_\ell + c_B + c_\theta\).
For the benchmark cell used here, with batch size \(B\), hidden size \(D\), input size \(X\), and inner depth \(K\), a batched matrix multiply of shape \([B,D] \times [D,D]\) costs \(2BD^2\) FLOPs, and one of shape \([B,X] \times [X,D]\) costs \(2BXD\).
The benchmark uses a square linear head on the outer state. If \(H_t \in \mathbb{R}^{B \times D}\) is the batch of outer states at step \(t\), \(U \in \mathbb{R}^{D \times D}\) is the head matrix, and \(Y_t \in \mathbb{R}^{B \times D}\) is the target, then
$$ O_t = H_t U, \qquad \ell_t = \tfrac12 \|O_t - Y_t\|_F^2. $$Writing \(\Delta_t = O_t - Y_t\), the head backward is
$$ \frac{\partial \ell_t}{\partial H_t} = \Delta_t U^\top, \qquad \frac{\partial \ell_t}{\partial U} = H_t^\top \Delta_t. $$Each term is one \([B,D] \times [D,D]\) GEMM, so each costs \(2BD^2\). Therefore
$$ c_\ell = 4BD^2. $$Likewise, in the earlier symbolic notation \(c_J = e_J + m_J\). For this affine-tanh cell, \(e_J\) is the elementwise derivative preparation
$$ S_t = 1 - H_t \odot H_t, \qquad \bar A_t = A_t \odot S_t, $$while \(m_J\) is the exported temporal VJP
$$ \bar A_t W_h, $$which is one more \([B,D] \times [D,D]\) GEMM. So
$$ m_J = 2BD^2. $$The remaining \(e_J\) work is only elementwise \(O(BD)\) and is not stably covered by torch.profiler(with_flops=True) in this benchmark. Under the profiler-aligned leading-term accounting used in this collapse, I therefore keep only the dominant counted GEMM part and write
because these are elementwise/add/update terms and torch.profiler(with_flops=True) does not give stable FLOPs coverage for them in this benchmark;
and, under the same profiler-aligned leading-term accounting,
$$ c_B \approx K(2BD^2 + 2BXD) + (K-1)2BD^2, $$because the local parameter backward still traverses all \(K\) inner refinements. At each refinement, exposing the parameter-gradient contribution costs one hidden-hidden term plus one input-hidden term, namely \(2BD^2 + 2BXD\). In addition, to reach the earlier parameter uses inside the same cell, backward must still propagate the hidden adjoint across the first \(K-1\) inner links, each costing another \(2BD^2\).
With the benchmark setting
$$ T=32,\qquad B=32,\qquad D=X=256,\qquad K=6, $$this becomes
$$ c_\ell = 8.389\text{ M}, \qquad c_J \approx 4.194\text{ M}, \qquad c_B \approx 71.303\text{ M}. $$So the formulas predict
$$ (T-1)c_J + (T-1)c_h \approx 31 \cdot 4.194\text{ M} = 130.023\text{ M}, $$while the local part retained by both Variant 3 and Variant 4 is
$$ T(c_\ell + c_B) + (T-1)c_\theta + c_u \approx 32(8.389 + 71.303)\text{ M} = 2550.137\text{ M}. $$That is why the relative compute drop is modest: detach removes the temporal term, but the local loss branch and the local parameter-backward term are still the dominant pieces in this toy.
Up to this point, I have only compared the backward-side terms that differ between the two rows. To compare against the profiler totals, we now add back the common forward work that the symbolic model deliberately held fixed. For both Variant 3 and Variant 4, each of the \(T\) outer steps runs \(K\) cell refinements and one head forward, so the shared forward baseline is
$$ T\bigl(K(2BD^2 + 2BXD) + 2BD^2\bigr) = 1744.830\text{ M}, $$where \(K(2BD^2 + 2BXD)\) is the cell forward and the final \(2BD^2\) is the head forward on that outer step. Therefore the full analytic interval totals are
$$ \underbrace{1744.830}_{\text{shared forward baseline}} + \underbrace{2680.161}_{\text{Variant 3 backward-side terms}} = 4424.991\text{ M for Variant 3}, $$and
$$ \underbrace{1744.830}_{\text{shared forward baseline}} + \underbrace{2550.137}_{\text{Variant 4 backward-side terms}} = 4294.967\text{ M for Variant 4}, $$which line up with the measured profiler-estimated values
$$ 4428.334\text{ M}, \qquad 4298.310\text{ M}. $$So the small V3 \(\to\) V4 gap is not evidence of an implementation bug. It is exactly what the earlier formulas predict once the toy cell is instantiated under the same decomposition.
In real systems, the visual separation can be even less clean: parts of the local derivative preparation for the temporal VJP and the parameter VJP may be shared or fused inside the same backward kernels. That is why the split into \(c_J\) and \(c_B\) should still be read as a convenient decomposition, not as two physically separate kernel launches.
Key Takeaways
The benchmark tracks the symbolic story closely: the biggest shifts come from changing which backward path exists and when local graphs are released, not from weight sharing alone.
- Weight sharing collapses parameter-side memory, not activation-side memory. Variant 1 and Variant 2 keep the same saved-activation peak at 66.094 MB, but tracked peak memory drops from 451.850 MB to 69.102 MB because the benchmark moves from a fully materialized 192-layer untied stack to one shared recurrent block, collapsing parameters, gradient buffers, and optimizer state.
- Per-step losses make the local term dominant. Variant 2 $\to$ Variant 3 raises profiler-estimated FLOPs from 4038.198 M to 4428.334 M, and Variant 3 $\to$ Variant 4 then falls only to 4298.310 M because
detachremoves only the temporal part while the local loss and local parameter-backward work remain; even after that drop, saved activations are still 69.000 MB under end-of-rollout backward. - Once per-step detach is introduced, streaming backward should be the default implementation. After detaching at every step, the local graphs are already temporally disconnected, so backpropagating only at the end of the rollout brings essentially no computational benefit while needlessly retaining activations. In our benchmark, the auxiliary streaming row keeps essentially the same compute as Variant 4, 4302.316 M versus 4298.310 M, but reduces saved activations from 69.000 MB to 2.156 MB by releasing each detached local graph immediately.
Future Topics
This post isolates the smallest ablation chain that makes gradient paths, optimizer intervals, and storage policies explicit. The next natural extensions are adaptive or routed recurrence, such as halting or mixture-of-recursions, where NFE becomes input dependent and the accounting has to move from a fixed $T$ to expected depth. A third direction is systems work, especially FSDP and compiler interactions, where communication, sharding boundaries, and recomputation matter just as much as local FLOPs and saved activations.
References
J. Geiping, S. McLeish, N. Jain, J. Kirchenbauer, S. Singh, B. Bartoldson, B. Kailkhura, A. Bhatele, T. Goldstein, (2025)
Link
H. Prairie, Z. Novack, T. Berg-Kirkpatrick, D. Fu, (2026)
Link
R. Zhu, Z. Wang, K. Hua, T. Zhang, Z. Li, H. Que, B. Wei, Z. Wen, F. Yin, H. Xing, others, (2025)
G. Wang, J. Li, Y. Sun, X. Chen, C. Liu, Y. Wu, M. Lu, Y. Yadkori, (2025)
Link
In reverse-mode autodiff, backward through a step usually needs some forward-time tensors again, such as the input hidden state, preactivation, normalization statistics, or attention masks. Unless we choose a recomputation policy such as checkpointing, the framework therefore saves those tensors during forward so the later backward pass can form the local derivatives. ↩︎
PyTorch documents
with_flops=Trueas using formulas to estimate FLOPs of specific operators, so I treat the resulting number as a comparative operator-accounted proxy rather than a full hardware-level total. https://docs.pytorch.org/docs/stable/profiler.html ↩︎
Cited as
Use the plain citation or copy the BibTeX entry below.
Benhao Huang. (Apr 2026). Loop-Model FLOPs and Memory in an Ablation Chain. Husky's Log. /husky-blog/posts/recursive_models/loop-cost/
@article{huang2026loopmodel,
title = "Loop-Model FLOPs and Memory in an Ablation Chain",
author = "Benhao Huang",
journal = "Husky's Log",
year = "2026",
month = "Apr",
url = "https://huskydoge.github.io/husky-blog/posts/recursive_models/loop-cost/"
}