package common

import (
	"archive/zip"
	"dbweb/core"
	"dbweb/lib/net"
	"dbweb/lib/safe"
	"encoding/json"
	"fmt"
	"io"
	"io/ioutil"
	"os"
	"strings"
	"time"

	"github.com/linlexing/dbx/ddb"

	"database/sql"

	"github.com/linlexing/dbx/data"
)

const (
	modelDumpBillName = "DUMPBILL"
	ctrlDumpName      = "dump"
)

type dumpTaskParam struct {
	DumpConfig  *core.DumpConfig
	OutFileName string
}

type dumpControler struct{}
type modelDumpBill struct {
	DumpName     string `dbx:"STR(50) PRIMARY KEY"`
	BillName     string `dbx:"STR(50) PRIMARY KEY"`
	Dept         string `dbx:"STR(50) NOT NULL"`
	DBName       string `dbx:"STR(50)"`
	TableNames   string `dbx:"STR NOT NULL"`
	ActionExists int64  `dbx:"INT NOT NULL"`
	SQLWhere     string `dbx:"STR"`
	SQLBefore    string `dbx:"STR"`
	SQLAfter     string `dbx:"STR"`
}

//DumpParam 是导入的参数类
type dumpElementParam struct {
	Name        string
	OutFileName string //存放结果的文件名，为空则放默认的任务id同名的zip文件中
	BeforeSQL   string //导入前运行的sql
	AfterSQL    string //导入后运行的sql
}
type modelDumpElement struct {
	core.ModelElementNoParam `yaml:",inline"`
	Params                   dumpElementParam `dbx:"STR"`
}

type dumpBillControler struct{}

func init() {
	core.RegisterModel(modelDumpBill{}, 1, false, modelDumpBillName)
	core.RegisterFun(ctrlDumpName, new(dumpControler))
	core.RegisterElementType(ctrlDumpName, modelDumpElement{})
	core.RegisterBill("dumpbill", new(dumpBillControler), modelDumpBillName)
}

//ElementParamRender 是工作元素定义时自动调用
func (d *dumpControler) ElementParamRender(p *core.BillGetHandleArgs) {
	jsonParams := dumpElementParam{}
	if bys := safe.Bytea(p.Record.Main["PARAMS"]); bys != nil {
		if err := json.Unmarshal(bys, &jsonParams); err != nil {
			core.LOG.Println(string(bys))
			core.LOG.Panic(err)
		}
	}
	p.More["JSONParams"] = jsonParams
}

//Get 业务
func (d *dumpControler) Get(p *core.ElementHandleArgs) {
	params := &dumpElementParam{}
	if err := json.Unmarshal(p.Element.Params, params); err != nil {
		core.LOG.Panic(err)
	}
	bills, err := getDumpBills(p.DB, params.Name, p.User)
	if err != nil {
		core.LOG.Panic(err)
	}
	dumpParam := &core.DumpParam{
		Name:      params.Name,
		Bills:     bills,
		BeforeSQL: params.BeforeSQL,
		AfterSQL:  params.AfterSQL,
	}
	p.More["Params"] = dumpParam
	p.HTML()
}

//导出数据
func dump(t *core.TaskRun) string {
	dumpConfig := t.Param.(dumpTaskParam).DumpConfig
	outFileName := t.Param.(dumpTaskParam).OutFileName
	tmpFile, err := ioutil.TempFile("", "dump")
	if err != nil {
		core.LOG.Panic(err)
	}
	dbFileName := tmpFile.Name()
	tmpFile.Close()
	//test
	dumpDB, err := ddb.Openx("sqlite3", dbFileName)
	if err != nil {
		core.LOG.Panic(err)
	}

	defer func() {
		if err = dumpDB.Close(); err != nil {
			core.LOG.Panic(err)
		}
		if err = os.Remove(dbFileName); err != nil {
			core.LOG.Panic(err)
		}
	}()
	if err = dumpConfig.DumpToFile(t.Db, t.User, dbFileName, t.AddProgress); err != nil {
		core.LOG.Panic(err)
	}
	var zipWrite *zip.Writer
	if len(outFileName) > 0 {
		f, err := os.Create(outFileName)
		if err != nil {
			core.LOG.Panic(err)
		}
		defer func() {
			if err := f.Close(); err != nil {
				core.LOG.Panic(err)
			}
		}()
		zipWrite = zip.NewWriter(f)
	} else {
		zipWrite = t.NeedZip()
	}
	defer func() {
		if err := zipWrite.Close(); err != nil {
			core.LOG.Panic(err)
		}
	}()
	w, err := zipWrite.Create("dump.sl3")
	if err != nil {
		core.LOG.Panic(err)
	}
	dbFile, err := os.Open(dbFileName)
	if err != nil {
		core.LOG.Panic(err)
	}
	defer func() {
		if err = dbFile.Close(); err != nil {
			core.LOG.Panic(err)
		}
	}()
	if _, err = io.Copy(w, dbFile); err != nil {
		core.LOG.Panic(err)
	}
	return ""
}

func getDumpBills(db ddb.DB, dumpName string, user *core.User) ([]*core.DumpBill, error) {
	//从数据库中读取所有的bill定义
	var vDumpName, vBillName string
	var vDept, vDBName, vTableNames, vSQLWhere, vSQLBefore, vSQLAfter sql.NullString
	var vActionExists sql.NullInt64
	strSQL, params, err := data.In(
		`select 
			DUMPNAME,
			BILLNAME,
			DEPT,
			DBNAME,
			TABLENAMES,
			ACTIONEXISTS,
			SQLWHERE,
			SQLBEFORE,
			SQLAFTER 
		from dumpbill where dumpname=? and dept in(?)`, dumpName, user.ToRootDeptCodesAndSelf())
	if err != nil {
		return nil, err
	}
	rows, err := db.Query(strSQL, params...)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	bills := []*core.DumpBill{}
	for rows.Next() {
		if err = rows.Scan(&vDumpName, &vBillName, &vDept, &vDBName, &vTableNames, &vActionExists,
			&vSQLWhere, &vSQLBefore, &vSQLAfter); err != nil {
			return nil, err
		}
		bills = append(bills, &core.DumpBill{
			Name:         vBillName,
			DBName:       vDBName.String,
			Tables:       strings.Split(vTableNames.String, ","),
			ActionExists: int32(vActionExists.Int64),
			Where:        vSQLWhere.String,
			BeforeSQL:    vSQLBefore.String,
			AfterSQL:     vSQLAfter.String,
		})
	}
	return bills, nil
}

//Post 业务
func (d *dumpControler) Post(p *core.ElementHandleArgs) {
	p.Req.ParseForm()
	params := &dumpElementParam{}
	if err := json.Unmarshal(p.Element.Params, params); err != nil {
		core.LOG.Panic(err)
	}
	selBills := []string{}
	for k := range p.Req.PostForm {
		names := strings.Split(k, "_")
		if len(names) > 1 {
			selBills = append(selBills, names[1])
		}
	}
	bills, err := getDumpBills(p.DB, params.Name, p.User)
	if err != nil {
		core.LOG.Panic(err)
	}
	dumpParams := &core.DumpParam{
		Name:      params.Name,
		Bills:     bills,
		BeforeSQL: params.BeforeSQL,
		AfterSQL:  params.AfterSQL,
	}

	dumpParams.FilterBill(selBills)
	dumpConfig := &core.DumpConfig{
		Param:                  dumpParams,
		UserName:               p.User.Name,
		DeptCode:               p.User.Dept.Code,
		ToRootDeptCodesAndSelf: p.User.ToRootDeptCodesAndSelf(),
		Time:           time.Now(),
		ServerIP:       net.GetLocalIP4(),
		ClientIP:       p.Req.RemoteAddr,
		DBVersions:     map[string]int64{},
		ModuleVersions: map[string]int64{},
	}
	for _, v := range core.ModelNames() {
		dumpConfig.DBVersions[v] = core.ModelVersion(v)
	}
	for k, v := range core.Versions {
		dumpConfig.ModuleVersions[k] = v[0].Version
	}

	//下面开始导出
	task := &core.TaskRun{
		Db:       p.DB,
		Name:     p.Element.DisplayLabel(),
		User:     p.User,
		ClientIP: p.Req.RemoteAddr,
		Param:    dumpTaskParam{DumpConfig: dumpConfig, OutFileName: params.OutFileName},
		Func:     dump,
	}
	if err := task.GoRun(); err != nil {
		p.RenderError(err.Error())

	} else {
		p.GotoMessage(fmt.Sprintf("导出操作将在后台运行，可以<a href='%s' target='_blank'>点击查看进度</a>", p.User.Sign("/browsetask/"+task.ID())))
	}
}
func (d *dumpBillControler) Get(p *core.BillGetHandleArgs) {
	list := []map[string]string{
		map[string]string{
			"Text":  "(当前数据库)",
			"Value": "",
		},
	}
	for _, str := range core.OuterDBNames(p.DB) {
		list = append(list, map[string]string{
			"Text":  str,
			"Value": str,
		})
	}
	p.Field("DBNAME").More["Options"] = list
	p.Field("ACTIONEXISTS").More["Options"] = []map[string]string{
		map[string]string{
			"Text":  "覆盖",
			"Value": "0",
		},
		map[string]string{
			"Text":  "跳过",
			"Value": "1",
		},
	}
	p.HTML()
}
