Building components of Stable Diffusion

CV
Deep Learning
Author

Vishwas Pai

Published

October 30, 2022

In this we will try to build the Stable Diffusion pipeline we saw in previous blog.

Just a recap ,

What are the blocks till now we saw?

1. Model which extracts noise and outputs less noise image.

For this we generally use specific type of models called Unet.

unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").to(DEVICE)

2. Model which embeds the text into embedding.

For this we use model called CLIP. As we are dealing with text, we need to load tokenizer as well along with the model.

tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(DEVICE)

3. Model which compresses and de-compressed images. ( output of these compression models are called as latents).

The model which are used for this is called as VAE. Image compression is used during training. During inference as we are starting from noise, we will not be using VAE.

vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(DEVICE)

Lets understand what happens inside StableDiffusionPipeline.

Our input will be text (prompt).

prompt = ["man riding blue bike with red dress"]

Now we will tokenize and convert it into embedding using CLIP tokenizer and CLIP encoder respectively.

text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(DEVICE))[0]

Now we need noise also as input. (input latent).

We will use random weight with proper input shape which is required for Unet model which we have loaded.

latents = torch.randn((batch_size, unet.in_channels, height // 8, width // 8))
latents = latents.to(DEVICE)

Now we have both the inputs, we can give it to Unet to generate output latent. One issue with this is, generally model tends to overfit to the text given. This leads to less creative and more aligned with text outputs.

So what we do is, generate embedding for empty string and concat with text embedding. The intiution is, we will generate image completly based on input , one more completely creative. We will average both with some condition to get best of both.

final_prediction = creative_predection + guidance_scale * predection_based_on_text

This guidance_scale param lets user to control. The smaller value tends to generate more creative and less accurate and the higher value generated accurate but less creative. According to documentation 7.5 is the best between both.

Lets code this

#We will remove noise by X times
for i, t in enumerate(tqdm(scheduler.timesteps)):
    # As we generate for both creative and text , we will have 2 noise inputs
    input = torch.cat([latents] * 2)
    input = scheduler.scale_model_input(input, t)

    # Given to unet for prediction
    with torch.no_grad(): 
        pred = unet(input, t, encoder_hidden_states=text_embeddings).sample

    # perform guidance
    pred_uncond, pred_text = pred.chunk(2)
    pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)

    # compute the "previous" noisy sample
    latents = scheduler.step(pred, t, latents).prev_sample

Now the last step. We will convert over latent into proper image using VAE.

with torch.no_grad():
    image = vae.decode(1 / 0.18215 * latents).sample

In the next post we will be understanding and building negetive prompt in stable diffusion.