# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmseg.models.decode_heads import DPTHead


def test_dpt_head():

    with pytest.raises(AssertionError):
        # input_transform must be 'multiple_select'
        head = DPTHead(
            in_channels=[768, 768, 768, 768],
            channels=4,
            num_classes=19,
            in_index=[0, 1, 2, 3])

    head = DPTHead(
        in_channels=[768, 768, 768, 768],
        channels=4,
        num_classes=19,
        in_index=[0, 1, 2, 3],
        input_transform='multiple_select')

    inputs = [[torch.randn(4, 768, 2, 2),
               torch.randn(4, 768)] for _ in range(4)]
    output = head(inputs)
    assert output.shape == torch.Size((4, 19, 16, 16))

    # test readout operation
    head = DPTHead(
        in_channels=[768, 768, 768, 768],
        channels=4,
        num_classes=19,
        in_index=[0, 1, 2, 3],
        input_transform='multiple_select',
        readout_type='add')
    output = head(inputs)
    assert output.shape == torch.Size((4, 19, 16, 16))

    head = DPTHead(
        in_channels=[768, 768, 768, 768],
        channels=4,
        num_classes=19,
        in_index=[0, 1, 2, 3],
        input_transform='multiple_select',
        readout_type='project')
    output = head(inputs)
    assert output.shape == torch.Size((4, 19, 16, 16))
