# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmseg.models.backbones import CGNet
from mmseg.models.backbones.cgnet import (ContextGuidedBlock,
                                          GlobalContextExtractor)


def test_cgnet_GlobalContextExtractor():
    block = GlobalContextExtractor(16, 16, with_cp=True)
    x = torch.randn(2, 16, 64, 64, requires_grad=True)
    x_out = block(x)
    assert x_out.shape == torch.Size([2, 16, 64, 64])


def test_cgnet_context_guided_block():
    with pytest.raises(AssertionError):
        # cgnet ContextGuidedBlock GlobalContextExtractor channel and reduction
        # constraints.
        ContextGuidedBlock(8, 8)

    # test cgnet ContextGuidedBlock with checkpoint forward
    block = ContextGuidedBlock(
        16, 16, act_cfg=dict(type='PReLU'), with_cp=True)
    assert block.with_cp
    x = torch.randn(2, 16, 64, 64, requires_grad=True)
    x_out = block(x)
    assert x_out.shape == torch.Size([2, 16, 64, 64])

    # test cgnet ContextGuidedBlock without checkpoint forward
    block = ContextGuidedBlock(32, 32)
    assert not block.with_cp
    x = torch.randn(3, 32, 32, 32)
    x_out = block(x)
    assert x_out.shape == torch.Size([3, 32, 32, 32])

    # test cgnet ContextGuidedBlock with down sampling
    block = ContextGuidedBlock(32, 32, downsample=True)
    assert block.conv1x1.conv.in_channels == 32
    assert block.conv1x1.conv.out_channels == 32
    assert block.conv1x1.conv.kernel_size == (3, 3)
    assert block.conv1x1.conv.stride == (2, 2)
    assert block.conv1x1.conv.padding == (1, 1)

    assert block.f_loc.in_channels == 32
    assert block.f_loc.out_channels == 32
    assert block.f_loc.kernel_size == (3, 3)
    assert block.f_loc.stride == (1, 1)
    assert block.f_loc.padding == (1, 1)
    assert block.f_loc.groups == 32
    assert block.f_loc.dilation == (1, 1)
    assert block.f_loc.bias is None

    assert block.f_sur.in_channels == 32
    assert block.f_sur.out_channels == 32
    assert block.f_sur.kernel_size == (3, 3)
    assert block.f_sur.stride == (1, 1)
    assert block.f_sur.padding == (2, 2)
    assert block.f_sur.groups == 32
    assert block.f_sur.dilation == (2, 2)
    assert block.f_sur.bias is None

    assert block.bottleneck.in_channels == 64
    assert block.bottleneck.out_channels == 32
    assert block.bottleneck.kernel_size == (1, 1)
    assert block.bottleneck.stride == (1, 1)
    assert block.bottleneck.bias is None

    x = torch.randn(1, 32, 32, 32)
    x_out = block(x)
    assert x_out.shape == torch.Size([1, 32, 16, 16])

    # test cgnet ContextGuidedBlock without down sampling
    block = ContextGuidedBlock(32, 32, downsample=False)
    assert block.conv1x1.conv.in_channels == 32
    assert block.conv1x1.conv.out_channels == 16
    assert block.conv1x1.conv.kernel_size == (1, 1)
    assert block.conv1x1.conv.stride == (1, 1)
    assert block.conv1x1.conv.padding == (0, 0)

    assert block.f_loc.in_channels == 16
    assert block.f_loc.out_channels == 16
    assert block.f_loc.kernel_size == (3, 3)
    assert block.f_loc.stride == (1, 1)
    assert block.f_loc.padding == (1, 1)
    assert block.f_loc.groups == 16
    assert block.f_loc.dilation == (1, 1)
    assert block.f_loc.bias is None

    assert block.f_sur.in_channels == 16
    assert block.f_sur.out_channels == 16
    assert block.f_sur.kernel_size == (3, 3)
    assert block.f_sur.stride == (1, 1)
    assert block.f_sur.padding == (2, 2)
    assert block.f_sur.groups == 16
    assert block.f_sur.dilation == (2, 2)
    assert block.f_sur.bias is None

    x = torch.randn(1, 32, 32, 32)
    x_out = block(x)
    assert x_out.shape == torch.Size([1, 32, 32, 32])


def test_cgnet_backbone():
    with pytest.raises(AssertionError):
        # check invalid num_channels
        CGNet(num_channels=(32, 64, 128, 256))

    with pytest.raises(AssertionError):
        # check invalid num_blocks
        CGNet(num_blocks=(3, 21, 3))

    with pytest.raises(AssertionError):
        # check invalid dilation
        CGNet(num_blocks=2)

    with pytest.raises(AssertionError):
        # check invalid reduction
        CGNet(reductions=16)

    with pytest.raises(AssertionError):
        # check invalid num_channels and reduction
        CGNet(num_channels=(32, 64, 128), reductions=(64, 129))

    # Test CGNet with default settings
    model = CGNet()
    model.init_weights()
    model.train()

    imgs = torch.randn(2, 3, 224, 224)
    feat = model(imgs)
    assert len(feat) == 3
    assert feat[0].shape == torch.Size([2, 35, 112, 112])
    assert feat[1].shape == torch.Size([2, 131, 56, 56])
    assert feat[2].shape == torch.Size([2, 256, 28, 28])

    # Test CGNet with norm_eval True and with_cp True
    model = CGNet(norm_eval=True, with_cp=True)
    with pytest.raises(TypeError):
        # check invalid pretrained
        model.init_weights(pretrained=8)
    model.init_weights()
    model.train()

    imgs = torch.randn(2, 3, 224, 224)
    feat = model(imgs)
    assert len(feat) == 3
    assert feat[0].shape == torch.Size([2, 35, 112, 112])
    assert feat[1].shape == torch.Size([2, 131, 56, 56])
    assert feat[2].shape == torch.Size([2, 256, 28, 28])
