CATR: Confounding-Aware Token Rationalization
- The paper introduces a novel framework that uses HSIC-based diagnostics to select a sparse subset of tokens from high-dimensional text, ensuring unconfoundedness.
- It employs a relaxed-Bernoulli mask and a multi-term loss to balance predictive utility, sparsity, and preservation of confounding signals.
- Empirical evaluations demonstrate that CATR reduces bias and variance while increasing effective sample sizes and providing stable treatment effect estimates.
Confounding-Aware Token Rationalization (CATR) is a methodological framework designed to address the challenges posed by high-dimensional text covariates in causal effect estimation from observational data. By selectively identifying a sparse subset of tokens carrying the requisite confounding signal, CATR mitigates positivity violations and instability in inverse-probability weighted estimators that often arise when full text representations are used. CATR employs a residual-independence diagnostic based on the Hilbert–Schmidt Independence Criterion (HSIC) to ensure the selected token set suffices for unconfoundedness, optimizing a multi-term loss that balances predictive utility, parsimony, and preservation of confounding structure (Zhang et al., 5 Dec 2025).
1. Problem Setting and Motivation
Consider an observational paper where each unit is characterized by:
- A high-dimensional text sequence ,
- A binary treatment ,
- An outcome or .
Under the potential-outcomes framework, each unit possesses counterfactuals . The scientific estimand is the average treatment effect (ATE):
Identification requires:
- Unconfoundedness: ,
- Positivity: for all ,
- Consistency: .
Only a typically unknown subset contains the true confounding information, while redundant tokens are unrelated. Conditioning on the entire is standard practice but, in high-dimensional settings, leads to "observational-level positivity violations", where the estimated propensity falls outside the desired interval even if satisfies overlap.
Positivity violations manifest as extreme propensity scores, large inverse-probability weights, high estimator variance, and inflated frequency of propensity score clipping. Toy examples (see Figure 1 in (Zhang et al., 5 Dec 2025)) demonstrate the undesirable concentration of at the boundaries and an increased clipping fraction as proxy dimension rises.
2. Core Methodology: CATR Framework
2.1 Token Selection and Predictive Model
CATR employs a selector network assigning each token a score . A relaxed-Bernoulli distribution (e.g., Gumbel-Softmax) provides a continuous mask , forming the rationalized subsequence .
A shared predictor network receives and outputs:
- ,
- .
2.2 Residual-Independence Diagnostic
To verify sufficiency, CATR introduces a residual-independence diagnostic. For candidate subset , define batch residuals:
The empirical Hilbert–Schmidt Independence Criterion (HSIC):
where , , and .
Proposition 1 demonstrates:
Therefore, nonzero observed HSIC indicates the selection is insufficient to block confounding paths.
2.3 Optimization Objective
The full optimization problem,
comprises:
- : Sum of cross-entropy losses for and ,
- : Sparsity penalty (entropy/KL),
- : Residual-independence metric.
Hyperparameters control sparsity and independence trade-offs.
3. Algorithmic Implementation
Algorithm 1 in (Zhang et al., 5 Dec 2025) proceeds via stochastic gradient descent in minibatches of size :
- Sample minibatch .
- Compute selection scores .
- Sample relaxed mask .
- Form rationalized text .
- Predict and .
- Compute residuals and batch HSIC.
- Accumulate losses and backpropagate.
Primary hyperparameters:
- Relaxed-Bernoulli temperature ,
- Token budget prior (if KL penalty used),
- Sparsity and independence weights , .
Batchwise HSIC is per batch but tractable for . Computational complexity is dominated by HSIC and the encoder.
4. Theoretical Properties
Theoretical guarantees under regularity assumptions (bounded embeddings, Sobolev smoothness, bounded penalties) include:
- Fast nonasymptotic rates: For ,
- Product-rate condition: Nuisance estimators , permit doubly-robust inference.
- Consistency and efficiency: IPW and AIPW are consistent under empirical overlap; AIPW is -consistent and asymptotically normal if propensities are bounded away from 0 and 1:
5. Empirical Evaluation
5.1 Semi-Synthetic MIMIC-III Experiment
- Design: Physician notes with confounder signal detected via infection-keyword indicators; nonlinear transformations define and .
- Method comparison: TARNet, CFRNet, DragonNet, CausalBERT, CATR.
- Estimators: OR, IPW, AIPW.
- Metrics: Absolute bias, empirical SD, bootstrap SE, CI coverage, effective sample size (ESS) ratio, clipping fraction.
Results: CATR achieves the lowest bias and variance, highest CI coverage, better ESS ratio (≈0.31 vs. 0.27 baseline), and fewer extreme propensities (≈10.7% vs. 29–31%). Ablation shows removing HSIC or sparsity deteriorates all performance measures and increases clipping.
Qualitative findings: Without HSIC regularization, top tokens selected are largely spurious. With HSIC, all salient infection confounders are recovered.
5.2 Real-World MIMIC-III Study
- Cohort: Septic ICU patients,
- Treatment: IV fluid bolus,
- Outcome: ICU readmission,
- Covariates: 42 structured plus unstructured text,
- Adjustments: structured-only, multimodal, multimodal+CATR.
Results: Structured-only yields low ESS (0.21) and 8% clipping. Adding text (no CATR) slightly improves ESS (0.28) but with 20% clipping and unstable ATE. Multimodal+CATR achieves ESS ≈0.85, 0% clipped propensities, and stable ATE (≈–0.026, SE ≈0.041). Token selections highlight clinically meaningful terms such as “sepsis” and “infection”.
6. Limitations and Future Directions
Key limitations:
- HSIC provides a soft test; zero HSIC is necessary but not sufficient for unconfoundedness.
- Mini-batch HSIC estimates are noisy with small batch sizes.
- Hyperparameter selection (especially , token allowance) demands careful validation and may be computationally intensive.
- Token selection quality is contingent on the information encoded by the fixed pretrained embedding.
Ongoing and future research directions include:
- Integrating counterfactual necessity and sufficiency diagnostics.
- Combining CATR with post-hoc calibration or covariate balancing to stabilize IPW.
- Extending token selection to multimodal (text and structured) covariates.
- Exploring alternative or differentiable independence criteria.
7. Context and Significance
CATR addresses the unique methodological problem posed by high-dimensional text when used as covariates in causal inference. By selecting a minimal sufficient subset of tokens, CATR avoids the fragility of propensity estimation in large text spaces while preserving requisite confounding information relevant for effect identification. Empirically, CATR demonstrates improved estimator stability, reduced bias/variance, enhanced effective sample size, and improved interpretability over baseline and ablated approaches. The framework thus advances the robust integration of unstructured text into causal effect estimation pipelines, especially in settings where naive adjustment for entire documents would undermine positivity and statistical efficiency.