package impt

import (
	"dbweb/core"
	"fmt"
	"io"
	"path/filepath"
	"strings"
	"time"

	sqlrender "github.com/linlexing/dbx/render"

	"github.com/linlexing/dbx/ddb"

	"dbweb/lib/lsession"

	"github.com/linlexing/dbx/data"
	"github.com/linlexing/dbx/schema"
)

type addProgresser interface {
	AddProgress(string)
}

func importData(f *file, rd *render, wr *tabWriter, ap addProgresser) (insertReport []string, err error) {
	insertReport = []string{}
	var icount int64
	var iskip int64
	var idup int64
	startTime := time.Now()
	prevTime := startTime

	for {

		row, err := f.Read()
		if err == io.EOF {
			insertReport = append(insertReport,
				fmt.Sprintf("共处理了%d条记录,其中空码:%d,重码:%d,实际导入%d条记录",
					icount, iskip, idup, icount-iskip-idup))
			ap.AddProgress(fmt.Sprintf("处理了%d条记录,其中空码:%d,重码:%d,实际导入%d条记录,耗时%.2f秒,下一步是合并...", icount, iskip, idup, icount-iskip-idup, time.Since(startTime).Seconds()))
			err = nil
			break
		}
		if err != nil {
			err = &importError{RowNo: icount, Data: row, Err: err}
			core.LOG.Panic("import panic at loop", err.Error())
			return nil, err
		}
		data, err := rd.renderLine(row)
		if err != nil {
			err = &importError{RowNo: icount, Data: row, Err: err}
			return nil, err
		}
		err = wr.write(data)
		switch terr := err.(type) {
		case *errNullPk:
			iskip++
			insertReport = append(insertReport,
				fmt.Sprintf("第%d行中[%s]为空，被跳过", icount,
					terr.Field()))
		case *errDupPk:
			idup++
			insertReport = append(insertReport,
				fmt.Sprintf("第%d行中[%s]的值[%s]和第%d行的重复，第%d行的数据被该行覆盖",
					icount, terr.pkNames, terr.pkValues, terr.preLine, terr.preLine))
		case nil:
		default:
			err = &importError{RowNo: icount, Data: row, Err: err}
			return nil, err
		}

		if time.Since(prevTime).Seconds() >= 5 {
			ap.AddProgress(fmt.Sprintf("处理了%d条记录,其中空码%d,重码%d,实际导入%d条记录", icount, iskip, idup, icount-iskip-idup))
			prevTime = time.Now()
		}
		icount++
	}

	return
}

//Imprt 导入数据
func Imprt(tabDB ddb.TxDB, importFile, tmpTableName string, importParam *ImportParam,
	postParam *ImportPostParam, more map[string]interface{}, user *core.User,
	sess *lsession.Session, ap addProgresser) (result string, err error) {

	//插入的报告，用于记录跳过的空行、被忽略的重码行
	insertReport := []string{}

	tab, err := schema.Find(tabDB.DriverName()).OpenTable(tabDB, importParam.Table)
	if err != nil {
		return
	}
	var cusSplit rune
	if len(postParam.Param.CustomSplit) > 0 {
		cusSplit = []rune(postParam.Param.CustomSplit)[0]
	}
	file, err := openFile(filepath.Join(ImportPath, importFile),
		postParam.Param.Charset, postParam.Param.Format, cusSplit, postParam.Param.FirstHead)
	if err != nil {
		return
	}
	defer file.close()
	cmp := &renderCompiler{
		param:   importParam,
		post:    postParam,
		more:    more,
		user:    user,
		session: sess,
		table:   tab,
		title:   file.Title,
	}
	rd, err := cmp.compile()
	if err != nil {
		return
	}

	//生成临时表，构造insert语句
	tT := schema.NewTable(tmpTableName)
	tT.Columns = tab.Columns
	tT.PrimaryKeys = tab.PrimaryKeys
	if err = tT.Update(tabDB.DriverName(), tabDB); err != nil {
		return
	}

	if err = ddb.RunAtTx(tabDB, func(tx ddb.Txer) error {
		tmpTab := data.NewTable(tx.DriverName(), tx, tT)
		wr, err := newTabWriter(tmpTab, rd.ColumnNames(), tx)
		if err != nil {
			return err
		}
		defer wr.close()

		startTime := time.Now()
		insertReport, err = importData(file, rd, wr, ap)
		if err != nil {
			return err
		}
		//下面开始临时表数据更新到正式表中，相同主键的被覆盖，多余的不会自动删除
		dataDest := data.NewTable(tx.DriverName(), tx, tab)
		//先运行beforesql
		if len(importParam.Before) > 0 {
			strSQL, err := sqlrender.RenderSQL(importParam.Before, map[string]interface{}{
				"User":            user,
				"ImportTempTable": tmpTableName,
			})
			if err != nil {
				return err
			}
			if _, err := tx.Exec(strSQL); err != nil {
				return err
			}
		}
		if err = dataDest.Merge(tmpTableName, rd.ColumnNames()...); err != nil {
			return err
		}
		if len(importParam.After) > 0 {
			strSQL, err := sqlrender.RenderSQL(importParam.After, map[string]interface{}{
				"User":            user,
				"ImportTempTable": tmpTableName,
			})
			if err != nil {
				return err
			}
			if _, err := tx.Exec(strSQL); err != nil {
				return err
			}
		}
		if _, err := tx.Exec("drop table " + tmpTableName); err != nil {
			return err
		}
		ap.AddProgress(fmt.Sprintf("合并 %s 到 %s,耗时:%.2f秒", tmpTableName, tab.FullName(), time.Since(startTime).Seconds()))
		startTime = time.Now()
		//最后进行统计信息的收集
		switch tx.DriverName() {
		case "oci8":
			schema := tab.Schema
			if len(schema) == 0 {
				if err := tx.QueryRow("select user from dual").Scan(&schema); err != nil {
					return err
				}
			}
			if _, err := tx.Exec(fmt.Sprintf(
				"begin DBMS_STATS.gather_table_stats(ownname => '%s',tabname => '%s');end;",
				schema,
				dataDest.Name)); err != nil {
				return err
			}
			ap.AddProgress(fmt.Sprintf("收集表[%s]的统计信息,耗时:%.2f秒", dataDest.FullName(), time.Since(startTime).Seconds()))
		}
		return nil

	}); err != nil {
		return
	}

	return strings.Join(insertReport, "\n"), nil
}
