This paper, "Improve Vision LLM Chain-of-thought Reasoning" (Zhang et al., 21 Oct 2024 ), addresses the challenge of improving Chain-of-Thought (CoT) reasoning capabilities in Vision LLMs (VLMs). The authors note that existing VLM training data is often dominated by short answers with minimal rationales, which hinders the models' ability to generalize to tasks requiring detailed reasoning steps. They demonstrate that training solely on short answers does not effectively improve CoT performance.
To tackle this, the paper proposes a two-pronged approach:
- CoT Data Distillation and Supervised Fine-Tuning (SFT): The authors generate a large dataset of visual CoT examples by distilling rationales from the GPT-4o model. Leveraging existing VQA datasets that have short ground truth annotations, GPT-4o is prompted to produce detailed reasoning steps leading to the correct answer. This process creates the ShareGPT-4o-Reasoning dataset, comprising 193k examples covering various reasoning types (real-world knowledge, chart understanding, document/text understanding, math, and science). This dataset is used to fine-tune an open-source VLM (LLaMA3-LLaVA-NeXT-8B initialized with Open-LLaVA-NeXT weights) for CoT prediction. The SFT training includes both the distilled CoT examples and corresponding direct answers, using specific prompt formats for each task type ("Generate a reason first and then output..." for CoT, "Answer the question with a short answer" for direct).
- Reinforcement Learning with Direct Preference Optimization (DPO): To further refine the reasoning quality, the SFT model is used to generate multiple candidate reasoning chains for a given question. These generated responses are then compared against the ground truth short answer. Responses leading to the correct answer are treated as positive examples (), and those leading to incorrect answers are negative examples (). These positive and negative pairs are used to train the model using the DPO algorithm. This process helps align the VLM's reasoning process towards more accurate outcomes based on self-generated signals, without requiring additional human-labeled preference data. The DPO training uses a subset of the data (A-OKVQA, ChartQA, math domains) to construct preference pairs.
The experimental results demonstrate significant improvements. SFT on the distilled CoT data substantially boosts VLM performance on CoT reasoning across a range of benchmark datasets (A-OKVQA, ChartQA, DocVQA, InfoVQA, TextVQA, AI2D, ScienceQA, MathVista). Ablation studies show that training on CoT data alone is more effective at improving CoT performance than training on direct answers, and importantly, training on CoT data also generalizes well to improve direct answer prediction. The best performance is achieved when combining both CoT and direct data during SFT.
The subsequent DPO training further enhances CoT reasoning performance. DPO trained on the model-generated reasoning preference pairs achieves better improvements in CoT accuracy compared to DPO trained on a general hallucination reduction dataset like RLAIF-V. This indicates that DPO is effective at calibrating reasoning using task-specific preference data derived from the model's own outputs.
The paper also shows that the DPO-trained model can function as an effective verifier for re-ranking candidate reasoning chains generated by the SFT model. This re-ranking, especially using strategies like Best-of-N or Weighted Voting based on DPO scores, leads to improved performance on challenging datasets like MMMU. Analysis of the DPO token-level rewards indicates that the DPO model learns to assign negative scores to errors or hallucinations within the reasoning process, suggesting it develops sensitivity to factual correctness and logical flow.
Practical Implementation Details:
- Data Generation: The distillation process can be implemented by iterating through existing VQA datasets, formatting the questions and known short answers into the specified GPT-4o prompt, and parsing the generated CoT and final answer. Filtering steps are crucial to handle cases where GPT-4o's generated answer doesn't match the ground truth, potentially indicating errors in the original annotation.
- SFT Training: Requires a base VLM architecture (e.g., LLaMA3-LLaVA-NeXT-8B). The training data needs to be prepared with distinct prompts for direct and CoT tasks. Training involves standard fine-tuning procedures on GPUs. The paper uses 8 H100 GPUs for 1 epoch.
- DPO Training: This requires generating a pool of responses (e.g., 32 candidates per question with temperature sampling) from the SFT model. An automated comparison of the generated final answer against the ground truth is needed to label responses as "correct" (positive) or "incorrect" (negative). Preference pairs are then constructed from these. The DPO algorithm can be implemented using libraries like TRL or custom code, initializing the policy and reference models with the SFT weights. Hyperparameter tuning (e.g., , learning rate, truncation length) is important for optimal performance. The paper found truncating responses to 90 tokens beneficial.
- Inference and Evaluation: For CoT evaluation, models need to adhere to a specific format (e.g., ending with "### Answer:") to enable automated extraction of the final answer. Evaluation requires established VQA benchmarks and associated evaluation metrics.
- DPO as Verifier: To use DPO for re-ranking, multiple responses (N > 1) need to be sampled for each question. The DPO reward score (or log probability ratio) for each generated response is calculated. These scores can be used to select the highest-scoring response (Best-of-N) or combine votes based on scores (Weighted Voting).
Trade-offs and Considerations:
- Data Cost: Distilling data from a powerful proprietary model like GPT-4o incurs API costs. The size of the distilled dataset (193k) is substantial but manageable.
- Computational Resources: SFT and especially RL training are computationally intensive, requiring multiple high-end GPUs.
- Prompt Sensitivity: The performance of distilled data and GPT-4o evaluation is highly sensitive to prompt phrasing, requiring careful engineering and validation on development sets.
- Generality of DPO Data: While DPO on task-specific reasoning data shows strong results, creating diverse reasoning preference data across all potential tasks can be challenging. The paper demonstrates generalization benefits even when DPO data is limited to a subset of domains.
- Answer Extraction: Reliable automatic extraction of final answers from CoT responses is critical for evaluation and generating preference pairs. Rule-based extraction based on markers like "### Answer:" is a practical approach but requires the model to consistently follow the format.
The work contributes a valuable dataset (ShareGPT-4o-Reasoning) and demonstrates effective strategies for training VLMs with enhanced CoT reasoning capabilities, which is crucial for developing more transparent, reliable, and general-purpose multimodal AI systems.