# coding: utf-8

# ---------------------------------------------------------------------------------
# MW-Linux面板
# ---------------------------------------------------------------------------------
# copyright (c) 2018-∞(https://github.com/midoks/mdserver-web) All rights reserved.
# ---------------------------------------------------------------------------------
# Author: midoks <midoks@163.com>
# ---------------------------------------------------------------------------------

# ---------------------------------------------------------------------------------
# sqlite3操作
# ---------------------------------------------------------------------------------

import os
import sys
import sqlite3

import core.mw as mw

def getPanelDir():
    return os.path.dirname(os.getcwd())

def getTracebackInfo():
    import traceback
    return traceback.format_exc()

class Sql():
    #------------------------------
    # 数据库操作类 For sqlite3
    #------------------------------
    __DB_FILE = None            # 数据库文件
    __DB_CONN = None            # 数据库连接对象
    __DB_TABLE = ""             # 被操作的表名称
    __OPT_WHERE = ""            # where条件
    __OPT_LIMIT = ""            # limit条件
    __OPT_GROUP = ""            # group条件
    __OPT_ORDER = ""            # order条件
    __OPT_FIELD = "*"           # field条件
    __OPT_PARAM = ()            # where值

    __debug = False

    def __init__(self):
        self.__DB_FILE = getPanelDir()+'/data/panel.db'

    def __getConn(self):
        # 取数据库对象
        try:
            if self.__DB_CONN == None:
                self.__DB_CONN = sqlite3.connect(self.__DB_FILE)
                self.__DB_CONN.text_factory = str
        except Exception as ex:
            print(getTracebackInfo())
            return "error: " + str(ex)

    def changeTextFactoryToBytes(self):
        self.__DB_CONN.text_factory = bytes
        return True

    def debug(self, debug=False):
        self.__debug = debug
        return self

    def autoTextFactory(self):
        if sys.version_info[0] == 3:
            self.__DB_CONN.text_factory = lambda x: str(x, encoding="utf-8", errors='ignore')
        else:
            self.__DB_CONN.text_factory = lambda x: unicode(x, "utf-8", "ignore")

    def dbfile(self, name):
        self.__DB_FILE = getPanelDir()+'/data/' + name + '.db'
        return self

    def dbPos(self, path, name, suffix_name = 'db'):
        self.__DB_FILE = path + '/' + name + '.' + suffix_name
        return self

    def table(self, table):
        # 设置表名
        self.__DB_TABLE = table
        return self

    def where(self, where, param=()):
        # WHERE条件
        if where:
            self.__OPT_WHERE = " WHERE " + where
            self.__OPT_PARAM = param
        return self

    def andWhere(self, where, param):
        # WHERE条件
        if where:
            self.__OPT_WHERE = self.__OPT_WHERE + " and " + where
            # print(param)
            # print(self.__OPT_PARAM)
            self.__OPT_PARAM = self.__OPT_PARAM + param
        return self

    def order(self, order):
        # ORDER条件
        if len(order):
            self.__OPT_ORDER = " ORDER BY " + order
        else:
            self.__OPT_ORDER = ""
        return self

    def group(self, group):
        if len(group):
            self.__OPT_GROUP = " GROUP BY " + group
        else:
            self.__OPT_GROUP = ""
        return self

    def limit(self, limit):
        # LIMIT条件
        if len(limit):
            self.__OPT_LIMIT = " LIMIT " + limit
        else:
            self.__OPT_LIMIT = ""
        return self

    def field(self, field):
        # FIELD条件
        if len(field):
            self.__OPT_FIELD = field
        return self

    def getDbField(self,name):
        sql = "PRAGMA table_info(%s)" % name
        result = self.__DB_CONN.execute(sql)
        data = result.fetchall()

        fields = []
        for i in data:
            fields.append(i[1])
        return fields

    def getDbFieldString(self,name):
        fields = self.getDbField(name)
        return ','.join(fields)
        

    def select(self):
        # 查询数据集
        self.__getConn()
        try:
            sql = "SELECT " + self.__OPT_FIELD + " FROM " + self.__DB_TABLE + \
                self.__OPT_WHERE + self.__OPT_GROUP + self.__OPT_ORDER + self.__OPT_LIMIT

            if self.__debug:
                print(sql)
                print(self.__OPT_PARAM)
                
            result = self.__DB_CONN.execute(sql, self.__OPT_PARAM)
            data = result.fetchall()
            if len(data) == 0:
                return data

            # 构造字曲系列
            if self.__OPT_FIELD != "*":
                field = self.__OPT_FIELD.split(',')
                tmp = []
                for row in data:
                    i = 0
                    t = {}
                    for key in field:
                        t[key] = row[i]
                        i += 1
                    tmp.append(t)
                    del(t)
                data = tmp
                del(tmp)
            else:
                field = self.getDbField(self.__DB_TABLE)
                tmp = []
                for row in data:
                    i = 0
                    t = {}
                    for key in field:
                        t[key] = row[i]
                        i += 1
                    tmp.append(t)
                    del(t)
                data = tmp
                del(tmp)
                # 将元组转换成列表
                # tmp = map(list, data)
                # data = tmp
                # del(tmp)
            self.__close()
            return data
        except Exception as ex:
            # return "error: " + str(ex)
            return []

    def inquiry(self, input_field=''):
        # 查询数据集
        # 不清空查询参数
        self.__getConn()
        try:
            sql = "SELECT " + self.__OPT_FIELD + " FROM " + self.__DB_TABLE + \
                self.__OPT_WHERE + self.__OPT_GROUP + self.__OPT_ORDER + self.__OPT_LIMIT
            # if mw.isDebugMode():
            #     print(sql, self.__OPT_PARAM)
            result = self.__DB_CONN.execute(sql, self.__OPT_PARAM)
            data = result.fetchall()
            # 构造字曲系列
            if self.__OPT_FIELD != "*":

                if input_field != "":
                    field = input_field.split(',')
                else:
                    field = self.__OPT_FIELD.split(',')

                tmp = []
                for row in data:
                    i = 0
                    tmp1 = {}
                    for key in field:
                        tmp1[key] = row[i]
                        i += 1
                    tmp.append(tmp1)
                    del(tmp1)
                data = tmp
                del(tmp)
            else:
                # 将元组转换成列表
                tmp = map(list, data)
                data = tmp
                del(tmp)
            return data
        except Exception as ex:
            return "error: " + str(ex)

    def getField(self, keyName):
        # 取回指定字段
        result = self.field(keyName).select()
        if len(result) == 1:
            return result[0][keyName]
        return None

    def setField(self, keyName, keyValue):
        # 更新指定字段
        return self.save(keyName, (keyValue,))

    def find(self):
        # 取一行数据
        result = self.limit("1").select()
        if len(result) == 1:
            return result[0]
        return None

    def count(self):
        # 取行数
        key = "COUNT(*)"
        data = self.field(key).select()
        try:
            return int(data[0][key])
        except:
            return 0

    def add(self, keys, param):
        # 插入数据
        self.__getConn()
        try:
            values = ""
            for key in keys.split(','):
                values += "?,"
            values = self.checkInput(values[0:len(values) - 1])
            sql = "INSERT INTO " + self.__DB_TABLE + \
                "(" + keys + ") " + "VALUES(" + values + ")"
            result = self.__DB_CONN.execute(sql, param)
            last_id = result.lastrowid
            self.__close()
            self.__DB_CONN.commit()
            return last_id
        except Exception as ex:
            print(str(ex))
            return 0

    # 插入数据
    def insert(self, pdata):
        if not pdata:
            return False
        keys, param = self.__format_pdata(pdata)
        return self.add(keys, param)

    # 更新数据
    def update(self, pdata):
        if not pdata:
            return False
        keys, param = self.__format_pdata(pdata)
        return self.save(keys, param)

    # 构造数据
    def __format_pdata(self, pdata):
        keys = pdata.keys()
        keys_str = ','.join(keys)
        param = []
        for k in keys:
            param.append(pdata[k])
        return keys_str, tuple(param)

    def checkInput(self, data):
        if not data:
            return data
        if type(data) != str:
            return data
        checkList = [
            {'d': '<', 'r': '＜'},
            {'d': '>', 'r': '＞'},
            {'d': '\'', 'r': '‘'},
            {'d': '"', 'r': '“'},
            {'d': '&', 'r': '＆'},
            {'d': '#', 'r': '＃'},
            {'d': '<', 'r': '＜'}
        ]
        for v in checkList:
            data = data.replace(v['d'], v['r'])
        return data

    def addAll(self, keys, param):
        # 插入数据
        self.__getConn()
        try:
            values = ""
            for key in keys.split(','):
                values += "?,"
            values = values[0:len(values) - 1]
            sql = "INSERT INTO " + self.__DB_TABLE + \
                "(" + keys + ") " + "VALUES(" + values + ")"
            result = self.__DB_CONN.execute(sql, param)
            return True
        except Exception as ex:
            return "error: " + str(ex)

    def commit(self):
        self.__close()
        self.__DB_CONN.commit()

    def save(self, keys, param):
        # 更新数据
        self.__getConn()
        try:
            opt = ""
            for key in keys.split(','):
                opt += key + "=?,"
            opt = opt[0:len(opt) - 1]
            sql = "UPDATE " + self.__DB_TABLE + " SET " + opt + self.__OPT_WHERE

            if self.__debug:
                print(sql, param)

            # 处理拼接WHERE与UPDATE参数
            tmp = list(param)
            for arg in self.__OPT_PARAM:
                tmp.append(arg)
            self.__OPT_PARAM = tuple(tmp)
            result = self.__DB_CONN.execute(sql, self.__OPT_PARAM)
            self.__close()
            self.__DB_CONN.commit()
            return result.rowcount
        except Exception as ex:
            return "error: " + str(ex)

    def delete(self, id=None):
        # 删除数据
        self.__getConn()
        try:
            if id:
                self.__OPT_WHERE = " WHERE id=?"
                self.__OPT_PARAM = (id,)
            sql = "DELETE FROM " + self.__DB_TABLE + self.__OPT_WHERE
            result = self.__DB_CONN.execute(sql, self.__OPT_PARAM)
            self.__close()
            self.__DB_CONN.commit()
            return result.rowcount
        except Exception as ex:
            return "error: " + str(ex)

    def originExecute(self, sql, param=()):
        self.__getConn()
        try:
            result = self.__DB_CONN.execute(sql, param)
            self.__DB_CONN.commit()
            return result
        except Exception as ex:
            return "error: " + str(ex)

    def execute(self, sql, param=()):
        # 执行SQL语句返回受影响行
        self.__getConn()
        # print sql, param
        try:
            result = self.__DB_CONN.execute(sql, param)
            self.__DB_CONN.commit()
            return result.rowcount
        except Exception as ex:
            return "error: " + str(ex)

    def query(self, sql, param=()):
        # 执行SQL语句返回数据集
        self.__getConn()
        try:
            result = self.__DB_CONN.execute(sql, param)
            # 将元组转换成列表
            # data = map(list, result)
            return result
        except Exception as ex:
            return "error: " + str(ex)

    def create(self, name):
        # 创建数据表
        self.__getConn()
        import mw
        script = mw.readFile('data/' + name + '.sql')
        result = self.__DB_CONN.executescript(script)
        self.__DB_CONN.commit()
        return result.rowcount

    def fofile(self, filename):
        # 执行脚本
        self.__getConn()
        import mw
        script = mw.readFile(filename)
        result = self.__DB_CONN.executescript(script)
        self.__DB_CONN.commit()
        return result.rowcount

    def __close(self):
        # 清理条件属性
        self.__OPT_WHERE = ""
        self.__OPT_FIELD = "*"
        self.__OPT_ORDER = ""
        self.__OPT_LIMIT = ""
        self.__OPT_PARAM = ()

    def close(self):
        # 释放资源
        try:
            self.__DB_CONN.close()
            self.__DB_CONN = None
        except:
            pass
