Skip to content

Commit 79ea8eb

Browse files
authored
[BUG] fixes in kadinsky pipeline (#11080)
* bug fix kadinsky pipeline
1 parent e7f3a73 commit 79ea8eb

File tree

5 files changed

+52
-125
lines changed

5 files changed

+52
-125
lines changed

src/diffusers/image_processor.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
vae_scale_factor: int = 8,
117117
vae_latent_channels: int = 4,
118118
resample: str = "lanczos",
119+
reducing_gap: int = None,
119120
do_normalize: bool = True,
120121
do_binarize: bool = False,
121122
do_convert_rgb: bool = False,
@@ -498,7 +499,11 @@ def resize(
498499
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
499500
if isinstance(image, PIL.Image.Image):
500501
if resize_mode == "default":
501-
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
502+
image = image.resize(
503+
(width, height),
504+
resample=PIL_INTERPOLATION[self.config.resample],
505+
reducing_gap=self.config.reducing_gap,
506+
)
502507
elif resize_mode == "fill":
503508
image = self._resize_and_fill(image, width, height)
504509
elif resize_mode == "crop":

src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py

+13-20
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414

1515
from typing import Callable, List, Optional, Union
1616

17-
import numpy as np
1817
import PIL.Image
1918
import torch
20-
from PIL import Image
2119
from transformers import (
2220
XLMRobertaTokenizer,
2321
)
2422

23+
from ...image_processor import VaeImageProcessor
2524
from ...models import UNet2DConditionModel, VQModel
2625
from ...schedulers import DDIMScheduler
2726
from ...utils import (
@@ -95,15 +94,6 @@ def get_new_h_w(h, w, scale_factor=8):
9594
return new_h * scale_factor, new_w * scale_factor
9695

9796

98-
def prepare_image(pil_image, w=512, h=512):
99-
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
100-
arr = np.array(pil_image.convert("RGB"))
101-
arr = arr.astype(np.float32) / 127.5 - 1
102-
arr = np.transpose(arr, [2, 0, 1])
103-
image = torch.from_numpy(arr).unsqueeze(0)
104-
return image
105-
106-
10797
class KandinskyImg2ImgPipeline(DiffusionPipeline):
10898
"""
10999
Pipeline for image-to-image generation using Kandinsky
@@ -143,7 +133,16 @@ def __init__(
143133
scheduler=scheduler,
144134
movq=movq,
145135
)
146-
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
136+
self.movq_scale_factor = (
137+
2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
138+
)
139+
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
140+
self.image_processor = VaeImageProcessor(
141+
vae_scale_factor=self.movq_scale_factor,
142+
vae_latent_channels=movq_latent_channels,
143+
resample="bicubic",
144+
reducing_gap=1,
145+
)
147146

148147
def get_timesteps(self, num_inference_steps, strength, device):
149148
# get the original timestep using init_timestep
@@ -417,7 +416,7 @@ def __call__(
417416
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
418417
)
419418

420-
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
419+
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
421420
image = image.to(dtype=prompt_embeds.dtype, device=device)
422421

423422
latents = self.movq.encode(image)["latents"]
@@ -498,13 +497,7 @@ def __call__(
498497
if output_type not in ["pt", "np", "pil"]:
499498
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
500499

501-
if output_type in ["np", "pil"]:
502-
image = image * 0.5 + 0.5
503-
image = image.clamp(0, 1)
504-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
505-
506-
if output_type == "pil":
507-
image = self.numpy_to_pil(image)
500+
image = self.image_processor.postprocess(image, output_type)
508501

509502
if not return_dict:
510503
return (image,)

src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py

+11-33
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414

1515
from typing import Callable, List, Optional, Union
1616

17-
import numpy as np
1817
import PIL.Image
1918
import torch
20-
from PIL import Image
2119

20+
from ...image_processor import VaeImageProcessor
2221
from ...models import UNet2DConditionModel, VQModel
2322
from ...schedulers import DDPMScheduler
2423
from ...utils import (
@@ -105,27 +104,6 @@
105104
"""
106105

107106

108-
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
109-
def downscale_height_and_width(height, width, scale_factor=8):
110-
new_height = height // scale_factor**2
111-
if height % scale_factor**2 != 0:
112-
new_height += 1
113-
new_width = width // scale_factor**2
114-
if width % scale_factor**2 != 0:
115-
new_width += 1
116-
return new_height * scale_factor, new_width * scale_factor
117-
118-
119-
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
120-
def prepare_image(pil_image, w=512, h=512):
121-
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
122-
arr = np.array(pil_image.convert("RGB"))
123-
arr = arr.astype(np.float32) / 127.5 - 1
124-
arr = np.transpose(arr, [2, 0, 1])
125-
image = torch.from_numpy(arr).unsqueeze(0)
126-
return image
127-
128-
129107
class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
130108
"""
131109
Pipeline for image-to-image generation using Kandinsky
@@ -157,7 +135,14 @@ def __init__(
157135
scheduler=scheduler,
158136
movq=movq,
159137
)
160-
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
138+
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
139+
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
140+
self.image_processor = VaeImageProcessor(
141+
vae_scale_factor=movq_scale_factor,
142+
vae_latent_channels=movq_latent_channels,
143+
resample="bicubic",
144+
reducing_gap=1,
145+
)
161146

162147
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
163148
def get_timesteps(self, num_inference_steps, strength, device):
@@ -316,15 +301,14 @@ def __call__(
316301
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
317302
)
318303

319-
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
304+
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
320305
image = image.to(dtype=image_embeds.dtype, device=device)
321306

322307
latents = self.movq.encode(image)["latents"]
323308
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
324309
self.scheduler.set_timesteps(num_inference_steps, device=device)
325310
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
326311
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
327-
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
328312
latents = self.prepare_latents(
329313
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
330314
)
@@ -379,13 +363,7 @@ def __call__(
379363
if output_type not in ["pt", "np", "pil"]:
380364
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
381365

382-
if output_type in ["np", "pil"]:
383-
image = image * 0.5 + 0.5
384-
image = image.clamp(0, 1)
385-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
386-
387-
if output_type == "pil":
388-
image = self.numpy_to_pil(image)
366+
image = self.image_processor.postprocess(image, output_type)
389367

390368
if not return_dict:
391369
return (image,)

src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py

+11-39
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414

1515
from typing import Callable, Dict, List, Optional, Union
1616

17-
import numpy as np
1817
import PIL.Image
1918
import torch
20-
from PIL import Image
2119

20+
from ...image_processor import VaeImageProcessor
2221
from ...models import UNet2DConditionModel, VQModel
2322
from ...schedulers import DDPMScheduler
2423
from ...utils import deprecate, is_torch_xla_available, logging
@@ -76,27 +75,6 @@
7675
"""
7776

7877

79-
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
80-
def downscale_height_and_width(height, width, scale_factor=8):
81-
new_height = height // scale_factor**2
82-
if height % scale_factor**2 != 0:
83-
new_height += 1
84-
new_width = width // scale_factor**2
85-
if width % scale_factor**2 != 0:
86-
new_width += 1
87-
return new_height * scale_factor, new_width * scale_factor
88-
89-
90-
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
91-
def prepare_image(pil_image, w=512, h=512):
92-
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
93-
arr = np.array(pil_image.convert("RGB"))
94-
arr = arr.astype(np.float32) / 127.5 - 1
95-
arr = np.transpose(arr, [2, 0, 1])
96-
image = torch.from_numpy(arr).unsqueeze(0)
97-
return image
98-
99-
10078
class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
10179
"""
10280
Pipeline for image-to-image generation using Kandinsky
@@ -129,7 +107,14 @@ def __init__(
129107
scheduler=scheduler,
130108
movq=movq,
131109
)
132-
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
110+
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
111+
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
112+
self.image_processor = VaeImageProcessor(
113+
vae_scale_factor=movq_scale_factor,
114+
vae_latent_channels=movq_latent_channels,
115+
resample="bicubic",
116+
reducing_gap=1,
117+
)
133118

134119
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
135120
def get_timesteps(self, num_inference_steps, strength, device):
@@ -319,15 +304,14 @@ def __call__(
319304
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
320305
)
321306

322-
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
307+
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
323308
image = image.to(dtype=image_embeds.dtype, device=device)
324309

325310
latents = self.movq.encode(image)["latents"]
326311
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
327312
self.scheduler.set_timesteps(num_inference_steps, device=device)
328313
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
329314
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
330-
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
331315
latents = self.prepare_latents(
332316
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
333317
)
@@ -383,21 +367,9 @@ def __call__(
383367
if XLA_AVAILABLE:
384368
xm.mark_step()
385369

386-
if output_type not in ["pt", "np", "pil", "latent"]:
387-
raise ValueError(
388-
f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
389-
)
390-
391370
if not output_type == "latent":
392-
# post-processing
393371
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
394-
if output_type in ["np", "pil"]:
395-
image = image * 0.5 + 0.5
396-
image = image.clamp(0, 1)
397-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
398-
399-
if output_type == "pil":
400-
image = self.numpy_to_pil(image)
372+
image = self.image_processor.postprocess(image, output_type)
401373
else:
402374
image = latents
403375

src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py

+11-32
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import inspect
22
from typing import Callable, Dict, List, Optional, Union
33

4-
import numpy as np
54
import PIL
65
import PIL.Image
76
import torch
87
from transformers import T5EncoderModel, T5Tokenizer
98

9+
from ...image_processor import VaeImageProcessor
1010
from ...loaders import StableDiffusionLoraLoaderMixin
1111
from ...models import Kandinsky3UNet, VQModel
1212
from ...schedulers import DDPMScheduler
@@ -53,24 +53,6 @@
5353
"""
5454

5555

56-
def downscale_height_and_width(height, width, scale_factor=8):
57-
new_height = height // scale_factor**2
58-
if height % scale_factor**2 != 0:
59-
new_height += 1
60-
new_width = width // scale_factor**2
61-
if width % scale_factor**2 != 0:
62-
new_width += 1
63-
return new_height * scale_factor, new_width * scale_factor
64-
65-
66-
def prepare_image(pil_image):
67-
arr = np.array(pil_image.convert("RGB"))
68-
arr = arr.astype(np.float32) / 127.5 - 1
69-
arr = np.transpose(arr, [2, 0, 1])
70-
image = torch.from_numpy(arr).unsqueeze(0)
71-
return image
72-
73-
7456
class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
7557
model_cpu_offload_seq = "text_encoder->movq->unet->movq"
7658
_callback_tensor_inputs = [
@@ -94,6 +76,14 @@ def __init__(
9476
self.register_modules(
9577
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
9678
)
79+
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
80+
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
81+
self.image_processor = VaeImageProcessor(
82+
vae_scale_factor=movq_scale_factor,
83+
vae_latent_channels=movq_latent_channels,
84+
resample="bicubic",
85+
reducing_gap=1,
86+
)
9787

9888
def get_timesteps(self, num_inference_steps, strength, device):
9989
# get the original timestep using init_timestep
@@ -566,7 +556,7 @@ def __call__(
566556
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
567557
)
568558

569-
image = torch.cat([prepare_image(i) for i in image], dim=0)
559+
image = torch.cat([self.image_processor.preprocess(i) for i in image], dim=0)
570560
image = image.to(dtype=prompt_embeds.dtype, device=device)
571561
# 4. Prepare timesteps
572562
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -630,20 +620,9 @@ def __call__(
630620
xm.mark_step()
631621

632622
# post-processing
633-
if output_type not in ["pt", "np", "pil", "latent"]:
634-
raise ValueError(
635-
f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
636-
)
637623
if not output_type == "latent":
638624
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
639-
640-
if output_type in ["np", "pil"]:
641-
image = image * 0.5 + 0.5
642-
image = image.clamp(0, 1)
643-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
644-
645-
if output_type == "pil":
646-
image = self.numpy_to_pil(image)
625+
image = self.image_processor.postprocess(image, output_type)
647626
else:
648627
image = latents
649628

0 commit comments

Comments
 (0)