local MapTable, parent = torch.class('nn.MapTable', 'nn.Container')

function MapTable:__init(module, shared)
    parent.__init(self)
    self.shared = shared or {'weight', 'bias', 'gradWeight', 'gradBias'}
    self.output = {}
    self.gradInput = {}
    self:add(module)
end

function MapTable:_extend(n)
    self.modules[1] = self.module
    for i = 2, n do
        if not self.modules[i] then
            self.modules[i] = self.module:clone(table.unpack(self.shared))
        end
    end
end

function MapTable:resize(n)
    self:_extend(n)
    for i = n + 1, #self.modules do
        self.modules[i] = nil
    end
end

function MapTable:add(module)
    assert(not self.module, 'Single module required')
    self.module = module
    self.modules[1] = self.module
    return self
end

function MapTable:updateOutput(input)
    self.output = {}
    self:_extend(#input)
    for i = 1, #input do
        self.output[i] = self:rethrowErrors(self.modules[i], i, 'updateOutput', input[i])
    end
    return self.output
end

function MapTable:updateGradInput(input, gradOutput)
    self.gradInput = {}
    self:_extend(#input)
    for i = 1, #input do
        self.gradInput[i] = self:rethrowErrors(self.modules[i], i, 'updateGradInput', input[i], gradOutput[i])
    end
    return self.gradInput
end

function MapTable:accGradParameters(input, gradOutput, scale)
    scale = scale or 1
    self:_extend(#input)
    for i = 1, #input do
        self:rethrowErrors(self.modules[i], i, 'accGradParameters', input[i], gradOutput[i], scale)
    end
end

function MapTable:accUpdateGradParameters(input, gradOutput, lr)
    lr = lr or 1
    self:_extend(#input)
    for i = 1, #input do
        self:rethrowErrors(self.modules[i], i, 'accUpdateGradParameters', input[i], gradOutput[i], lr)
    end
end

function MapTable:zeroGradParameters()
    if self.module then
        self.module:zeroGradParameters()
    end
end

function MapTable:updateParameters(learningRate)
    if self.module then
        self.module:updateParameters(learningRate)
    end
end

function MapTable:clearState()
    for i = 2, #self.modules do
        self.modules[i] = nil
    end
    parent.clearState(self)
end

function MapTable:__tostring__()
    local tab = '  '
    local line = '\n'
    local extlast = '      '
    local str = torch.type(self)
    if self.module then
        str = str .. ' {' .. line .. tab
        str = str .. tostring(self.module):gsub(line, line .. tab .. extlast) .. line .. '}'
    else
        str = str .. ' { }'
    end
    return str
end
