#include "GLShader.h"
#include <iostream>
#include <sstream>
#include <hilog/log.h>
#include "common/common.h"

GLShader::GLShader(std::shared_ptr<GLShaderData>& data, int id, std::shared_ptr<RNGLContext>& rnglContext)
    : name(data->name), vert(data->vert), frag(data->frag), id(id), rnglContext(rnglContext), program(0), buffer(0), pointerLoc(-1), compilationFailed(nullptr) {}

GLShader::~GLShader() {
    if (buffer) {
        glDeleteProgram(program);
        glDeleteBuffers(1, &buffer);
    }
    if (compilationFailed) {
        delete compilationFailed;
    }
}

void GLShader::bind() {
    ensureCompile();
    if (!glIsProgram(program)) {
        runtimeException("not a program");
    }
    glUseProgram(program);
    glBindBuffer(GL_ARRAY_BUFFER, buffer);
    glEnableVertexAttribArray(pointerLoc);
    glVertexAttribPointer(pointerLoc, 2, GL_FLOAT, GL_FALSE, 0, 0);
}

void GLShader::validate() {
    glValidateProgram(program);
    GLint validSuccess;
    glGetProgramiv(program, GL_VALIDATE_STATUS, &validSuccess);
    if (validSuccess == GL_FALSE) {
        GLchar infoLog[1024];
        glGetProgramInfoLog(program, 1024, nullptr, infoLog);
        runtimeException(infoLog);
    }
}

void GLShader::setUniform(const std::string& name, GLint i) {
    glUniform1i(uniformLocations.at(name), i);
}

void GLShader::setUniform(const std::string& name, GLfloat f) {
    glUniform1f(uniformLocations.at(name), f);
}

void GLShader::setUniform(const std::string& name, const GLfloat* buf, int type) {
    switch (type) {
        case GL_FLOAT_VEC2:
            glUniform2fv(uniformLocations.at(name), 1, buf);
            break;
        case GL_FLOAT_VEC3:
            glUniform3fv(uniformLocations.at(name), 1, buf);
            break;
        case GL_FLOAT_VEC4:
            glUniform4fv(uniformLocations.at(name), 1, buf);
            break;
        case GL_FLOAT_MAT2:
            glUniformMatrix2fv(uniformLocations.at(name), 1, GL_FALSE, buf);
            break;
        case GL_FLOAT_MAT3:
            glUniformMatrix3fv(uniformLocations.at(name), 1, GL_FALSE, buf);
            break;
        case GL_FLOAT_MAT4:
            glUniformMatrix4fv(uniformLocations.at(name), 1, GL_FALSE, buf);
            break;
        default: {
            std::ostringstream oss;
            oss << "Unsupported case: uniform '" << name << "' type: " << type;
            runtimeException(oss.str());
        }
    }
}

void GLShader::setUniform(const std::string& name, const GLint* buf, int type) {
    switch (type) {
        case GL_INT_VEC2:
        case GL_BOOL_VEC2:
            glUniform2iv(uniformLocations.at(name), 1, buf);
            break;
        case GL_INT_VEC3:
        case GL_BOOL_VEC3:
            glUniform3iv(uniformLocations.at(name), 1, buf);
            break;
        case GL_INT_VEC4:
        case GL_BOOL_VEC4:
            glUniform4iv(uniformLocations.at(name), 1, buf);
            break;
        default: {
            std::ostringstream oss;
            oss << "Unsupported case: uniform '" << name << "' type: " << type;
            runtimeException(oss.str());
        }
    }
}

const std::string& GLShader::getName() const {
    return name;
}

const std::unordered_map<std::string, GLint>& GLShader::getUniformTypes() const {
    return uniformTypes;
}

const std::unordered_map<std::string, GLint>& GLShader::getUniformSizes() const {
    return uniformSizes;
}

const std::vector<std::string>& GLShader::getUniformNames() const {
    return uniformNames;
}

bool GLShader::isReady() const {
    //return buffer != 0 && !uniformLocations.empty();
    return buffer != 0 && program != 0;
}

bool GLShader::ensureCompile() {
    if (!isReady()) {
        if (compilationFailed) {
            OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_PRINT_DOMAIN, "GLShader", "compilationFailed");
            throw *compilationFailed;
        }
        try {
            OH_LOG_Print(LOG_APP, LOG_INFO, LOG_PRINT_DOMAIN, "GLShader", "makeProgram");
            makeProgram();
            // 这里需要实现 rnglContext 的回调方法
            // rnglContext->shaderSucceedToCompile(id, uniformTypes);
        } catch (const std::runtime_error& e) {
            OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_PRINT_DOMAIN, "GLShader", "makeProgram compilationFailed");
            compilationFailed = new std::runtime_error(e);
            // 这里需要实现 rnglContext 的回调方法
            // rnglContext->shaderFailedToCompile(id, *compilationFailed);
            throw e;
        }
    }
    return isReady();
}

GLuint GLShader::compileShader(const std::string& code, GLenum shaderType) {
    GLuint shaderHandle = glCreateShader(shaderType);
    const char* codePtr = code.c_str();
    glShaderSource(shaderHandle, 1, &codePtr, nullptr);
    glCompileShader(shaderHandle);
    GLint compileSuccess;
    glGetShaderiv(shaderHandle, GL_COMPILE_STATUS, &compileSuccess);
    if (compileSuccess == GL_FALSE) {
        GLchar infoLog[1024];
        glGetShaderInfoLog(shaderHandle, 1024, nullptr, infoLog);
        runtimeException(infoLog);
        return 0;
    }
    return shaderHandle;
}

void GLShader::computeMeta() {
    GLint nbUniforms;
    glGetProgramiv(program, GL_ACTIVE_UNIFORMS, &nbUniforms);
    for (GLint i = 0; i < nbUniforms; ++i) {
        GLint size;
        GLenum type;
        char uniformName[1024];
        glGetActiveUniform(program, i, 1024, nullptr, &size, &type, uniformName);
        std::string nameStr(uniformName);
        if (nameStr.find("[0]") != std::string::npos) {
            nameStr = nameStr.substr(0, nameStr.length() - 3);
        }
        uniformNames.push_back(nameStr);
        uniformTypes[nameStr] = type;
        uniformSizes[nameStr] = size;
        if (size == 1) {
            GLint location = glGetUniformLocation(program, nameStr.c_str());
            uniformLocations[nameStr] = location;
        } else {
            for (GLint j = 0; j < size; ++j) {
                std::ostringstream oss;
                oss << nameStr << "[" << j << "]";
                std::string uniformIndexName = oss.str();
                GLint location = glGetUniformLocation(program, uniformIndexName.c_str());
                uniformLocations[uniformIndexName] = location;
                uniformTypes[uniformIndexName] = type;
                uniformSizes[uniformIndexName] = 1;
            }
        }
    }
}

void GLShader::makeProgram() {
    GLuint vertex = compileShader(vert, GL_VERTEX_SHADER);
    if (!vertex) return;

    GLuint fragment = compileShader(frag, GL_FRAGMENT_SHADER);
    if (!fragment) return;

    program = glCreateProgram();
    if (program == 0) return;
    glAttachShader(program, vertex);
    glAttachShader(program, fragment);
    glLinkProgram(program);

    GLint linkSuccess;
    glGetProgramiv(program, GL_LINK_STATUS, &linkSuccess);
    if (linkSuccess == GL_FALSE) {
        GLchar infoLog[1024];
        glGetProgramInfoLog(program, 1024, nullptr, infoLog);
        runtimeException(infoLog);
    }

    glUseProgram(program);

    validate();

    computeMeta();

    pointerLoc = glGetAttribLocation(program, "position");

    glGenBuffers(1, &buffer);
    glBindBuffer(GL_ARRAY_BUFFER, buffer);

    const float buf[] = {
        -1.0f, -1.0f,
        -1.0f, 4.0f,
        4.0f, -1.0f
    };
    glBufferData(GL_ARRAY_BUFFER, sizeof(buf), buf, GL_STATIC_DRAW);
}

void GLShader::runtimeException(const std::string& msg) {
    OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_PRINT_DOMAIN, "GLShader", "%{public}s runtimeException: '%{public}s'", name.c_str(), msg.c_str());
    throw std::runtime_error("GLShader " + name + ": " + msg);
}