Emergent Mind

Abstract

Pre-trained language models (LMs) are able to perform complex reasoning without explicit fine-tuning. To understand how pre-training with a next-token prediction objective contributes to the emergence of such reasoning capability, we propose that we can view an LM as deriving new conclusions by aggregating indirect reasoning paths seen at pre-training time. We found this perspective effective in two important cases of reasoning: logic reasoning with knowledge graphs (KGs) and math reasoning with math word problems (MWPs). More specifically, we formalize the reasoning paths as random walk paths on the knowledge/reasoning graphs. Analyses of learned LM distributions suggest that a weighted sum of relevant random walk path probabilities is a reasonable way to explain how LMs reason. Experiments and analysis on multiple KG and MWP datasets reveal the effect of training on random walk paths and suggest that augmenting unlabeled random walk reasoning paths can improve real-world multi-step reasoning performance. code: https://github.com/WANGXinyiLinda/LM_random_walk

Hypothesis: Language models learn by aggregating random walk paths on a knowledge reasoning graph.

Overview

  • The paper introduces a new method for understanding pre-trained language models' (LMs) reasoning abilities, focusing on logic reasoning with knowledge graphs (KGs) and math reasoning with math word problems (MWPs).

  • It proposes that LMs generate new conclusions by aggregating indirect reasoning paths encountered during pre-training, a hypothesis tested with experiments involving KGs and MWPs.

  • For logical reasoning, a small Transformer model trained on random walk paths from KGs showed that LMs can perform reasoning tasks effectively by optimally weighting logical rules.

  • The study also demonstrates improved math reasoning on MWPs when LMs are trained with random walk reasoning paths derived from existing Chain of Thought (CoT) training data.

Overview

This paper introduces a novel approach to understanding the reasoning abilities of pre-trained language models (LMs), specifically focusing on how LLMs can perform complex reasoning tasks without explicit fine-tuning. The authors propose viewing LMs as systems that aggregate indirect reasoning paths seen during pre-training, applying this perspective to two key areas: logic reasoning with knowledge graphs (KGs) and math reasoning with math word problems (MWPs).

Reasoning Paths Aggregation

The core hypothesis is that LMs are capable of generating new conclusions by aggregating reasoning paths encountered during pre-training. This hypothesis was tested in scenarios involving KGs and MWPs, with LMs formalizing reasoning paths as random walks on knowledge/reasoning graphs. The paper demonstrates that a weighted sum of relevant random walk path probabilities can explain how LMs reason, suggesting that training on random walk paths enhances real-world multi-step reasoning performance.

Logical Reasoning Analysis

The investigation begins with logical reasoning over KGs. The paper details an experiment involving the training of a small Transformer model on random walk paths derived from KGs. It was found that the LM's distribution closely resembles a weighted aggregation of possible random walk paths, indicating LMs can effectively perform reasoning tasks by weighting logical rules optimally. This result was supported by analyses showing an optimal random walk path length for effective reasoning, bolstered by further experiments showing improved reasoning capabilities with augmented unlabeled random walk reasoning paths.

Mathematical Reasoning Expansion

The study extends its findings to math reasoning, focusing on the challenge of solving MWPs. The research methodology involves continuing the training of a pre-trained base LM with random walk reasoning paths generated from existing Chain of Thought (CoT) training data. The results reveal consistent improvements over vanilla supervised fine-tuning techniques for MWPs, affirming the paper's hypothesis that LMs utilize and benefit from an aggregation of random walk reasoning paths.

Implications and Future Work

The findings have significant implications for both academic research and practical applications in AI. Understanding how LMs can harness pre-training data to enhance reasoning abilities could lead to more advanced AI systems capable of complex problem-solving. The research suggests potential for further explorations into how different data augmentation techniques, specifically around random walk reasoning paths, can further refine and improve the performance of LMs in reasoning tasks.

Conclusion

This paper provides insightful analyses and robust empirical evidence supporting the hypothesis that LMs can aggregate indirect reasoning paths to enhance their reasoning capabilities. By dissecting the reasoning ability of LMs from the perspective of reasoning paths aggregation, the authors offer a compelling framework that not only sheds light on how these models learn but also opens pathways for future research aimed at optimizing LM pre-training and fine-tuning processes for advanced reasoning tasks.

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.