This is an unofficial Pytorch implementation of the paper Reference-guided structure-aware deep sketch colorization for cartoons
[Paper]
This repository contains the implementation of a sketch colorization network. The output image of this network should have the colors of a reference image, while preserving the structure of the input sketch image.
This project involves two different trainings. First, the main model is trained using the same dataset as the original authors, and the resulting model is then fine-tuned using a custom dataset. This last model has a slight modification to improve results.
The main model consists of three components.
- A color style extractor network to extract the colors from a reference image, using 4 ResNet blocks and global average pooling. The output is the color style code, which is a vector with a dimension of 256.
- A colorization network to generate multi-scale output images by combining an input sketch with the colors from the extractor. This network is based on the U-Net structure, with a downscaling sketch encoder and an upscaling style-content fusion decoder.
- A multi-scale discriminator to improve the results by ensuring a realistic colorization.
Please refer to the original paper for more details.
After training the main model, I wanted to use the results in a more complex scenario, with sketches that require more color (like background or objects), with the intention to fine-tune the model to colorize sketches from a specific manga. For this experiment, the One Punch Man manga was chosen.
For the fine-tune model, the way in which colors are extracted and combined was modified. In this setting, the color style extractor outputs 4 different color style codes, each one coming from the last four ResNet blocks; then, each one of these representations is fused with a decoder block from the colorization network. This change is due to the fact that more vivid colors were observed under this configuration.
- Anime Sketch Colorization Pair: Publicly available dataset with 17,769 images containing a sketch and its color version. 14,224 images are used for training and 3,545 for evaluation (testing).
- One Punch Man Manga Colorization: Custom dataset with 3,546 pairs of images (color and sketch). 2,837 pairs are used for training, 177 for validation and 532 for testing. The color images were obtained scraping the One Punch Man subreddit and the corresponding sketches were extracted using Anime2Sketch.
- Python 3
- Pytorch >= 1.6 (torch and torchvision)
- Numpy
- PIQ (for evaluation metrics)
- Download the desired dataset and extract it to its corresponding folder in
./data
. - Change the settings for training or testing the model in
./configs/training.cfg
. - Run
main.py
script.
The training.cfg
file manages the instructions for the execution of the code, it contains several parameters such as the batch size, learning rate and other constants. To train or test the model, the keys train_model
and test_model
in [Commons], must be set to True or False, depending on the case. For the Main Model or the Fine-Tune Model, more parameters are required.
Configuration for the Main Model:
[Training]
dataset_name = anime_sketch_colorization_pair
split_img = True
fixed_size_imgs = True
fine_tune = False
[Testing]
dataset_name = anime_sketch_colorization_pair
split_img = True
fixed_size_imgs = True
Configuration for the Fine-Tune Model:
[Training]
dataset_name = opm_colorization_dataset
split_img = False
fixed_size_imgs = False
fine_tune = True
[Testing]
dataset_name = opm_colorization_dataset
split_img = False
fixed_size_imgs = False
For training the Main Model, the learning rate was set to 1e-04, but for training the Fine-Tune Model, the best results were observed using a learning rate of 1e-03.
The results obtained in this implementation are not as good as the ones reported in the paper, it is observed that some colors do not propagate correctly. The model was trained for 100 epochs, as the colorization did not improve after this point.
During training, the adversarial loss dropped close to zero after epoch 25, this may be one of the reasons why the results are not excellent. Also, because there is no validation data, it's difficult to identify if overfitting was present.
Each metric is evaluated in two ways, using the correct paired sketch/reference and using a differente reference, as shown in the images below. The results obtained in paired sketch/reference are similar to the reference paper.
PSNR | SSIM | FID | ||||
---|---|---|---|---|---|---|
Paired sketch/reference | Different reference | Paired sketch/reference | Different reference | Paired sketch/reference | Different reference | |
Reference | 22.43 | 17.25 | 0.85 | 0.74 | - | 27.99 |
Obtained | 20.11 | 13.41 | 0.84 | 0.56 | 155.46 | 174.0 |
Reference | ![]() |
Input | |
Result | |
Reference | ![]() |
Input | |
Result |
The fine-tune model was trained for 60 epochs, as the performance of the validation set did not improve after this point. It is worth mentioning that the adversarial loss remained relatively constant and did not drop to zero.
The results are not perfect, but the model shows potential, considering that this task is much more difficult and the dataset is smaller.
All metrics are evaluated using the correct paired sketch/reference.
PSNR | SSIM | FID | |
---|---|---|---|
Obtained | 19.542 | 0.748 | 151.057 |