Papers
Topics
Authors
Recent
Search
2000 character limit reached

(How) Can Transformers Predict Pseudo-Random Numbers?

Published 14 Feb 2025 in cs.LG, cond-mat.dis-nn, cs.CR, and stat.ML | (2502.10390v1)

Abstract: Transformers excel at discovering patterns in sequential data, yet their fundamental limitations and learning mechanisms remain crucial topics of investigation. In this paper, we study the ability of Transformers to learn pseudo-random number sequences from linear congruential generators (LCGs), defined by the recurrence relation $x_{t+1} = a x_t + c \;\mathrm{mod}\; m$. Our analysis reveals that with sufficient architectural capacity and training data variety, Transformers can perform in-context prediction of LCG sequences with unseen moduli ($m$) and parameters ($a,c$). Through analysis of embedding layers and attention patterns, we uncover how Transformers develop algorithmic structures to learn these sequences in two scenarios of increasing complexity. First, we analyze how Transformers learn LCG sequences with unseen ($a, c$) but fixed modulus, and we demonstrate successful learning up to $m = 2{32}$. Our analysis reveals that models learn to factorize the modulus and utilize digit-wise number representations to make sequential predictions. In the second, more challenging scenario of unseen moduli, we show that Transformers can generalize to unseen moduli up to $m_{\text{test}} = 2{16}$. In this case, the model employs a two-step strategy: first estimating the unknown modulus from the context, then utilizing prime factorizations to generate predictions. For this task, we observe a sharp transition in the accuracy at a critical depth $=3$. We also find that the number of in-context sequence elements needed to reach high accuracy scales sublinearly with the modulus.

Summary

  • The paper demonstrates that Transformer models can effectively learn and predict sequences from Linear Congruential Generators, succeeding for moduli up to $2^{32}$ when given sufficient architecture, data, and context.
  • Through training on fixed or unseen moduli, models learn to factorize the modulus and use digit-wise representations or estimate the modulus greedily, showing a critical depth requirement of $d=3$ for generalization.
  • Interpretability analysis reveals that Transformers reverse-engineer the LCG algorithm using mixed-radix representations, identifying prime factorizations in embedding layers and employing attention heads to look back specific steps for predictions, with scaling sublinearly with modulus size.

The paper "(How) Can Transformers Predict Pseudo-Random Numbers?" investigates the capacity of Transformer models to learn and predict sequences generated by Linear Congruential Generators (LCGs). The central question explored is whether Transformers can discern the underlying mathematical patterns in PRNGs or merely identify spurious correlations. The study examines the impact of model architecture, training data, and context length on the ability of Transformers to predict LCG sequences, even with unseen moduli and parameters.

The authors find that Transformers, given sufficient architectural depth, context length, and training data, can effectively predict LCG sequences. The study demonstrates successful learning for moduli up to m=232m = 2^{32}. Interpretability analyses reveal that the models develop emergent structures in their embedding layers and attention heads, effectively reverse-engineering the LCG algorithm.

The paper investigates two training paradigms:

  • Fixed Modulus (FM): Training and testing are conducted on sequences generated with a single modulus.
  • Generalization to Unseen Modulus (UM): Models are trained on sequences with various moduli and tested on sequences with a modulus not seen during training.

Key findings of the paper include:

  • For the FM paradigm, Transformers learn to factorize the modulus and utilize digit-wise number representations for sequential predictions.
  • For the UM paradigm, models adopt a greedy approach to estimate the unknown modulus from the context and use prime factorizations to generate predictions. A critical depth of d=3d = 3 is observed, below which the Transformer cannot learn.
  • The number of in-context sequence elements needed to achieve high accuracy scales sublinearly with the period.

The paper builds upon existing research in Transformer interpretability, modular arithmetic, and PRNG cracking. It connects the pattern-learning capabilities of Transformers to the structured nature of PRNG outputs.

Related works include:

  • Interpretability and Modular Arithmetic: Studies uncovering circuits, algorithms, and emergent structures learned by Transformers in modular arithmetic problems [sharkey2025interp, olsson2022context, Ahn2023gradient, vonoswald2023Transformers, akyurek2023what, hendel2023incontext, liu2024incontextvector, power2022grokking, gromov2022grokking, nanda2023progress, zhong2023clock, doshi2024to, he2024learning].
  • Cracking PRNGs: Deep learning-based attacks on PRNGs [amigo2021].
  • Context-Free Grammar: Use of formal languages to understand the properties of neural networks [chomsky1956three, deletang2023chomsky, allenzhu2024physicslanguagemodels1, cagnetta2024deep, cagnetta2024towards].
  • Chaotic time-series: Application of neural networks in predicting time-series for chaotic dynamics [lam2023learning].

Linear Congruential Generators

LCGs generate a sequence {xn}n=0N\{x_n\}_{n=0}^N using the recurrence relation:

xn+1=(axn+c)mod  mx_{n+1} = (ax_n + c) \mod m,

where

  • m>0m > 0 is the modulus,
  • $0 < a < m$ is the multiplier,
  • 0≤c<m0 \leq c < m is the increment.

The period Tm(a,c)\mathcal{T}_m(a,c) of an LCG sequence is an important factor determining its complexity, where 1≤Tm(a,c)≤m1 \leq \mathcal{T}_m(a,c) \leq m. The Hull-Dobell Theorem [hull-dobell] provides criteria for achieving the maximum possible period Tm(a,c)=m\mathcal{T}_m(a,c) = m:

  • d=3d = 30 and d=3d = 31 are coprime.
  • d=3d = 32 is divisible by all prime factors of d=3d = 33.
  • d=3d = 34 is divisible by d=3d = 35 if d=3d = 36 is divisible by d=3d = 37.

The paper evaluates models exclusively on sequences that satisfy these criteria.

Training Setup

The study employs two training paradigms: FM and UM.

Fixed Modulus

For a given modulus d=3d = 38, the Hull-Dobell Theorem is used to determine valid d=3d = 39 pairs that maximize the sequence period. A test dataset is generated by randomly sampling {xn}n=0N\{x_n\}_{n=0}^N0 values for each of {xn}n=0N\{x_n\}_{n=0}^N1 and {xn}n=0N\{x_n\}_{n=0}^N2. The training dataset consists of 100k LCG sequences of length {xn}n=0N\{x_n\}_{n=0}^N3 (context length), with different {xn}n=0N\{x_n\}_{n=0}^N4, and {xn}n=0N\{x_n\}_{n=0}^N5, excluding the {xn}n=0N\{x_n\}_{n=0}^N6 and {xn}n=0N\{x_n\}_{n=0}^N7 values used in the test dataset.

Generalization to Unseen Modulus

A test modulus {xn}n=0N\{x_n\}_{n=0}^N8 is selected and excluded from the training data. The training dataset includes {xn}n=0N\{x_n\}_{n=0}^N9 different moduli uniformly sampled from the range xn+1=(axn+c)mod  mx_{n+1} = (ax_n + c) \mod m0, where xn+1=(axn+c)mod  mx_{n+1} = (ax_n + c) \mod m1. For each modulus xn+1=(axn+c)mod  mx_{n+1} = (ax_n + c) \mod m2, xn+1=(axn+c)mod  mx_{n+1} = (ax_n + c) \mod m3 multipliers (xn+1=(axn+c)mod  mx_{n+1} = (ax_n + c) \mod m4) and xn+1=(axn+c)mod  mx_{n+1} = (ax_n + c) \mod m5 increments (xn+1=(axn+c)mod  mx_{n+1} = (ax_n + c) \mod m6) are uniformly selected, generating sequences of length xn+1=(axn+c)mod  mx_{n+1} = (ax_n + c) \mod m7 with random initial values xn+1=(axn+c)mod  mx_{n+1} = (ax_n + c) \mod m8. The test dataset is constructed by randomly sampling 64 values each of xn+1=(axn+c)mod  mx_{n+1} = (ax_n + c) \mod m9 and m>0m > 00 that generate maximum-period sequences for each m>0m > 01.

In both paradigms, test accuracy is calculated by averaging over all sequences generated from the excluded values of m>0m > 02 and m>0m > 03 and 128 random initial values m>0m > 04 for each m>0m > 05 pair.

The model architecture is based on GPT-style Transformers with learnable positional embeddings and weight tying [press2017tying]. The architecture is defined by the number of blocks (depth), embedding dimension (m>0m > 06), and number of attention heads (m>0m > 07). Models are trained using the AdamW optimizer.

Training Results

The paper focuses on models trained with UM data. It investigates the minimal model architecture required to perform the task. Results indicate that a minimum depth of m>0m > 08 is necessary for good generalization performance, with a sharp transition observed when reducing the depth.

Analysis of training dynamics reveals that the model first learns to copy sequences with periods shorter than the context length. Subsequently, it "groks" the solution for sequences with longer periods. Generalization to the test modulus m>0m > 09 emerges simultaneously with this grokking phenomenon. Ablation studies confirm that training exclusively on long-period sequences enables model generalization and eliminates the grokking phenomenon.

Interpreting How Transformers Predict PRNGs

The paper interprets the algorithms implemented by the models for both FM and UM cases, revealing common properties stemming from the underlying LCG structures.

Mixed Radix Representations

An LCG sequence with modulus $0 < a < m$0 can be represented as an 11-digit binary number:

$0 < a < m$1,

where $0 < a < m$2 are the binary digits (bits), and $0 < a < m$3.

Each bit has its own period along the LCG sequence. For a sequence of period $0 < a < m$4, the $0 < a < m$5 lowest digit has a period of $0 < a < m$6 along the sequence. This representation can be generalized to composite moduli using a mixed radix form:

$0 < a < m$7,

where $0 < a < m$8 is the prime factorization of $0 < a < m$9, and the digits 0≤c<m0 \leq c < m0.

The 0≤c<m0 \leq c < m1-step iteration of the LCG equation is:

0≤c<m0 \leq c < m2.

The period of each digit 0≤c<m0 \leq c < m3 reduces from 0≤c<m0 \leq c < m4 to 0≤c<m0 \leq c < m5.

Interpretability: Fixed Modulus

The algorithm implemented by Transformers trained on LCG with a fixed modulus involves:

  1. Finding mixed-radix representations of inputs from the learned prime factorization of 0≤c<m0 \leq c < m6.
  2. Looking back 0≤c<m0 \leq c < m7 steps in the context and copying the lowest 0≤c<m0 \leq c < m8 digits, for different prime factors of 0≤c<m0 \leq c < m9.
  3. Using these Tm(a,c)\mathcal{T}_m(a,c)0-step sequences, predict the higher digits of the simplified sequence.

PCA of the embedding matrix reveals that the model groups numbers into modulo Tm(a,c)\mathcal{T}_m(a,c)1 clusters along the first principal component and modulo Tm(a,c)\mathcal{T}_m(a,c)2 clusters along the second and third principal components. Analysis of attention weights confirms that the model looks back Tm(a,c)\mathcal{T}_m(a,c)3 steps in the context to predict the token at position Tm(a,c)\mathcal{T}_m(a,c)4, attending to positions Tm(a,c)\mathcal{T}_m(a,c)5. This allows the model to copy the lowest Tm(a,c)\mathcal{T}_m(a,c)6 bits and simplify the prediction of higher bits.

Interpretability: Generalization to Unseen Modulus

The algorithm implemented by Transformers trained on LCG with an unseen modulus involves:

  1. Encoding information about many possible prime factorizations.
  2. Using the largest number in the context to estimate the modulus.
  3. Implementing steps ii and iii from the fixed modulus algorithm.

PCA of the embedding layer shows a semi-circular structure in the first two principal components, resembling patterns observed in modular arithmetic tasks. Further analysis shows that the 2nd and 3rd principal components categorize numbers according to their values modulo Tm(a,c)\mathcal{T}_m(a,c)7 and Tm(a,c)\mathcal{T}_m(a,c)8, respectively.

Attention heads in the first layer group numbers according to their modular values with respect to various prime factors. One specific attention head in the first layer is responsible for estimating Tm(a,c)\mathcal{T}_m(a,c)9, focusing on the largest numbers within the context.

Scaling Up the Modulus

The paper investigates Transformer training when scaling up the modulus of LCG. It implements a byte-level tokenization scheme and uses Abacus embeddings [mcleish2024abacus] to encode positional information.

For FM training with 1≤Tm(a,c)≤m1 \leq \mathcal{T}_m(a,c) \leq m0, the model performs worse on sequences generated with spectrally optimal Steele multipliers compared to those from arbitrary multipliers. The number of in-context examples needed for 1≤Tm(a,c)≤m1 \leq \mathcal{T}_m(a,c) \leq m1 test accuracy scales sublinearly with modulus 1≤Tm(a,c)≤m1 \leq \mathcal{T}_m(a,c) \leq m2 as 1≤Tm(a,c)≤m1 \leq \mathcal{T}_m(a,c) \leq m3, where 1≤Tm(a,c)≤m1 \leq \mathcal{T}_m(a,c) \leq m4.

For the UM case, the number of in-context sequence elements required to reach 1≤Tm(a,c)≤m1 \leq \mathcal{T}_m(a,c) \leq m5 test accuracy scales sublinearly with 1≤Tm(a,c)≤m1 \leq \mathcal{T}_m(a,c) \leq m6. Test performance is influenced by the tokenization base.

Discussion

The paper provides insights into how neural networks learn deterministic sequences generated by LCGs. It uncovers the algorithms used by the models and highlights the model components that implement the steps of the algorithms. The model finds and utilizes prime factorizations of 1≤Tm(a,c)≤m1 \leq \mathcal{T}_m(a,c) \leq m7 and mixed-radix representations of numbers to simplify the sequences and make predictions. The paper also provides a modified training recipe for scaling up the modulus in both FM and UM settings.

Paper to Video (Beta)

No one has generated a video about this paper yet.

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Continue Learning

We haven't generated follow-up questions for this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 5 tweets with 8 likes about this paper.