Skip to content
Snippets Groups Projects
Commit 4ef17ff0 authored by 柏灌's avatar 柏灌
Browse files

add sr model

parent 7f0b3817
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,9 @@ _<sup>2</sup>[Department of Computing, The Hong Kong Polytechnic University](htt
## News
(2021-07-06) The training code will be released soon. Stay tuned.
(2021-10-11) The Colab demo for GPEN is available now <a href="https://colab.research.google.com/drive/1fPUsJCpQipp2Z5B5GbEXqpBGsMp-nvjm?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
(2021-10-11) The Colab demo for GPEN is available now <a href="https://colab.research.google.com/drive/1fPUsJCpQipp2Z5B5GbEXqpBGsMp-nvjm?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
(2021-10-22) GPEN can now work with SR methods. A SR model trained by myself is released. Replace it with your own model is necessary.
## Usage
......@@ -45,11 +47,11 @@ cd GPEN
```
- Download RetinaFace model and our pre-trained model (not our best model due to commercial issues) and put them into ``weights/``.
[RetinaFace-R50](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth) | [GPEN-BFR-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512.pth) | [GPEN-BFR-512-D](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512-D.pth) | [GPEN-BFR-256](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-256.pth) | [GPEN-Colorization-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Colorization-1024.pth) | [GPEN-Inpainting-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Inpainting-1024.pth) | [GPEN-Seg2face-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Seg2face-512.pth)
[RetinaFace-R50](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth) | [GPEN-BFR-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512.pth) | [GPEN-BFR-512-D](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512-D.pth) | [GPEN-BFR-256](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-256.pth) | [GPEN-Colorization-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Colorization-1024.pth) | [GPEN-Inpainting-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Inpainting-1024.pth) | [GPEN-Seg2face-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Seg2face-512.pth) | [rrdb_realesrnet_psnr](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/rrdb_realesrnet_psnr.pth)
- Restore face images:
```bash
python face_enhancement.py
python face_enhancement.py --model GPEN-BFR-512 --size 512 --channel_multiplier 2 --narrow 1 --use_sr --indir examples/imgs --outdir examples/outs-BFR
```
- Colorize faces:
......@@ -84,7 +86,7 @@ If our work is useful for your research, please consider citing:
© Alibaba, 2021. For academic and non-commercial use only.
## Acknowledgments
We borrow some codes from [Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface) and [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch).
We borrow some codes from [Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface), [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch), and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN).
## Contact
If you have any questions or suggestions about this paper, feel free to reach me at yangtao9009@gmail.com.
......@@ -16,3 +16,6 @@ add_path(path)
path = osp.join(this_dir, 'face_model')
add_path(path)
path = osp.join(this_dir, 'sr_model')
add_path(path)
\ No newline at end of file
......@@ -6,18 +6,22 @@ import os
import cv2
import glob
import time
import argparse
import numpy as np
from PIL import Image
import __init_paths
from retinaface.retinaface_detection import RetinaFaceDetection
from face_model.face_gan import FaceGAN
from sr_model.real_esrnet import RealESRNet
from align_faces import warp_and_crop_face, get_reference_facial_points
from skimage import transform as tf
class FaceEnhancement(object):
def __init__(self, base_dir='./', size=512, model=None, channel_multiplier=2, narrow=1):
def __init__(self, base_dir='./', size=512, model=None, use_sr=True, sr_model=None, channel_multiplier=2, narrow=1):
self.facedetector = RetinaFaceDetection(base_dir)
self.facegan = FaceGAN(base_dir, size, model, channel_multiplier, narrow)
self.srmodel = RealESRNet(base_dir, sr_model)
self.use_sr = use_sr
self.size = size
self.threshold = 0.9
......@@ -40,6 +44,11 @@ class FaceEnhancement(object):
(self.size, self.size), inner_padding_factor, outer_padding, default_square)
def process(self, img):
if self.use_sr:
img_sr = self.srmodel.process(img)
if img_sr is not None:
img = cv2.resize(img, img_sr.shape[:2][::-1])
facebs, landms = self.facedetector.detect(img)
orig_faces, enhanced_faces = [], []
......@@ -75,37 +84,51 @@ class FaceEnhancement(object):
full_img[np.where(mask>0)] = tmp_img[np.where(mask>0)]
full_mask = full_mask[:, :, np.newaxis]
img = cv2.convertScaleAbs(img*(1-full_mask) + full_img*full_mask)
if self.use_sr and img_sr is not None:
img = cv2.convertScaleAbs(img_sr*(1-full_mask) + full_img*full_mask)
else:
img = cv2.convertScaleAbs(img*(1-full_mask) + full_img*full_mask)
return img, orig_faces, enhanced_faces
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='GPEN-BFR-512', help='GPEN model')
parser.add_argument('--size', type=int, default=512, help='resolution of GPEN')
parser.add_argument('--channel_multiplier', type=int, default=2, help='channel multiplier of GPEN')
parser.add_argument('--narrow', type=float, default=1, help='channel narrow scale')
parser.add_argument('--use_sr', action='store_true', help='use sr or not')
parser.add_argument('--sr_model', type=str, default='rrdb_realesrnet_psnr', help='SR model')
parser.add_argument('--sr_scale', type=int, default=2, help='SR scale')
parser.add_argument('--indir', type=str, default='examples/imgs', help='input folder')
parser.add_argument('--outdir', type=str, default='results/outs-BFR', help='output folder')
args = parser.parse_args()
#model = {'name':'GPEN-BFR-512', 'size':512, 'channel_multiplier':2, 'narrow':1}
model = {'name':'GPEN-BFR-256', 'size':256, 'channel_multiplier':1, 'narrow':0.5}
#model = {'name':'GPEN-BFR-256', 'size':256, 'channel_multiplier':1, 'narrow':0.5}
indir = 'examples/imgs'
outdir = 'examples/outs-BFR'
os.makedirs(outdir, exist_ok=True)
os.makedirs(args.outdir, exist_ok=True)
faceenhancer = FaceEnhancement(size=model['size'], model=model['name'], channel_multiplier=model['channel_multiplier'], narrow=model['narrow'])
faceenhancer = FaceEnhancement(size=args.size, model=args.model, use_sr=args.use_sr, sr_model=args.sr_model, channel_multiplier=args.channel_multiplier, narrow=args.narrow)
files = sorted(glob.glob(os.path.join(indir, '*.*g')))
files = sorted(glob.glob(os.path.join(args.indir, '*.*g')))
for n, file in enumerate(files[:]):
filename = os.path.basename(file)
im = cv2.imread(file, cv2.IMREAD_COLOR) # BGR
if not isinstance(im, np.ndarray): print(filename, 'error'); continue
im = cv2.resize(im, (0,0), fx=2, fy=2)
#im = cv2.resize(im, (0,0), fx=2, fy=2) # optional
img, orig_faces, enhanced_faces = faceenhancer.process(im)
cv2.imwrite(os.path.join(outdir, '.'.join(filename.split('.')[:-1])+'_COMP.jpg'), np.hstack((im, img)))
cv2.imwrite(os.path.join(outdir, '.'.join(filename.split('.')[:-1])+'_GPEN.jpg'), img)
im = cv2.resize(im, img.shape[:2][::-1])
cv2.imwrite(os.path.join(args.outdir, '.'.join(filename.split('.')[:-1])+'_COMP.jpg'), np.hstack((im, img)))
cv2.imwrite(os.path.join(args.outdir, '.'.join(filename.split('.')[:-1])+'_GPEN.jpg'), img)
for m, (ef, of) in enumerate(zip(enhanced_faces, orig_faces)):
of = cv2.resize(of, ef.shape[:2])
cv2.imwrite(os.path.join(outdir, '.'.join(filename.split('.')[:-1])+'_face%02d'%m+'.jpg'), np.hstack((of, ef)))
cv2.imwrite(os.path.join(args.outdir, '.'.join(filename.split('.')[:-1])+'_face%02d'%m+'.jpg'), np.hstack((of, ef)))
if n%10==0: print(n, filename)
import math
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm
@torch.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if not isinstance(module_list, list):
module_list = [module_list]
for module in module_list:
for m in module.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, _BatchNorm):
init.constant_(m.weight, 1)
if m.bias is not None:
m.bias.data.fill_(bias_fill)
def make_layer(basic_block, num_basic_block, **kwarg):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block.
num_basic_block (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_basic_block):
layers.append(basic_block(**kwarg))
return nn.Sequential(*layers)
class ResidualBlockNoBN(nn.Module):
"""Residual block without BN.
It has a style of:
---Conv-ReLU-Conv-+-
|________________|
Args:
num_feat (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.
pytorch_init (bool): If set to True, use pytorch default init,
otherwise, use default_init_weights. Default: False.
"""
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
super(ResidualBlockNoBN, self).__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.relu = nn.ReLU(inplace=True)
if not pytorch_init:
default_init_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = self.conv2(self.relu(self.conv1(x)))
return identity + out * self.res_scale
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. '
'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b, c, hh, hw = x.size()
out_channel = c * (scale**2)
assert hh % scale == 0 and hw % scale == 0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
\ No newline at end of file
import os
import torch
import numpy as np
from rrdbnet_arch import RRDBNet
from torch.nn import functional as F
class RealESRNet(object):
def __init__(self, base_dir='./', model=None, scale=2):
self.base_dir = base_dir
self.scale = scale
self.load_srmodel(base_dir, model)
def load_srmodel(self, base_dir, model):
self.srmodel = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=32, num_block=23, num_grow_ch=32, scale=self.scale)
if model is None:
loadnet = torch.load(os.path.join(self.base_dir, 'weights', 'rrdb_realesrnet_psnr.pth'))
else:
loadnet = torch.load(os.path.join(self.base_dir, 'weights', model+'.pth'))
self.srmodel.load_state_dict(loadnet['params_ema'], strict=True)
self.srmodel.eval()
self.srmodel = self.srmodel.cuda()
def process(self, img):
img = img.astype(np.float32) / 255.
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img = img.unsqueeze(0).cuda()
if self.scale == 2:
mod_scale = 2
elif self.scale == 1:
mod_scale = 4
else:
mod_scale = None
if mod_scale is not None:
h_pad, w_pad = 0, 0
_, _, h, w = img.size()
if (h % mod_scale != 0):
h_pad = (mod_scale - h % mod_scale)
if (w % mod_scale != 0):
w_pad = (mod_scale - w % mod_scale)
img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')
try:
with torch.no_grad():
output = self.srmodel(img)
# remove extra pad
if mod_scale is not None:
_, _, h, w = output.size()
output = output[:, :, 0:h - h_pad, 0:w - w_pad]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)
return output
except:
return None
\ No newline at end of file
import torch
from torch import nn as nn
from torch.nn import functional as F
from arch_util import default_init_weights, make_layer, pixel_unshuffle
class ResidualDenseBlock(nn.Module):
"""Residual Dense Block.
Used in RRDB block in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# Emperically, we use 0.2 to scale the residual for better performance
return x5 * 0.2 + x
class RRDB(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# Emperically, we use 0.2 to scale the residual for better performance
return out * 0.2 + x
class RRDBNet(nn.Module):
"""Networks consisting of Residual in Residual Dense Block, which is used
in ESRGAN.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
We extend ESRGAN for scale x2 and scale x1.
Note: This is one option for scale 1, scale 2 in RRDBNet.
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64
num_block (int): Block number in the trunk network. Defaults: 23
num_grow_ch (int): Channels for each growth. Default: 32.
"""
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
super(RRDBNet, self).__init__()
self.scale = scale
if scale == 2:
num_in_ch = num_in_ch * 4
elif scale == 1:
num_in_ch = num_in_ch * 16
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
if self.scale == 2:
feat = pixel_unshuffle(x, scale=2)
elif self.scale == 1:
feat = pixel_unshuffle(x, scale=4)
else:
feat = x
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment