Why machine learning struggles with causality

machine learning causality

This article is part of our reviews of AI research papers, a series of posts that explore the latest findings in artificial intelligence.

When you look at the following short video sequence, you can make inferences about causal relations between different elements. For instance, you can see the bat and the baseball player’s arm moving in unison, but you also know that it is the player’s arm that is causing the bat’s movement and not the other way around. You also don’t need to be told that the bat is causing the sudden change in the ball’s direction.

Likewise, you can think about counterfactuals, such as what would happen if the ball flew a bit higher and didn’t hit the bat.

baseball bat hitting ball

Such inferences come to us humans intuitively. We learn them at a very early age, without being explicitly instructed by anyone and just by observing the world. But for machine learning algorithms, which have managed to outperform humans in complicated tasks such as go and chess, causality remains a challenge. Machine learning algorithms, especially deep neural networks, are especially good at ferreting out subtle patterns in huge sets of data. They can transcribe audio in real-time, label thousands of images and video frames per second, and examine x-ray and MRI scans for cancerous patterns. But they struggle to make simple causal inferences like the ones we just saw in the baseball video above.

In a paper titled “Towards Causal Representation Learning,” researchers at the Max Planck Institute for Intelligent Systems, the Montreal Institute for Learning Algorithms (Mila), and Google Research, discuss the challenges arising from the lack of causal representations in machine learning models and provide directions for creating artificial intelligence systems that can learn causal representations.

This is one of several efforts that aim to explore and solve machine learning’s lack of causality, which can be key to overcoming some of the major challenges the field faces today.

Independent and identically distributed data

Why do machine learning models fail at generalizing beyond their narrow domains and training data?

“Machine learning often disregards information that animals use heavily: interventions in the world, domain shifts, temporal structure — by and large, we consider these factors a nuisance and try to engineer them away,” write the authors of the causal representation learning paper. “In accordance with this, the majority of current successes of machine learning boil down to large scale pattern recognition on suitably collected independent and identically distributed (i.i.d.) data.”

i.i.d. is a term often used in machine learning. It supposes that random observations in a problem space are not dependent on each other and have a constant probability of occurring. The simplest example of i.i.d. is flipping a coin or tossing a die. The result of each new flip or toss is independent of previous ones and the probability of each outcome remains constant.

When it comes to more complicated areas such as computer vision, machine learning engineers try to turn the problem into an i.i.d. domain by training the model on very large corpora of examples. The assumption is that, with enough examples, the machine learning model will be able to encode the general distribution of the problem into its parameters. But in the real world, distributions often change due to factors that cannot be considered and controlled in the training data. For instance, convolutional neural networks trained on millions of images can fail when they see objects under new lighting conditions or from slightly different angles or against new backgrounds.

ImageNet images vs ObjectNet images
Objects in training datasets vs objects in the real world (source: objectnet.dev)

Efforts to address these problems mostly include training machine learning models on more examples. But as the environment grows in complexity, it becomes impossible to cover the entire distribution by adding more training examples. This is especially true in domains where AI agents must interact with the world, such as robotics and self-driving cars. Lack of causal understanding makes it very hard to make predictions and deal with novel situations. This is why you see self-driving cars make weird and dangerous mistakes even after having trained for millions of miles.

“Generalizing well outside the i.i.d. setting requires learning not mere statistical associations between variables, but an underlying causal model,” the AI researchers write.

Causal models also allow humans to repurpose previously gained knowledge for new domains. For instance, when you learn a real-time strategy game such as Warcraft, you can quickly apply your knowledge to other similar games StarCraft and Age of Empires. Transfer learning in machine learning algorithms, however, is limited to very superficial uses, such as finetuning an image classifier to detect new types of objects. In more complex tasks, such as learning video games, machine learning models need huge amounts of training (thousands of years’ worth of play) and respond poorly to minor changes in the environment (e.g., playing on a new map or with a slight change to the rules).

“When learning a causal model, one should thus require fewer examples to adapt as most knowledge, i.e., modules, can be reused without further training,” the authors of the causal machine learning paper write.

Causal learning

causal graph

So, why has i.i.d. remained the dominant form of machine learning despite its known weaknesses? Pure observation-based approaches are scalable. You can continue to achieve incremental gains in accuracy by adding more training data, and you can speed up the training process by adding more compute power. In fact, one of the key factors behind the recent success of deep learning is the availability of more data and stronger processors.

i.i.d.-based models are also easy to evaluate: Take a large dataset, split it into training and test sets, tune the model on the training data, and validate its performance by measuring the accuracy of its predictions on the test set. Continue the training until you reach the accuracy you require. There are already many public datasets that provide such benchmarks, such as ImageNet, CIFAR-10, and MNIST. There are also task-specific datasets such as the COVIDx dataset for covid-19 diagnosis and the Wisconsin Breast Cancer Diagnosis dataset. In all cases, the challenge is the same: Develop a machine learning model that can predict outcomes based on statistical regularities.

But as the AI researchers observe in their paper, accurate predictions are often not sufficient to inform decision-making. For instance, during the coronavirus pandemic, many machine learning systems began to fail because they had been trained on statistical regularities instead of causal relations. As life patterns changed, the accuracy of the models dropped.

Causal models remain robust when interventions change the statistical distributions of a problem. For instance, when you see an object for the first time, your mind will subconsciously factor out lighting from its appearance. That’s why, in general, you can recognize the object when you see it under new lighting conditions.

Causal models also allow us to respond to situations we haven’t seen before and think about counterfactuals. We don’t need to drive a car off a cliff to know what will happen. Counterfactuals play an important role in cutting down the number of training examples a machine learning model needs.

Causality can also be crucial to dealing with adversarial attacks, subtle manipulations that force machine learning systems to fail in unexpected ways. “These attacks clearly constitute violations of the i.i.d. assumption that underlies statistical machine learning,” the authors of the paper write, adding that adversarial vulnerabilities are proof of the differences in the robustness mechanisms of human intelligence and machine learning algorithms. The researchers also suggest that causality can be a possible defense against adversarial attacks.

ai adversarial example panda gibbon
Adversarial attacks target machine learning’s sensitivity to i.i.d. In this image, adding a imperceptible layer of noise to this panda picture causes a convolutional neural network to mistake it for a gibbon.

In a broad sense, causality can address machine learning’s lack of generalization. “It is fair to say that much of the current practice (of solving i.i.d. benchmark problems) and most theoretical results (about generalization in i.i.d. settings) fail to tackle the hard open challenge of generalization across problems,” the researchers write.

Adding causality to machine learning

In their paper, the AI researchers bring together several concepts and principles that can be essential to creating causal machine learning models.

Two of these concepts include “structural causal models” and “independent causal mechanisms.” In general, the principles state that instead of looking for superficial statistical correlations, an AI system should be able to identify causal variables and separate their effects on the environment.

This is the mechanism that enables you to detect different objects regardless of the view angle, background, lighting, and other noise. Disentangling these causal variables will make AI systems more robust against unpredictable changes and interventions. As a result, causal AI models won’t need huge training datasets.

“Once a causal model is available, either by external human knowledge or a learning process, causal reasoning allows to draw conclusions on the effect of interventions, counterfactuals and potential outcomes,” the authors of the causal machine learning paper write.

The authors also explore how these concepts can be applied to different branches of machine learning, including reinforcement learning, which is crucial to problems where an intelligent agent relies a lot on exploring environments and discovering solutions through trial and error. Causal structures can help make the training of reinforcement learning more efficient by allowing them to make informed decisions from the start of their training instead of taking random and irrational actions.

The researchers provide ideas for AI systems that combine machine learning mechanisms and structural causal models: “To combine structural causal modeling and representation learning, we should strive to embed an SCM into larger machine learning models whose inputs and outputs may be high-dimensional and unstructured, but whose inner workings are at least partly governed by an SCM (that can be parameterized with a neural network). The result may be a modular architecture, where the different modules can be individually fine-tuned and re-purposed for new tasks.”

Such concepts bring us closer to the modular approach the human mind uses (at least as far as we know) to link and reuse knowledge and skills across different domains and areas of the brain.

causal machine learning model
Combining causal graphs with machine learning will enable AI agents to create modules that can be applied to different tasks without much training

It is worth noting, however, that the ideas presented in the paper are at the conceptual level. As the authors acknowledge, implementing these concepts faces several challenges: “(a) in many cases, we need to infer abstract causal variables from the available low-level input features; (b) there is no consensus on which aspects of the data reveal causal relations; (c) the usual experimental protocol of training and test set may not be sufficient for inferring and evaluating causal relations on existing data sets, and we may need to create new benchmarks, for example with access to environment information and interventions; (d) even in the limited cases we understand, we often lack scalable and numerically sound algorithms.”

But what’s interesting is that the researchers draw inspiration from much of the parallel work being done in the field. The paper contains references to the work done by Judea Pearl, a Turing Award–winning scientist best known for his work on causal inference. Pearl is a vocal critic of pure deep learning methods. Meanwhile, Yoshua Bengio, one of the co-authors of the paper and another Turing Award winner, is one of the pioneers of deep learning.

The paper also contains several ideas that overlap with the idea of hybrid AI models proposed by Gary Marcus, which combines the reasoning power of symbolic systems with the pattern recognition power of neural networks. The paper does not, however, make any direct reference to hybrid systems.

The paper is also in line with system 2 deep learning, a concept first proposed by Bengio in a talk at the NeurIPS 2019 AI conference. The idea behind system 2 deep learning is to create a type of neural network architecture that can learn higher representations from data. Higher representations are crucial to causality, reasoning, and transfer learning.

While it’s not clear which of the several proposed approaches will help solve machine learning’s causality problem, the fact that ideas from different—and often conflicting—schools of thought are coming together is guaranteed to produce interesting results.

“At its core, i.i.d. pattern recognition is but a mathematical abstraction, and causality may be essential to most forms of animate learning,” the authors write. “Until now, machine learning has neglected a full integration of causality, and this paper argues that it would indeed benefit from integrating causal concepts.”

4 COMMENTS

  1. Idea – Since the hidden fully connected layer has output that looks random. The last layer has not much error correct capability as if had been a photo or an image. This creates a classification problem. Create a noise free next to last image that can be error corrected?

  2. What about reasoning backwards from the assumption that every event is caused? Then use the data to rule out discrete theories. Instead of looking for correlations you would look for interruptions in statistical patterns.

  3. This is a very interesting read. What do you think about Bayesian Neural Networks (BNNs)? To me, they appear to be the closest merge between the two fields.

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.