package export

import (
	"encoding/csv"
	"encoding/json"
	"fmt"
	"html"
	"io"
	"io/ioutil"
	"log"
	"os"
	"strings"

	"github.com/linlexing/dbx/ddb"

	"github.com/jmoiron/sqlx"

	"golang.org/x/text/encoding"

	"time"

	"github.com/linlexing/dbx/common"
	"github.com/linlexing/dbx/data"
	"github.com/linlexing/dbx/scan"
	"github.com/linlexing/dbx/schema"
)

type Exports interface {
	Export(db ddb.DB, typeColumns []*scan.ColumnType, strSql string,
		w io.Writer, endi encoding.Encoding, progressFunc func(string),
		sqlParams ...interface{}) (err error)
}

type ExportCsv struct{}
type ExportTxt struct{}
type ExportJson struct{}
type ExportExcel struct{}

func (e *ExportCsv) Export(db ddb.DB, typeColumns []*scan.ColumnType, strSql string,
	writer io.Writer, endi encoding.Encoding, progressFunc func(string),
	sqlParams ...interface{}) (err error) {
	w := writer
	//采用gbk输出
	if endi != nil {
		w = endi.NewEncoder().Writer(w)
	}
	var rowCount int64
	s := data.Bind(db.DriverName(), fmt.Sprintf("select count(*) from (%s) out_count", strSql))
	if err = db.QueryRow(s, sqlParams...).Scan(&rowCount); err != nil {
		err = common.NewSQLError(err, s, sqlParams...)
		return
	}

	s = data.Bind(db.DriverName(), strSql)
	rows, err := db.Query(s, sqlParams...)
	if err != nil {
		err = common.NewSQLError(err, s, sqlParams...)
		log.Println(err)
		return
	}
	defer rows.Close()
	cols, err := rows.Columns()
	if err != nil {
		return
	}
	for i, v := range cols {
		cols[i] = strings.ToUpper(v)
	}
	cw := csv.NewWriter(w)
	cw.UseCRLF = true
	//写入首行
	if err = cw.Write(cols); err != nil {
		return
	}
	startTime := time.Now()
	beginTime := startTime
	progressFunc(fmt.Sprintf("start export,total %d records", rowCount))

	//写入每行数据
	var icount int64
	for rows.Next() {
		var values []interface{}
		values, err = scan.TypeScan(rows, typeColumns)
		if err != nil {
			return
		}
		strList := make([]string, len(values))
		for i, v := range values {
			strList[i] = typeColumns[i].Type.ToString(v)
		}
		cw.Write(strList)
		icount++
		totalSec := time.Since(startTime).Seconds()
		if totalSec >= 5 {
			progressFunc(fmt.Sprintf("\t%.2f%%\t%d", 100.0*float64(icount)/float64(rowCount), icount))
			startTime = time.Now()
		}
	}
	if err = rows.Err(); err != nil {
		return
	}
	cw.Flush()
	if err = cw.Error(); err != nil {
		return err
	}
	progressFunc(fmt.Sprintf("total %d records exported %.2fs", icount, time.Since(beginTime).Seconds()))

	return nil

}
func (e *ExportTxt) Export(db ddb.DB, typeColumns []*scan.ColumnType, strSql string,
	w io.Writer, endi encoding.Encoding, progressFunc func(string),
	sqlParams ...interface{}) (err error) {
	//采用gbk输出
	if endi != nil {
		w = endi.NewEncoder().Writer(w)
	}
	var rowCount int64
	s := data.Bind(db.DriverName(), fmt.Sprintf("select count(*) from (%s) out_count", strSql))
	if err = db.QueryRow(s, sqlParams...).Scan(&rowCount); err != nil {
		err = common.NewSQLError(err, s, sqlParams...)
		return
	}

	s = data.Bind(db.DriverName(), strSql)
	rows, err := db.Query(s, sqlParams...)
	if err != nil {
		err = common.NewSQLError(err, s, sqlParams...)
		log.Println(err)
		return
	}
	defer rows.Close()
	cols, err := rows.Columns()
	if err != nil {
		return
	}
	for i, v := range cols {
		cols[i] = strings.ToUpper(v)
	}
	//写入首行
	if _, err = w.Write([]byte(strings.Join(cols, "\t"))); err != nil {
		return
	}
	startTime := time.Now()
	beginTime := startTime
	progressFunc(fmt.Sprintf("start export total %d records", rowCount))

	//写入每行数据

	var icount int64
	for rows.Next() {
		var values []interface{}
		if values, err = scan.TypeScan(rows, typeColumns); err != nil {
			return
		}
		strList := make([]string, len(values))
		for i, v := range values {
			strList[i] = typeColumns[i].Type.ToString(v)
		}
		if _, err = w.Write([]byte(strings.Join(strList, "\t") + "\r\n")); err != nil {
			return
		}
		icount++
		totalSec := time.Since(startTime).Seconds()
		if totalSec >= 5 {
			progressFunc(fmt.Sprintf("\t%.2f%%\t%d/%d\t%.2fs", 100.0*float64(icount)/float64(rowCount), icount, rowCount, totalSec))
			startTime = time.Now()
		}
	}
	if err = rows.Err(); err != nil {
		return
	}
	progressFunc(fmt.Sprintf("total %d records exported %.2fs", icount, time.Since(beginTime).Seconds()))

	return nil

}
func (e *ExportJson) Export(db ddb.DB, typeColumns []*scan.ColumnType, strSql string,
	w io.Writer, endi encoding.Encoding, progressFunc func(string),
	sqlParams ...interface{}) (err error) {
	//采用gbk输出
	if endi != nil {
		w = endi.NewEncoder().Writer(w)
	}
	var rowCount int64
	s := data.Bind(db.DriverName(), fmt.Sprintf("select count(*) from (%s) out_count", strSql))
	if err = db.QueryRow(s, sqlParams...).Scan(&rowCount); err != nil {
		err = common.NewSQLError(err, s, sqlParams...)
		return
	}

	s = data.Bind(db.DriverName(), strSql)
	rows, err := db.Query(s, sqlParams...)
	if err != nil {
		err = common.NewSQLError(err, s, sqlParams...)
		log.Println(err)
		return
	}

	defer rows.Close()
	cols, err := rows.Columns()
	if err != nil {
		return
	}
	colsKey := make([][]byte, len(cols))
	for i, v := range cols {
		cols[i] = strings.ToUpper(v)
		colsKey[i], err = json.Marshal(cols[i])
		if err != nil {
			return
		}
	}
	startTime := time.Now()
	beginTime := startTime

	progressFunc(fmt.Sprintf("start export total %d records", rowCount))
	//写入每行数据
	var icount int64
	for rows.Next() {
		if _, err = w.Write([]byte("{")); err != nil {
			return
		}
		var values []interface{}
		if values, err = scan.TypeScan(rows, typeColumns); err != nil {
			return
		}
		for i, v := range values {
			//key
			if _, err = w.Write(colsKey[i]); err != nil {
				return
			}

			if _, err = w.Write([]byte(":")); err != nil {
				return
			}
			valBys, err := json.Marshal(v)
			if err != nil {
				return err
			}
			if _, err = w.Write(valBys); err != nil {
				return err
			}

			if i < len(values)-1 {
				if _, err = w.Write([]byte(",")); err != nil {
					return err
				}
			}
		}
		if _, err = w.Write([]byte("}\r\n")); err != nil {
			return
		}
		icount++
		totalSec := time.Since(startTime).Seconds()
		if totalSec >= 5 {
			progressFunc(fmt.Sprintf("\t%.2f%%\t%d/%d\t%.2fs", 100.0*float64(icount)/float64(rowCount), icount, rowCount, totalSec))
			startTime = time.Now()
		}
	}
	if err = rows.Err(); err != nil {
		return
	}
	progressFunc(fmt.Sprintf("total %d records exported %.2fs", icount, time.Since(beginTime).Seconds()))

	return nil

}
func mustWrite(w io.Writer, str string) {
	if _, err := w.Write([]byte(str)); err != nil {
		log.Panic(err)
	}

}
func (e *ExportExcel) Export(db ddb.DB, typeColumns []*scan.ColumnType, strSql string,
	w io.Writer, endi encoding.Encoding, progressFunc func(string),
	sqlParams ...interface{}) (err error) {
	charset := "utf-8"
	//采用gbk输出
	if endi != nil {
		w = endi.NewEncoder().Writer(w)
		charset = "gbk"
	}
	var rowCount int64
	s := data.Bind(db.DriverName(), fmt.Sprintf("select count(*) from (%s) out_count", strSql))
	if err = db.QueryRow(s, sqlParams...).Scan(&rowCount); err != nil {
		err = common.NewSQLError(err, s, sqlParams...)
		return
	}

	s = data.Bind(db.DriverName(), strSql)
	rows, err := db.Query(s, sqlParams...)
	if err != nil {
		err = common.NewSQLError(err, s, sqlParams...)
		log.Println(err)
		return
	}
	defer rows.Close()
	mustWrite(w, "<html><head>")
	mustWrite(w, `<meta http-equiv="Content-Language" content="zh-cn">`)
	mustWrite(w, fmt.Sprintf(`<meta http-equiv="Content-Type" content="text/html;charset=%s">`, charset))
	mustWrite(w, `<style>td{vnd.ms-excel.numberformat:@}</style>`)
	mustWrite(w, "</head><body>")
	mustWrite(w, "<table cellspacing=0 rules=all border=1 style='border-collapse:collapse;'>")
	mustWrite(w, "<tr>")
	cols, err := rows.Columns()
	if err != nil {
		return
	}
	for i, v := range cols {
		cols[i] = strings.ToUpper(v)
		mustWrite(w, "<td>")
		mustWrite(w, html.EscapeString(cols[i]))
		mustWrite(w, "</td>")
	}
	mustWrite(w, "</tr>")

	startTime := time.Now()
	beginTime := startTime

	progressFunc(fmt.Sprintf("start export total %d records", rowCount))
	//写入每行数据
	var icount int64
	for rows.Next() {
		mustWrite(w, "<tr>")
		var values []interface{}
		if values, err = scan.TypeScan(rows, typeColumns); err != nil {
			return
		}
		for i, v := range values {
			mustWrite(w, "<td>")
			mustWrite(w, html.EscapeString(typeColumns[i].Type.ToString(v)))
			mustWrite(w, "</td>")
		}
		mustWrite(w, "</tr>")
		icount++
		totalSec := time.Since(startTime).Seconds()
		if totalSec >= 5 {
			progressFunc(fmt.Sprintf("\t%.2f%%\t%d/%d\t%.2fs", 100.0*float64(icount)/float64(rowCount), icount, rowCount, totalSec))
			startTime = time.Now()
		}
	}
	if err = rows.Err(); err != nil {
		return
	}
	mustWrite(w, "</table></body></html>")
	progressFunc(fmt.Sprintf("total %d records exported %.2fs", icount, time.Since(beginTime).Seconds()))

	return nil

}
func exportSqlite(db ddb.DB, tableName string, typeColumns []*scan.ColumnType, ucols []string,
	strSql string, progressFunc func(string), sqlParams ...interface{}) (tempFileName string, err error) {
	var rowCount int64
	s := data.Bind(db.DriverName(), fmt.Sprintf("select count(*) from (%s) out_count", strSql))
	if err = db.QueryRow(s, sqlParams...).Scan(&rowCount); err != nil {
		err = common.NewSQLError(err, s, sqlParams...)
		return
	}

	//创建数据库
	tmpf, err := ioutil.TempFile("", "exp_")
	if err != nil {
		return
	}
	tempFileName = tmpf.Name()
	tmpf.Close()
	sqliteDB, err := sqlx.Open("sqlite3", tempFileName)
	if err != nil {
		return
	}
	defer sqliteDB.Close()
	//获取字段名称，并创建导出的表结构

	s = data.Bind(db.DriverName(), strSql)
	rows, err := db.Query(s, sqlParams...)
	if err != nil {
		err = common.NewSQLError(err, s, sqlParams...)
		log.Println(err)
		return
	}
	defer rows.Close()
	cols, err := rows.Columns()
	if err != nil {
		return
	}
	colsIndex := map[string]bool{}

	colsDef := []*schema.Column{}
	pkDef := []string{}
	numQUESTION := make([]string, len(cols))
	for i, v := range cols {
		cols[i] = strings.ToUpper(v)
		colsIndex[cols[i]] = true
		colDef := &schema.Column{
			Name:      cols[i],
			Type:      schema.TypeString,
			Null:      true,
			MaxLength: -1,
		}
		for _, one := range typeColumns {
			if one.Name == v {
				colDef.Type = one.Type
				break
			}
		}
		colsDef = append(colsDef, colDef)
		numQUESTION[i] = "?"
	}
	//检查唯一字段是否在导出中，以确定主键
	bContain := true
	for _, v := range ucols {
		if _, ok := colsIndex[v]; !ok {
			bContain = false
			break
		}
	}
	if bContain {
		pkDef = ucols
	}
	outTable := schema.NewTable(tableName)
	outTable.Columns = colsDef
	outTable.PrimaryKeys = pkDef

	list, err := schema.Find(sqliteDB.DriverName()).CreateTableSQL(sqliteDB, outTable)
	if err != nil {
		return
	}
	if err = common.BatchRun(sqliteDB, list); err != nil {
		return
	}

	values := make([]interface{}, len(cols))
	for i := range values {
		values[i] = new(interface{})
	}
	startTime := time.Now()
	beginTime := startTime
	progressFunc(fmt.Sprintf("start export total %d records", rowCount))
	insertSql := fmt.Sprintf("INSERT INTO %s(%s)values(%s)", tableName, strings.Join(cols, ","), strings.Join(numQUESTION, ","))
	sqliteTx, err := sqliteDB.Beginx()
	if err != nil {
		return
	}
	insertStmt, err := sqliteTx.Prepare(insertSql)
	if err != nil {
		return
	}
	//写入每行数据,使用批量提交，5秒提交一次
	var icount int64
	var batchCount int64
	for rows.Next() {
		if err = rows.Scan(values...); err != nil {
			sqliteTx.Rollback()
			return
		}
		if _, err = insertStmt.Exec(values...); err != nil {
			sqliteTx.Rollback()
			return
		}
		icount++
		batchCount++
		//批量提交
		totalSec := time.Since(startTime).Seconds()
		if totalSec >= 5 {
			if err = insertStmt.Close(); err != nil {
				return
			}
			if err = sqliteTx.Commit(); err != nil {
				return
			}
			batchCount = 0
			sqliteTx, err = sqliteDB.Beginx()
			if err != nil {
				return
			}
			insertStmt, err = sqliteTx.Prepare(insertSql)
			if err != nil {
				return
			}
			progressFunc(fmt.Sprintf("\t%.2f%%\t%d/%d\t%.2fs", 100.0*float64(icount)/float64(rowCount), icount, rowCount, totalSec))
			startTime = time.Now()
		}
	}
	if batchCount > 0 {
		if err = insertStmt.Close(); err != nil {
			return
		}
		if err = sqliteTx.Commit(); err != nil {
			return
		}
	}
	if err = rows.Err(); err != nil {
		return
	}
	progressFunc(fmt.Sprintf("total %d records exported %.2fs", icount, time.Since(beginTime).Seconds()))

	return
}
func ExportSqlite3(db ddb.DB, tableName string, typeColumns []*scan.ColumnType,
	ucols []string, strSql string, w io.Writer, progressFunc func(string),
	sqlParams ...interface{}) (err error) {
	fileName, err := exportSqlite(db, tableName, typeColumns, ucols, strSql, progressFunc, sqlParams...)
	startTime := time.Now()
	if err != nil {
		return
	}
	f, err := os.Open(fileName)
	if err != nil {
		return
	}
	defer f.Close()
	if _, err = io.Copy(w, f); err != nil {
		return
	}
	progressFunc(fmt.Sprintf("finished by %.2fs", time.Since(startTime).Seconds()))
	return
}
