Multi-modal data and the magic of GANs
One of the places where I’m most excited about all these new foundational models in biotech is the potential to build models that leverage lots of different data modalities. So this week, to start getting warmed up for what this might look like, I want to dig into a trick that some well known models use to merge information about two popular data modalities: Text and image. That’s right, this week it’s an overview of how text-based image generation models like DALL-E and Stable Diffusion work.
But first, to give some more biological motivation, let’s start with a scenario you may have encountered: After a year or two of screening thousands of compounds / proteins / sequences / cells, you want to train a model to predict new compounds / proteins / sequences / cells. The problem is that the primary assay that you’ve run on all of them isn’t very predictive. You get much more signal from all the secondary assays. But you’ve only run those on a small subset of the overall library. In fact, you’ve run different secondary assays on different (small) subsets.
So how do you build a model that uses this data when you have different data about each data point?
Biological foundation models have the potential to start solving this problem. And image generation models provide an example of one way models like this can connect these different modalities.
With an image generation model like DALL-E or Stable Diffusion, you type in a description of an image and the model generates it for you. These models are trained on large collections of images with text labels (written by people) attached.
At first this feels like supervised learning problem because you have data points (images) with labels (text). But there are two problems with this: 1) Free text isn’t a good label. As I described last week, it needs its own model to interpret. Plus, 2) the text descriptions are all positive labels. The labels tell us what a picture of sail boat is, but they don’t tell us what an image that isn’t of a sail boat looks like.
This second one may seem pedantic, but it’s really important: Machine learning models need to be wrong sometimes, and know they’re wrong, so they can adjust their weights to be (more) right the next time. Without negative examples, they can’t do that.
But it turns out there’s a very clever trick to get around this, called a Generative Adversarial Network (GAN). A GAN consists of two models: The first one, called the generator, generates images from text descriptions. The second one, called the differentiator, predicts whether an image is real and associated with a correct label, or was created by the generator and/or associated with a fake label.
This setup gives both models something to learn on: If the differentiator makes a bad prediction, it updates it weights accordingly. But if it correctly identifies an image as created by the generator, then the generator has to update its own weights.
If you let these two models go at it for long enough, with a big enough collection of images and descriptions, the generator eventually gets pretty good.
So that takes care of the second problem, but what about the first? Well, to turn the text description into something usable, we use an LLM. As I touched on last week, the LLM turns the text description into a vector representing the underlying concepts expressed by the string. So the generator model actually generates images from these embedding vectors, not directly from the text. And the discriminator takes a vector amalgamated from the embedding vector and the image.
Now, obviously, there’s a lot more to what these models do with the image and the embedding vector. And, in fact, there are a lot of technical details left out from the little bit that I did cover. But this is already a lot longer than I like. So hopefully this was enough to spark some ideas for what this kind of thing might look like for biological data.
We’ll start swinging back in that direction next week.