
This article is part of our coverage of the latest in AI research.
Autoregressive models such as ChatGPT and DALL-E have made great inroads into generating realistic text and images. However, autoregressive models have limitations that make it impossible for them to achieve human-like cognitive abilities. This makes it difficult for them to solve problems that require different modes of thinking and reasoning.
Researchers from the University of Virginia, Stanford University, and Amazon GenAI have introduced a novel approach called Energy-Based World Models (EBWM) to address the limitations of traditional autoregressive models.
The researchers use the concept of EBWM to develop a modified version of the Transformer architecture, which is better suited for tasks where traditional models struggle.
The limitations of autoregressive world models
Self-supervised learning (SSL) has become a powerful approach for training large foundation models in computer vision, natural language processing, and speech. In SSL the model uses different techniques to use unlabeled data for training. One popular SSL method is to use autoregressive models that take in a sequence of elements and try to predict the next one. For example, in natural language generation, the model takes part of a text and tries to predict the next word or token, and compares its prediction with the actual text.
Traditional Autoregressive Models (TAM) are often considered to learn world models of their input spaces. However, TAMs struggle with seemingly simple human capabilities such as reasoning, planning, and thinking over extended time horizons.
The authors of the paper attribute this to fundamental differences in how TAMs work compared to the human brain.
The paper identifies four key cognitive facets found in humans that TAMs lack:
1. In humans, predictions about the future naturally shape the brain’s internal state. In TAMs, the internal state is not influenced by the predictions.
2. Humans naturally evaluate the plausibility of their predictions, while TAMs do not assess the strength or plausibility of their outputs.
3. Humans dedicate varying amounts of computational resources to making predictions or reasoning based on the task. This is often referred to as System 1 and System 2 thinking. TAMs treat all predictions equally.
4. Humans can model uncertainty in continuous spaces, while TAMs need to break everything down into discrete tokens.
The researchers argue that achieving these human-like cognitive capabilities in AI models could enhance their ability to reason, plan, and perform System 2 thinking at the same level as humans.
Energy-Based World Models
To address the limitations of TAMs, the researchers propose Energy-Based World Models (EBWM). The key idea is to approach world modeling as “making future state predictions in the input space” and “predicting the energy/compatibility of these future state predictions with the current context” using an Energy-Based Model (EBM).
EBMs use a form of contrastive learning where the model tries to measure the compatibility of various inputs. EBMs associate a scalar energy value with each configuration of input variables, producing lower values for highly compatible inputs and higher values for less compatible ones. The goal of the model is to reduce the energy.
An EBWM receives the current state and several possible future states. Its goal is to predict the compatibility between the initial state and the predicted future states. This is a scheme that can be applied to data modalities, including text, image, and sound.
By integrating its predictions in the input space, EBWM achieves the first and third cognitive facets—shaping the model’s internal state with its predictions and allowing dynamic computation. By predicting the compatibility of future states with the current context, EBWM addresses the second and fourth facets—evaluating predictions and modeling uncertainty.
The researchers emphasize that EBWM is a standalone, domain-agnostic solution that does not require a pre-trained model. It uses a reconstruction objective that makes it applicable to various domains. Reconstruction objectives are training methods where the model is given a piece of data (e.g., text or image) with parts of it missing. The model must learn to reconstruct the original data from its incomplete version.
The Energy-Based Transformer
To make EBMs competitive with modern architectures, the researchers designed the Energy-Based Transformer (EBT), a domain-agnostic transformer architecture that can learn Energy-Based World Models. EBT incorporates elements from diffusion models and makes changes to the attention structure and computation to incorporate various predictions and better account for future states.
According to the researchers, on computer vision tasks, EBT scales well in terms of data and GPU hours compared to TAMs. While EBWM initially learns slower than TAMs, as scale increases, it matches and eventually exceeds the performance of TAMs in data and GPU hour efficiency.
“This outcome is promising for higher compute regimes, as the scaling rate of EBWM is higher than TAMs as computation increases,” the researchers write.
Interestingly, the experiments also show that EBWM is less susceptible to overfitting compared to TAMs. The researchers attribute this to EBMs learning a joint distribution rather than a conditional one.
Complementing autoregressive models
While EBWM shows promising results, the researchers do not see it as a drop-in replacement for TAMs. “Having the four aspects of human cognition described, we see EBWM as different and even complementary to TAMs,” they write.
They acknowledge that for certain use cases, such as low-latency serving of large language models, the added inference overhead of EBWM’s gradient backpropagation may not be worth the extra computation.
However, they envision EBWM being particularly useful for scenarios requiring long-term System 2 thinking to solve challenging problems.