Masked Diffusion Models: Exploring A New Frontier
Could MDMs introduce the concept of exploration into language models?
One reason why it pays to scout ahead in the field of AI development is that things often get worked out on paper before they hit the market. Ideas in different areas sometimes combine and cross-pollinate, resulting in promising new hybrids. The latest one of these is called Masked Diffusion Models (MDMs) for language processing.
MDMs combine two key concepts diffusion, and masking.
Diffusion
Perhaps you’ve heard of stable diffusion for image generation models, such as MidJourney. Stable diffusion models work by applying a diffusion process to a dataset, gradually corrupting it, and then reconstruct (or “denoise”) the corrupted dataset to generate new samples or recover missing information. Diffusion mimics the physical process of particle diffusion, where particles move from areas of high concentration to low concentration, eventually reaching equilibrium.
Originally, this process was found to work well for image generation tasks, for reasons that I have never found a clear explanation for, but which probably has to do with the detachable layered information complexity of visual stimuli. The same object can have its form detached from its content; hence why it’s not impossible for you to visualize a pink apple, or a diamond in the shape of an apple. Diffusion allows an object to be modified through some chaotic content transformation, and then rebuilt in its original form, with the changed characteristics preserved by generating new samples. Researchers are now looking to apply the method to language processing, and are seeing some interesting results.
Masking
Masking is a symbolic processing technique that shows up frequently in machine learning. Masking involves hiding or replacing pieces of data from the model and having it predict the missing or replaced elements. Masking helps the model learn meaningful relationships in the data by forcing it to rely on context, patterns, or dependencies rather than memorizing individual inputs.
Masked Diffusion
In simple terms, MDMs work by generating a space of masked replacements on the training data, and then having the model reconstruct the inputs. This allows the model to learn flexibly by combining both concepts.
What’s Cool About MDMs?
In traditional autoregressive language models like the GPT family, sequences are processed from left to right, one token at a time. And every time a new predicted token is shifted onto the queue on the right, the whole sequence is recomputed. This process is structured and ordered, but has a few problems. It introduces a left-to-right bias to the model, where it may learn patterns that flow in this direction arbitrarily. It also tends to struggle when sequences grow excessively large.
MDMs approach it differently. Instead of taking the input prompt and extending it left to right like autoregressive models do, you can think of MDMs embedding the input prompt in the middle. They then sample tokens bidirectionally both to the right and left of the input, and the randomly altered states are replaced with masks that later, during the denoising, reconstruction phase, the model tries to “inpaint” or fill in. MDMs are particularly effective for tasks like text completion, inpainting, or reasoning, as the iterative nature allows for better contextual reconstruction compared to autoregressive models. MDMs allow non-sequential token updates (parallel sampling) and leverage the iterative refinement process to capture complex dependencies in text.
MDMs are still in their early stages, and are about as capable as the now obsolete GPT-2 is. That’s not bad for a technique that’s only a few years old. However, they exhibit crucial advantages over their autoregressive cousins. They are particularly adept at learning logical reversals. So an autoregressive model, might learn that A=B, as it sees it flow from left to right in its training data. But it may have a hard time reversing the generalization, to figure out that if A=B, then B=A. MDMs excel at this sort of flexibility, since they are not constrained by a positional encoding bias.
I am also excited by their ability to predict masked sequences. The inclusion of masking is powerful, because it lends inferential power to the model’s training. It doesn’t simply learn patterns in place, but rather, it must reconstruct patterns by predicting values for placeholders. This sort of metonymy or algebraic substitution operation is at the heart of complex reasoning. Human beings reason with certain algebraicized generic schemas, in which we have an overall conceptual frame or invariant that we then slot in with specific, contingent, contextual data when a specific situation calls for it. This allows us to bind the universal to the particular in a flexible and inexhaustible way.
Already researchers are finding that diffusion based language models blow autoregressive models out of the water when it comes to certain with-replacement games such as sudoku. They could also be excellent for reconstructing damaged texts and similar fill-in-the-blanks tasks.
MDMs have some drawbacks, for instance, they’re more computationally expensive to run in their current incarnations. That’s because all that noise adds complexity to processing. This is something researchers will surely learn to optimize in future iterations, however. Indeed, energy based diffusion models are already making strides in this direction.
I’m going off on a limb here, but I also see MDMs capable of doing something I’ve long sought to see done in AI, which modeling a concept of “relative ignorance.” True intelligence does not just involve applying fixed reasoning operations over static, closed knowledge. It involves negotiating a complex, uncertain world, and having a relationship with one’s own ignorance. Curiosity is a cognitive and emotional process that involves recognizing one’s own ignorance and being compelled to explore new information in the desire to “satiate” one’s lack of knowledge. Combining masking and diffusion could provide a framework for artificial curiosity that could have profound implications. By blending the concepts of “relative ignorance” (masks) and diffusion (exploration) these models may be able to transcend the face-value content of their initial training data in surprising ways.
If you now happen to be curious about MDMs, you can read up about them in detail here: