from typing import Union, List
from typing_extensions import Literal

from config.config import Config


class Layer(Config):
    activation: str = None
    activation_params: dict = None


class Dense(Layer):
    n_units: int
    bias: bool = True


class Conv(Layer):
    out_channels: int
    kernel_size: Union[List[int], int]
    stride: Union[List[int], int] = 1
    padding: str = "valid"


class ConvTranspose(Layer):
    out_channels: int
    kernel_size: Union[List[int], int]
    padding: str = "valid"


class MaxPool(Layer):
    kernel_size: Union[List[int], int]
    padding: str = "valid"


class UpSample(Layer):
    kernel_size: int
    mode: str = "nearest"


class AdaptiveAvgPool(Layer):
    output_size: List[int] = [1, 1]


class BatchNormalization(Layer):
    affine: bool = True


class Activation(Layer):
    pass


class Flatten(Layer):
    pass


RELU = Activation(activation="relu")
SOFTMAX = Activation(activation="softmax")
SIGMOID = Activation(activation="sigmoid")


DEFAULT_LAYERS = {
    "cnn_tiny": [
        Conv(out_channels=32, kernel_size=3),
        BatchNormalization(),
        RELU,
        Conv(out_channels=32, kernel_size=3),
        BatchNormalization(),
        RELU,
        Conv(out_channels=32, kernel_size=3),
        BatchNormalization(),
    ],

    "cnn_tiny_decoder": [
        ConvTranspose(out_channels=32, kernel_size=3),
        BatchNormalization(),
        RELU,

        ConvTranspose(out_channels=32, kernel_size=3),
        BatchNormalization(),
        RELU,

        ConvTranspose(out_channels=1, kernel_size=3, activation="sigmoid"),
    ],

    "cnn_small": [
        Conv(out_channels=64, kernel_size=3, activation="relu"),
        Conv(out_channels=64, kernel_size=3, activation=None),
        BatchNormalization(),
        RELU,
        MaxPool(kernel_size=2),

        Conv(out_channels=64, kernel_size=3, activation="relu"),
        Conv(out_channels=64, kernel_size=3, activation=None),
        BatchNormalization(),
        RELU,
        MaxPool(kernel_size=2),
    ],

    "cnn_small_decoder": [
        UpSample(kernel_size=2),

        ConvTranspose(out_channels=64, kernel_size=3, activation="relu"),
        ConvTranspose(out_channels=64, kernel_size=3, activation=None),
        BatchNormalization(),
        RELU,

        UpSample(kernel_size=2),

        ConvTranspose(out_channels=64, kernel_size=3, activation="relu"),
        ConvTranspose(out_channels=1, kernel_size=3, activation="sigmoid"),
    ],

    "cnn_large": [
        Conv(out_channels=64, kernel_size=3, activation="relu"),
        Conv(out_channels=64, kernel_size=3, activation=None),
        BatchNormalization(),
        RELU,
        MaxPool(kernel_size=2),

        Conv(out_channels=64, kernel_size=3, activation="relu"),
        Conv(out_channels=64, kernel_size=3, activation=None),
        BatchNormalization(),
        RELU,
        MaxPool(kernel_size=2),

        Conv(out_channels=64, kernel_size=3, activation="relu"),
        Conv(out_channels=64, kernel_size=3, activation=None),
        BatchNormalization(),
        RELU,
        MaxPool(kernel_size=2),
    ],

    "cnn_large_decoder": [
        UpSample(kernel_size=2),
        ConvTranspose(out_channels=64, kernel_size=3, activation="relu"),
        ConvTranspose(out_channels=64, kernel_size=3, activation=None),
        BatchNormalization(),
        RELU,

        UpSample(kernel_size=2),
        ConvTranspose(out_channels=64, kernel_size=3, activation="relu"),
        ConvTranspose(out_channels=64, kernel_size=3, activation=None),
        BatchNormalization(),
        RELU,

        UpSample(kernel_size=2),
        ConvTranspose(out_channels=64, kernel_size=3, activation="relu"),
        ConvTranspose(out_channels=64, kernel_size=3, activation=None),
        BatchNormalization(),
        RELU,

        ConvTranspose(out_channels=64, kernel_size=3, activation="relu"),
        ConvTranspose(out_channels=1, kernel_size=3, activation="sigmoid"),
    ],

    "dense_5": [
        Dense(n_units=1024),
        BatchNormalization(),
        RELU,

        Dense(n_units=1024),
        BatchNormalization(),
        RELU,

        Dense(n_units=1024),
        BatchNormalization(),
        RELU,

        Dense(n_units=1024),
        BatchNormalization(),
        RELU,

        Dense(n_units=256),
    ],

    "dense_2d": [
        Dense(n_units=8),
        BatchNormalization(),
        RELU,

        Dense(n_units=8, activation="relu"),
    ],

    "dense_2d_decoder": [
        Dense(n_units=8),
        BatchNormalization(),
        RELU,

        Dense(n_units=2)

    ],

    "projection_head": [
        Dense(n_units=-1),
        BatchNormalization(),
        RELU,
        Dense(n_units=-1, activation=None, bias=False),
    ],

    "linear_projection": [
        Dense(n_units=-1, activation=None, bias=False)
    ],
}