https://arxiv.org/pdf/2407.07612
Teaching Transformers Causal Reasoning through Axiomatic Training
Aniket Vashishtha Microsoft Research, India
Abhinav Kumar Massachusetts Institute of Technology, USA
Abbavaram Gowtham Reddy IIT Hyderabad, India
Vineeth N Balasubramanian IIT Hyderabad, India
Amit Sharma Microsoft Research, India [email protected]
Abstract
For text-based AI systems to interact in the real world, causal reasoning is an es- sential skill. Since interventional data is costly to generate, we study to what extent an agent can learn causal reasoning from passive data. Specifically, we consider an axiomatic training setup where an agent learns from multiple demonstrations of a causal axiom (or rule), rather than incorporating the axiom as an inductive bias or inferring it from data values. A key question is whether the agent would learn to generalize from the axiom demonstrations to new scenarios. For example, if a transformer model is trained on demonstrations of the causal transitivity axiom over small graphs, would it generalize to applying the transitivity axiom over large graphs? Our results, based on a novel axiomatic training scheme, indicate that such generalization is possible. We consider the task of inferring whether a variable causes another variable, given a causal graph structure. We find that a 67 million parameter transformer model, when trained on linear causal chains (along with some noisy variations) can generalize well to new kinds of graphs, including longer causal chains, causal chains with reversed order, and graphs with branching; even when it is not explicitly trained for such settings. Our model performs at par (or even better) than many larger language models such as GPT-4, Gemini Pro, and Phi-3. Overall, our axiomatic training framework provides a new paradigm of learning causal reasoning from passive data that can be used to learn arbitrary axioms, as long as sufficient demonstrations can be generated.
1 Introduction Causal reasoning can be defined as a set of reasoning procedures consistent with pre-defined axioms or rules that are specific to causality [11]. For instance, d-separation and rules of do-calculus can be considered as axioms and specifications of a collider or a backdoor set can be considered as rules that can be derived from axioms. Typically, causal reasoning is done over data corresponding to variables in a system. Axioms or rules are incorporated as inductive biases in a machine learning (ML) model, through regularization, model architecture, or the choice of variables for a particular analysis. Depending on the kind of available data—observational, interventional, or counterfactual—Pearl’s ladder of causation [5] defines the kinds of causal reasoning that is possible.
As axioms are the building blocks of causality, we study whether it is possible to directly learn the axioms using ML models. That is, rather than learning from data that is the result of axioms followed by a data-generating process, what if a model can learn an axiom (and thus causal reasoning) directly from symbolic demonstrations of the axiom? Such a model has the advantage that it can be applied for causal reasoning in diverse downstream scenarios, compared to task-specific causal models built
Preprint. Under review.
ar X
iv :2
40 7.