Papers
Topics
Authors
Recent
2000 character limit reached

Reason2Decide: Rationale-Driven Learning

Updated 26 December 2025
  • Reason2Decide is a framework for rationale-driven multi-task learning that decouples predictive accuracy from rationale-faithfulness through a structured two-stage training process.
  • It utilizes scheduled sampling to gradually replace gold labels with model predictions, effectively reducing exposure bias during rationale generation.
  • The framework delivers state-of-the-art performance in clinical and biomedical decision support using smaller models, ensuring efficient, on-premise deployment.

The Reason2Decide framework advances rationale-driven multi-task learning by explicitly addressing the separation between predictive accuracy and rationale-faithfulness in self-rationalizing models. Developed as a two-stage curriculum with scheduled sampling, Reason2Decide is designed to jointly optimize for robust, aligned predictions and high-fidelity free-text explanations, particularly in clinical and biomedical decision-support contexts. The framework reduces exposure bias, is robust to the source of rationales (human or LLM-generated), and delivers state-of-the-art performance with models that are substantially smaller than contemporary LLMs, thereby offering practical deployment advantages for resource-constrained environments (Hasan et al., 23 Dec 2025).

1. Model Objectives and Problem Formulation

Reason2Decide employs a T5-based text-to-text architecture fθf_\theta trained over input-label-rationale triples (x,y,r)(x, y, r), where xXx \in X is the task input (e.g., clinical text or question), yYy \in \mathcal{Y} is a discrete target label (e.g., triage category, Yes/No), and rRr\in\mathcal{R} is a free-text rationale sequence. Training objectives are defined in two stages:

  • Stage-1 (Rationale Foundation):
    • Minimize rationale generation loss only:

    Lr=E(x,r)[logPθ(rexplain:x)].L_r = -\mathbb{E}_{(x, r^*)}[\log P_\theta(r^*\,|\,\text{explain:}\,x)].

  • Stage-2 (Joint Prediction + Rationale):

    • Joint multi-task loss, with adaptive task-weighting αt\alpha_t:

    Ltotal(t)=αtLy+(1αt)Lr,L_\text{total}(t) = \alpha_t L_y + (1-\alpha_t) L_r,

    where

    Ly=E(x,y)[logPθ(ypredict:x)]L_y = -\mathbb{E}_{(x, y^*)}[\log P_\theta(y^*\,|\,\text{predict:}\,x)]

    and

    Lr=E(x,r)[logPθ(rgiven label:y~,explain:x)].L_r = -\mathbb{E}_{(x, r^*)}[\log P_\theta(r^*\,|\,\text{given label:}\,\tilde{y},\,\text{explain:}\,x)].

Here, y~\tilde{y} is a mixed label, derived via scheduled sampling (see Section 2). The output generation is always autoregressive.

2. Two-Stage Curriculum and Scheduled Sampling

Reason2Decide employs a training curriculum specifically structured to overcome exposure bias and misalignment between labels and explanations:

  • Stage-1: Rationale Foundation Training

    • The model is initialized and trained solely to generate rationales rr^* given the input xx, using either human-annotated or LLM-generated rationales as supervision.
    • Early stopping occurs on rationale validation loss.
    • The resulting parameters θ1\theta_1 initialize Stage-2.
  • Stage-2: Joint Prediction + Explanation with Task-Level Scheduled Sampling
    • At each step, the model computes both the label prediction loss LyL_y and the rationale loss LrL_r; explanations are conditioned on a mixed label y~\tilde{y}:
    • With probability 1πt1-\pi_t, use the gold label yy^*;
    • With probability πt\pi_t, use the model’s own prediction y^\hat{y}.
    • The sampling rate πt\pi_t is increased from 0 to 0.9 according to a schedule:

    πt={0t<w min(0.9,(tw)/m)wt<w+m 0.9tw+m\pi_t = \begin{cases} 0 & t < w \ \min(0.9, (t-w)/m) & w \leq t < w+m \ 0.9 & t \geq w+m \end{cases}

    where ww is a warm-up period (5% of training steps), mm is the transition duration (60%). - The multi-task weighting αt\alpha_t is linearly annealed from $0$ to $0.7$ during warm-up, remaining fixed thereafter.

Early stopping in Stage-2 is delayed until scheduled sampling plateaus and is monitored on validation macro-F1.

3. Model Architecture and Implementation

The backbone is a standard encoder–decoder T5, with model sizes ranging from 77M (Small) through 250M (Base) to 800M (Large). Typical architectural details:

  • Encoder/decoder stack: hidden size 512/768/1024, 12/12/24 layers

  • Prompt engineering:

    • Label prediction: “predict:” xx \rightarrow yy
    • Rationale generation: “given label:” yy “explain:” xx \rightarrow rr
  • Hyperparameters: AdamW, lr=5×105\text{lr}=5 \times 10^{-5}, batch size 64 (effective), max sequence length 1024, early stopping
  • Implementation: HuggingFace Transformers, 4 × A100 GPUs

Conditioning on the model's own predictions during rationale generation in Stage-2 ensures exposure at train time to the same conditioning regime encountered at inference.

4. Evaluation: Metrics and Empirical Results

Datasets:

  • Proprietary clinical triage (12-way classification; \sim200K examples); rationales from three provenance sources (nurse-authored, post-processed, LLM-generated)
  • PubMedQA (biomedical Yes/No/Maybe, gold long_answer as rationale)
  • BioASQ Yes/No (snippets summarized as rationale)

Prediction Performance (Macro-F1, T5-Large):

Dataset SFT (no rationale) DSS-Loss DSS-F1 Reason2Decide
Clinical Triage 56.85 ±7.21 58.09 ±0.77 59.43 ±1.07 60.58 ±0.46
PubMedQA 59.60 ±0.26 59.74 ±0.26 59.92 ±0.21 60.28 ±0.05
BioASQ (T5-Base) 53.70 ±12.29 59.99 ±13.58 66.28 ±0.54 68.02 ±2.19

Rationale Fidelity (Clinical Triage, T5-Large):

  • BERTScore (F1): 92.30 (Reason2Decide best)
  • BLEU: 24.13
  • LLM-as-a-Judge (Likert scale, 1–5):
    • Coverage: 4.80
    • Correctness: 4.43 (clear gain)
    • Overlap: 2.63

Ablation Findings:

  • Removing Stage-1 pretraining: F1 drops by 1.7–2.0 points.
  • Removing scheduled sampling: F1 drops by 2.4–2.6 points.
  • Skipping the warm-up leads to a 0.4–6.7 decrease in F1, depending on the task.
  • Excluding Stage-2 abolishes predictive capability (F1 = 0).

These results demonstrate that both the rationale pretraining and scheduled sampling are critical for maximizing predictive and explanation performance.

5. Analysis: Exposure Bias, Rationale Robustness, and Alignment

A major challenge in self-rationalizing models is exposure bias, where during training, rationales are conditioned on gold labels, but at inference, only model predictions are available as labels ("label-conditioning mismatch"). Reason2Decide mitigates this via task-level scheduled sampling, progressively exposing the model to its own predictions while explaining, leading to better alignment between decisions and explanations.

Rationale Source Robustness: The framework demonstrates high rationale-fidelity (within 1 F1 point) regardless of whether the rationales in Stage-1 are human-authored, post-processed, or LLM-generated. This suggests LLM-generated rationales are suitable for building explanation-robust models without the need for costly manual annotation.

Model Efficiency and Deployment: With state-of-the-art performance using an 800M parameter T5-Large model, Reason2Decide matches or surpasses much larger (8B–32B) LLM baselines on both predictive and rationale-fidelity metrics. The smaller footprint enables feasible on-premise deployment in clinical environments.

6. Positioning and Implications within Rationale-Driven Multi-Task Learning

Reason2Decide occupies a distinctive position among rationale-driven learning frameworks:

  • Unlike approaches such as RationaleCL (multi-task rationale tuning and contrastive replay) or RaDME (sequential score→rationale generation with LLM-to-student distillation), Reason2Decide directly intervenes on self-conditioning—explicitly aligning the rationale generation process to the prediction context through scheduled sampling (Xiong et al., 2023, Do et al., 28 Feb 2025).
  • It addresses scenarios with variable (including synthetic) rationale supervision, and empirically demonstrates negligible rationale-quality loss when using LLM-generated rationales in Stage-1, significantly reducing annotation costs.
  • The two-stage paradigm and exposure-bias mitigation are especially relevant for self-explaining biomedical and clinical AI, which demand both practical accuracy and transparent, human-interpretable reasoning.

In summary, Reason2Decide establishes a generalizable and practical template for rationale-driven multi-task learning. It alleviates conditioning mismatch, is robust across explanation sources, and achieves state-of-the-art efficiency and alignment in high-stakes medical decision support (Hasan et al., 23 Dec 2025).

Whiteboard

Topic to Video (Beta)

Follow Topic

Get notified by email when new papers are published related to Reason2Decide Framework.