package common

import (
	"dbweb/core"
	"dbweb/lib/safe"
	"encoding/csv"
	"encoding/json"
	"fmt"
	"io"
	"io/ioutil"
	"net/url"
	"os"
	"path/filepath"
	"strings"

	"github.com/linlexing/dbx/ddb"

	"github.com/linlexing/dbx/data"
	"github.com/linlexing/dbx/render"

	"dbweb/modules/common/impt"

	log "github.com/Sirupsen/logrus"
	"github.com/linlexing/dbx/schema"

	"github.com/spkg/bom"

	"golang.org/x/text/encoding/simplifiedchinese"
	"golang.org/x/text/transform"
)

const (
	ctrlImportName = "import"
)

type importTaskParam struct {
	ImportParam   *impt.ImportParam
	PostParam     *impt.ImportPostParam
	ImportFile    string
	TempTableName string
}

//Import 该功能导入任意的表，第一个主键值被识别成表名
type Import struct {
}
type ImportParamField struct {
	Name   string
	Style  string //NORMAL-正常导入字段，MUST-必须导入的字段，FILL-自动填充的字段，SKIP-忽略的字段
	Value  string
	Remark string
}

type importRemoveResult struct {
	TableTotal     int64
	TableCanRemove int64
}

func init() {
	core.RegisterFun(ctrlImportName, new(Import))
	core.RegisterElementType(ctrlImportName, impt.ModelImportElement{})
}

//ElementParamRender 是工作元素定义时自动调用
func (m *Import) ElementParamRender(p *core.BillGetHandleArgs) {
	jsonParams := impt.ImportParam{}
	if bys := safe.Bytea(p.Record.Main["PARAMS"]); bys != nil {
		if err := json.Unmarshal(bys, &jsonParams); err != nil {
			log.Println(err, string(bys))
		}
	}
	dbNames := core.OuterDBNames(p.DB)
	newList := []string{""}
	dbNames = append(newList, dbNames...)
	p.More["OuterdbNames"] = dbNames

	if core.OuterDBExists(p.DB, safe.String(jsonParams.DB)) {
		tabledb := core.LoadOuterDB(p.DB, safe.String(jsonParams.DB))
		names, err := schema.Find(tabledb.DriverName()).TableNames(tabledb)
		if err != nil {
			core.LOG.Panic(err)
		}
		p.More["TableNames"] = names
	}

	p.More["JSONParams"] = jsonParams
}
func (m *Import) ElementAjaxProcess(p *core.BillGetHandleArgs) {
	query := p.Req.URL.Query()
	switch query.Get("action") {
	case "getTableList":
		db := core.LoadOuterDB(p.DB, query.Get("db"))
		names, err := schema.Find(db.DriverName()).TableNames(db)
		if err != nil {
			core.LOG.Error(err)
			core.LOG.Panic(err)
		}
		p.Render.JSON(200, names)
	case "getColumnList":
		db := core.LoadOuterDB(p.DB, query.Get("db"))
		tabName := query.Get("table")
		if len(tabName) > 0 {
			table, err := schema.Find(db.DriverName()).OpenTable(db, tabName)
			if err != nil {
				core.LOG.Error(err)
				core.LOG.Panic(err)
			}

			p.Render.JSON(200, func() []string {
				rev := []string{}
				for _, one := range table.Columns {
					rev = append(rev, one.Name)
				}
				return rev
			}())
		} else {
			p.Render.JSON(200, []string{})
		}

	default:
		core.LOG.Panic("not impl")
	}

}
func (m *Import) imprt(t *core.TaskRun) string {
	more := map[string]interface{}{}
	for k, v := range t.Req.URL.Query() {
		if strings.HasPrefix(k, "more_") {
			more[k[5:]] = t.User.DecodeQueryValue(v[0])
		}
	}
	param := t.Param.(importTaskParam).ImportParam
	post := t.Param.(importTaskParam).PostParam
	importFileName := t.Param.(importTaskParam).ImportFile
	tmpTabName := t.Param.(importTaskParam).TempTableName
	str, err := impt.Imprt(core.LoadOuterDB(t.Db, param.DB), importFileName,
		tmpTabName, param, post, more, t.User, t.LSession, t)
	if err != nil {
		core.LOG.Panic(err)
	}
	return str

}
func (m *Import) Post(p *core.ElementHandleArgs) {
	q := p.Req.URL.Query()
	if step := q.Get("step"); len(step) == 0 {
		if err := p.Req.ParseMultipartForm(1024 * 1024 * 10); err != nil {
			core.LOG.Panic(err)
		}
		file, _, err := p.Req.FormFile("upFile")
		if err != nil {
			core.LOG.Panic(err)
		}
		defer file.Close()
		//生成唯一的句柄，用于文件名，该文件保存在专用目录中，后台清理过期的
		if err := os.MkdirAll(impt.ImportPath, os.ModeDir); err != nil {
			core.LOG.Panic(err)
		}
		f, err := ioutil.TempFile(impt.ImportPath, "IMP")
		if err != nil {
			core.LOG.Println(err)
			core.LOG.Panic(err)
		}
		defer f.Close()
		if _, err := io.Copy(f, file); err != nil {
			core.LOG.Panic(err)
		}
		//跳转到第二步处理
		u, _ := url.Parse(p.Req.URL.String())
		q.Set("step", "1")
		q.Set("f", filepath.Base(f.Name()))
		u.RawQuery = q.Encode()
		p.Redirect(u.String())
	} else {
		jsonParams := &impt.ImportParam{}
		if err := json.Unmarshal(p.Element.Params, jsonParams); err != nil {
			core.LOG.Println(err, string(p.Element.Params))

		}
		importPostParam := &impt.ImportPostParam{}
		if err := json.Unmarshal([]byte(p.Req.FormValue("param")), importPostParam); err != nil {
			core.LOG.Panic(err)
		}
		//下面开始导入
		imptDB := core.LoadOuterDB(p.DB, jsonParams.DB)
		tmpTableName, err := ddb.GetTempTableName(imptDB, "T")
		if err != nil {
			core.LOG.Panic(err)
		}
		config := importTaskParam{
			ImportParam:   jsonParams,
			PostParam:     importPostParam,
			ImportFile:    q.Get("f"),
			TempTableName: tmpTableName,
		}
		task := &core.TaskRun{
			Db:       p.DB,
			Name:     p.Element.DisplayLabel(),
			User:     p.User,
			LSession: p.LSession,
			Req:      p.Req,
			ClientIP: p.Req.RemoteAddr,
			Param:    config,
			Func:     m.imprt,
		}
		if err := task.GoRun(); err != nil {
			p.RenderError(err.Error())

		} else {
			p.GotoMessage(fmt.Sprintf("导入操作将在后台运行，可以<a href='%s' target='_blank'>点击查看进度</a>", p.User.Sign("/browsetask/"+task.ID())))
		}
	}
}
func trimStrings(list []string) {
	for i, v := range list {
		list[i] = strings.TrimSpace(v)
	}
}
func (m *Import) buildSampleData(fileName, charset, format string, customSplit rune, firstTitle bool) (title []string, rows [][]string, err error) {
	file, err := os.Open(fileName)
	if err != nil {
		return
	}
	//需要去掉bom，参见 https://github.com/golang/go/issues/9588
	f := bom.NewReader(file)
	defer file.Close()
	var r io.Reader
	switch charset {
	case "UTF8":
		r = f
	case "GBK":
		r = transform.NewReader(f, simplifiedchinese.GBK.NewDecoder())

	default:
		core.LOG.Panic("not impl")
	}
	csvReader := csv.NewReader(r)
	switch format {
	case "CSV":
	case "TAB":
		csvReader.Comma = '\t'
	case "CUSTOM":
		csvReader.Comma = customSplit
	default:
		core.LOG.Panic("not impl")
	}
	title = []string{}
	rows = [][]string{}
	//读取首行
	row, err := csvReader.Read()
	if err == io.EOF {
		core.LOG.Println("read first line eof")
		return
	}
	trimStrings(row)
	maxLines := 100
	if firstTitle {
		title = row
	} else {
		rows = append(rows, row)
		for i := range row {
			title = append(title, fmt.Sprintf("col%d", i))
		}
		maxLines--
	}
	//限定每行必须要和第一行列数一致
	csvReader.FieldsPerRecord = len(title)
	for i := 0; i < maxLines; i++ {
		row, err = csvReader.Read()
		if err == io.EOF {
			break
		}
		if err != nil {
			fmt.Println(i, row)
			return
		}
		trimStrings(row)
		rows = append(rows, row)
	}
	return
}
func (m *Import) getColumns(p *core.ElementHandleArgs) (string, int64, int64, interface{}) {
	jsonParams := impt.ImportParam{}
	var err error
	if err = json.Unmarshal(p.Element.Params, &jsonParams); err != nil {
		core.LOG.Println(err, string(p.Element.Params))

	}
	//收集query值，作为more传入sql语句渲染
	more := map[string]interface{}{}
	for k, v := range p.Req.URL.Query() {
		if strings.HasPrefix(k, "more_") {
			more[k[5:]] = p.User.DecodeQueryValue(v[0])
		}
	}
	where := jsonParams.Where
	if len(where) > 0 {
		where, err = render.RenderSQL(where, map[string]interface{}{
			"User":     p.User,
			"LSession": p.LSession,
			"More":     more,
		})
		if err != nil {
			core.LOG.Panic(err)
		}
	}

	tabDB := core.LoadOuterDB(p.DB, jsonParams.DB)
	tab, err := data.OpenTable(tabDB.DriverName(), tabDB, jsonParams.Table)
	if err != nil {
		core.LOG.Panic(err)
	}
	//下面读取数据库字段的类型
	fldList := []*struct {
		Name       string
		Type       string
		MaxLength  int
		Style      string
		PrimaryKey bool
		Remark     string
	}{}
	tabPKS := tab.PrimaryKeys
	for _, field := range jsonParams.Fields {
		switch field.Style {
		case "MUST", "NORMAL":
		default:
			continue
		}
		fs := &struct {
			Name       string
			Type       string
			MaxLength  int
			Style      string
			PrimaryKey bool
			Remark     string
		}{
			Name:   field.Name,
			Style:  field.Style,
			Remark: field.Remark,
		}
		if fld := tab.ColumnByName(field.Name); fld == nil {
			core.LOG.Println("field:", field.Name, "can't found")
		} else {
			fs.Type = fld.Type.ChineseString()
			fs.MaxLength = fld.MaxLength
		}
		for _, one := range tabPKS {
			if one == field.Name {
				fs.PrimaryKey = true
			}
		}
		fldList = append(fldList, fs)
	}
	return jsonParams.Table, //表名
		tab.MustCount(""), //总记录数
		tab.MustCount(where), //可管理记录数
		fldList
}
func (m *Import) removeTableData(p *core.ElementHandleArgs) (int64, int64) {
	var err error
	jsonParams := impt.ImportParam{}
	if err := json.Unmarshal(p.Element.Params, &jsonParams); err != nil {
		core.LOG.Println(err, string(p.Element.Params))

	}
	//收集query值，作为more传入sql语句渲染
	more := map[string]interface{}{}
	for k, v := range p.Req.URL.Query() {
		if strings.HasPrefix(k, "more_") {
			more[k[5:]] = p.User.DecodeQueryValue(v[0])
		}
	}
	where := jsonParams.Where
	if len(where) > 0 {
		if where, err = render.RenderSQL(where, map[string]interface{}{
			"User":     p.User,
			"LSession": p.LSession,
			"More":     more,
		}); err != nil {
			core.LOG.Panic(err)
		}
	}

	tabDB := core.LoadOuterDB(p.DB, jsonParams.DB)
	tab, err := data.OpenTable(tabDB.DriverName(), tabDB, jsonParams.Table)
	if err != nil {
		core.LOG.Panic(err)
	}
	strSql := ""
	if len(strings.TrimSpace(where)) > 0 {
		strSql = fmt.Sprintf("delete from %s where %s", jsonParams.Table, where)
	} else {
		strSql = fmt.Sprintf("delete from %s", jsonParams.Table)
	}
	if _, err := tabDB.Exec(strSql); err != nil {
		core.LOG.Panic(err)
	}
	return tab.MustCount(""), //总记录数
		tab.MustCount(where)
}
func (m *Import) Get(p *core.ElementHandleArgs) {
	q := p.Req.URL.Query()
	if q.Get("action") == "removetabledata" {
		rev := importRemoveResult{}
		rev.TableTotal, rev.TableCanRemove = m.removeTableData(p)
		p.Render.JSON(200, rev)
		return
	}
	if q.Get("action") == "refreshsample" {
		//下面生成样本数据
		fileName := filepath.Join(impt.ImportPath, q.Get("f"))
		sampleData := &struct {
			Title []string
			Rows  [][]string
			Error string
		}{}
		splitRune := ','
		if q.Get("Format") == "CUSTOM" {
			splitRune = []rune(q.Get("CustomSplit"))[0]
		}
		title, rows, err := m.buildSampleData(fileName, q.Get("Charset"), q.Get("Format"), splitRune, safe.Bool(q.Get("FirstHead")))
		if err != nil && err != io.EOF {
			core.LOG.Println("error:", err)
			sampleData.Error = err.Error()
		} else {
			sampleData.Title = title
			sampleData.Rows = rows
		}
		p.Render.JSON(200, sampleData)
		return
	}
	p.More["TableName"], p.More["TableTotal"],
		p.More["TableCanRemove"], p.More["Fields"] = m.getColumns(p)
	if step := q.Get("step"); len(step) > 0 {
		prevUrl, _ := url.Parse(p.Req.URL.String())
		q.Del("step")
		prevUrl.RawQuery = q.Encode()
		p.More["PrevUrl"] = prevUrl.String()
		//下面生成样本数据
		fileName := filepath.Join(impt.ImportPath, q.Get("f"))
		sampleData := &struct {
			Title []string
			Rows  [][]string
			Error string
		}{}
		title, rows, err := m.buildSampleData(fileName, "UTF8", "CSV", ',', true)
		if err != nil && err != io.EOF {
			sampleData.Error = err.Error()
		} else {
			sampleData.Title = title
			sampleData.Rows = rows
		}
		p.More["SampleData"] = sampleData
	}
	p.HTML()
}
