- The paper introduces a novel CGAR method that leverages a Progressive Depth Curriculum to reduce computation by approximately 41.4% and prevent overfitting.
- It employs Hierarchical Supervision Weighting to counteract gradient decay, reducing variance by roughly 40% and achieving a 1.71 times speedup in training time.
- The approach maintains competitive accuracy with only a 0.63% drop, showing promise for efficient training in recursive reasoning models.
Accelerating Training Speed of Tiny Recursive Models via Curriculum Guided Adaptive Recursion
Introduction
The paper tackles the inefficiencies associated with training recursive reasoning models, specifically focusing on Tiny Recursive Models (TRM), which, despite their small size and ability to perform complex tasks, require significant training time. To address this, the paper introduces Curriculum-Guided Adaptive Recursion (CGAR), a novel training methodology leveraging curriculum learning applied to recursion depth rather than data ordering.
Methodology
Progressive Depth Curriculum
The CGAR framework introduces a Progressive Depth Curriculum (PDC), which dynamically adjusts recursion depth during training. Unlike traditional fixed-depth training which leads to overfitting and inefficient computation at early training stages, PDC gradually increases the network’s depth as training progresses. This is achieved through a piecewise-constant schedule that transitions between shallow, medium, and full-depth configurations. This progression prevents early-stage overfitting and reduces computational overhead by approximately 41.4%.
Hierarchical Supervision Weighting
Complementing PDC is Hierarchical Supervision Weighting (HSW), which addresses inefficiencies in uniform supervision across reasoning steps. HSW applies exponentially decaying importance to different supervision steps, aligning with observed gradient decay in recursive architectures. This approach ensures significant early-stage gradients are not overshadowed by the dwindling significance of later steps, resulting in approximately 40% reduction in gradient variance and accelerated convergence.
Implementation and Results
Under controlled conditions using a single A100 GPU, CGAR achieves a 1.71 times speedup in training time, reducing it from 10.93 hours to 6.38 hours while maintaining competitive accuracy with only a 0.63% drop. The detailed comparison to a baseline TRM using fixed-depth training shows that CGAR preserves model quality while significantly improving training efficiency.
Code Snippet
The implementation of CGAR involves integrating dynamic recursion and hierarchical weighting into the training loop:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
def train_cgar(D, E, C_PDC, lambda_decay, eta, N_sup):
theta = INIT_PARAMS()
OPT = ADAMW(theta, lr=eta)
Z_lambda = (1 - lambda_decay**N_sup) / (1 - lambda_decay)
for e in range(1, E+1):
n, T = C_PDC(e / E)
for X, Y_true in D:
Y = EMBED(X)
Z = ZERO_STATE_like(Y)
L = 0.0
for t in range(1, N_sup+1):
Y, Z = deep_recursion(Y, Z, X, n, T)
logits = OUT_HEAD(Y)
q = SIGMOID(HALT_HEAD(Y))
w = lambda_decay**(t-1)
L += w * CE(logits, Y_true) + BCE(q, MATCH(logits, Y_true))
if MAX(q) > 0.5:
Y, Z = DETACH(Y), DETACH(Z)
break
Y, Z = DETACH(Y), DETACH(Z)
loss = L / Z_lambda
OPT.zero_grad(); loss.backward(); OPT.step()
return theta |
This snippet captures the adaptive recursion and supervision mechanism central to CGAR, demonstrating its practical implementation in a PyTorch-like environment.
Discussion and Future Work
CGAR demonstrates the feasibility of applying curriculum learning to architectural parameters rather than conventional data paths, offering significant improvements in training time and efficiency without compromising accuracy. Future work may explore automated curriculum scheduling, broader task domain evaluations, and further optimization techniques to integrate CGAR principles into larger, pre-trained LLMs. This could potentially transform training paradigms across different domains, especially where logical reasoning and inferencing efficiency are critical.