import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Upsample



class GLU(nn.Module):
    def __init__(self):
        super(GLU, self).__init__()
    def forward(self, x):
        nc = x.size(1)
        assert nc % 2 == 0, 'channels dont divide 2!'
        nc = int(nc/2)
        return x[:, :nc] * torch.sigmoid(x[:, nc:])


def upBlock(in_planes, out_planes):
    block = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
                          nn.Conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
                          nn.BatchNorm2d(out_planes*2),
                          GLU())
    return block


def sameBlock(in_planes, out_planes):
    block = nn.Sequential(nn.Conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
                          nn.BatchNorm2d(out_planes*2),
                          GLU())
    return block


class ResBlock(nn.Module):
    def __init__(self, channel_num):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(nn.Conv2d(channel_num, channel_num*2, 3, 1, 1, bias=False),
                                   nn.BatchNorm2d(channel_num*2),
                                   GLU(),
                                   nn.Conv2d(channel_num, channel_num, 3, 1, 1, bias=False),
                                   nn.BatchNorm2d(channel_num))

    def forward(self, x):
        return x + self.block(x)


def multi_ResBlock(num_residual, ngf):
    layers = []
    for _ in range(num_residual):
        layers.append(ResBlock(ngf))
    return nn.Sequential(*layers)


def encode_img(ndf=64, in_c=3):
    layers = nn.Sequential(
        nn.Conv2d(in_c, ndf, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 2),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 4),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf * 8, ndf * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf * 8, ndf * 8, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ndf * 8),
        nn.LeakyReLU(0.2, inplace=True)
    )
    return layers