|
14 | 14 |
|
15 | 15 | from typing import Callable, Dict, List, Optional, Union
|
16 | 16 |
|
17 |
| -import numpy as np |
18 | 17 | import PIL.Image
|
19 | 18 | import torch
|
20 |
| -from PIL import Image |
21 | 19 |
|
| 20 | +from ...image_processor import VaeImageProcessor |
22 | 21 | from ...models import UNet2DConditionModel, VQModel
|
23 | 22 | from ...schedulers import DDPMScheduler
|
24 | 23 | from ...utils import deprecate, is_torch_xla_available, logging
|
|
76 | 75 | """
|
77 | 76 |
|
78 | 77 |
|
79 |
| -# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width |
80 |
| -def downscale_height_and_width(height, width, scale_factor=8): |
81 |
| - new_height = height // scale_factor**2 |
82 |
| - if height % scale_factor**2 != 0: |
83 |
| - new_height += 1 |
84 |
| - new_width = width // scale_factor**2 |
85 |
| - if width % scale_factor**2 != 0: |
86 |
| - new_width += 1 |
87 |
| - return new_height * scale_factor, new_width * scale_factor |
88 |
| - |
89 |
| - |
90 |
| -# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image |
91 |
| -def prepare_image(pil_image, w=512, h=512): |
92 |
| - pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) |
93 |
| - arr = np.array(pil_image.convert("RGB")) |
94 |
| - arr = arr.astype(np.float32) / 127.5 - 1 |
95 |
| - arr = np.transpose(arr, [2, 0, 1]) |
96 |
| - image = torch.from_numpy(arr).unsqueeze(0) |
97 |
| - return image |
98 |
| - |
99 |
| - |
100 | 78 | class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
|
101 | 79 | """
|
102 | 80 | Pipeline for image-to-image generation using Kandinsky
|
@@ -129,7 +107,14 @@ def __init__(
|
129 | 107 | scheduler=scheduler,
|
130 | 108 | movq=movq,
|
131 | 109 | )
|
132 |
| - self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) |
| 110 | + movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8 |
| 111 | + movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4 |
| 112 | + self.image_processor = VaeImageProcessor( |
| 113 | + vae_scale_factor=movq_scale_factor, |
| 114 | + vae_latent_channels=movq_latent_channels, |
| 115 | + resample="bicubic", |
| 116 | + reducing_gap=1, |
| 117 | + ) |
133 | 118 |
|
134 | 119 | # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
|
135 | 120 | def get_timesteps(self, num_inference_steps, strength, device):
|
@@ -319,15 +304,14 @@ def __call__(
|
319 | 304 | f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
|
320 | 305 | )
|
321 | 306 |
|
322 |
| - image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) |
| 307 | + image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0) |
323 | 308 | image = image.to(dtype=image_embeds.dtype, device=device)
|
324 | 309 |
|
325 | 310 | latents = self.movq.encode(image)["latents"]
|
326 | 311 | latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
|
327 | 312 | self.scheduler.set_timesteps(num_inference_steps, device=device)
|
328 | 313 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
329 | 314 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
330 |
| - height, width = downscale_height_and_width(height, width, self.movq_scale_factor) |
331 | 315 | latents = self.prepare_latents(
|
332 | 316 | latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
|
333 | 317 | )
|
@@ -383,21 +367,9 @@ def __call__(
|
383 | 367 | if XLA_AVAILABLE:
|
384 | 368 | xm.mark_step()
|
385 | 369 |
|
386 |
| - if output_type not in ["pt", "np", "pil", "latent"]: |
387 |
| - raise ValueError( |
388 |
| - f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}" |
389 |
| - ) |
390 |
| - |
391 | 370 | if not output_type == "latent":
|
392 |
| - # post-processing |
393 | 371 | image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
394 |
| - if output_type in ["np", "pil"]: |
395 |
| - image = image * 0.5 + 0.5 |
396 |
| - image = image.clamp(0, 1) |
397 |
| - image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
398 |
| - |
399 |
| - if output_type == "pil": |
400 |
| - image = self.numpy_to_pil(image) |
| 372 | + image = self.image_processor.postprocess(image, output_type) |
401 | 373 | else:
|
402 | 374 | image = latents
|
403 | 375 |
|
|
0 commit comments