package impt

import (
	"database/sql"
	"github.com/linlexing/dbx/ddb"
	"errors"
	"fmt"
	"strings"

	"github.com/linlexing/dbx/data"
	"github.com/linlexing/dbx/schema"
)

type errNullPk struct {
	field string
}

func (e *errNullPk) Error() string {
	return e.field + "为空"
}
func (e *errNullPk) Field() string {
	return e.field
}

type errDupPk struct {
	preLine  int64
	pkValues string
	pkNames  string
}

func (d *errDupPk) Error() string {
	return "主键重复"
}

type tabWriter struct {
	pkInserted map[string]int64 //已经被插入的主键及行号
	table      *data.Table
	cols       []*schema.Column
	pkIndeies  []int
	ins        *sql.Stmt
	tx         ddb.Txer
	rowNum     int64
}

func newTabWriter(table *data.Table, cols []string, tx ddb.Txer) (*tabWriter, error) {
	tcols := []*schema.Column{}
	for _, fld := range cols {
		one := table.ColumnByName(fld)
		if one == nil {
			return nil, errors.New("not found the field:" + fld)
		}
		tcols = append(tcols, one)
	}
	idx := []int{}
	for _, pk := range table.PrimaryKeys {
		bFound := false
		for i, fld := range cols {
			if fld == pk {
				idx = append(idx, i)
				bFound = true
				break
			}
		}
		if !bFound {
			return nil, errors.New("the pk field:" + pk + " not found at import list")
		}
	}
	insSQL := fmt.Sprintf("insert into %s(%s)values(%s)",
		table.FullName(), strings.Join(cols, ","), strings.Join(strings.Split(
			strings.Repeat("?", len(cols)), ""), ","))
	stmt, err := tx.Prepare(insSQL)
	if err != nil {
		return nil, err
	}
	return &tabWriter{
		table:      table,
		tx:         tx,
		ins:        stmt,
		pkIndeies:  idx,
		cols:       tcols,
		pkInserted: map[string]int64{},
	}, nil

}
func (t *tabWriter) write(data []interface{}) error {
	t.rowNum++

	pkvals := []string{}
	for _, i := range t.pkIndeies {
		switch tv := data[i].(type) {
		case nil:
			return &errNullPk{field: t.cols[i].Name}
		case string:
			if len(tv) == 0 {
				return &errNullPk{field: t.cols[i].Name}
			}
		case []byte:
			if len(tv) == 0 {
				return &errNullPk{field: t.cols[i].Name}
			}
		}
		pkvals = append(pkvals, t.cols[i].Type.ToString(data[i]))
	}
	pkval := strings.Join(pkvals, ",")
	//如果是重码，则用update
	if iline, ok := t.pkInserted[pkval]; ok {
		rowUpdate := map[string]interface{}{}
		for i, v := range data {
			rowUpdate[t.cols[i].Name] = v
		}
		//采用save，因为save第一次尝试是update，不会有性能问题,没有涉及到的字段不会改变值
		if err := t.table.Save(rowUpdate); err != nil {
			return err
		}
		return &errDupPk{preLine: iline, pkValues: pkval, pkNames: strings.Join(
			t.table.PrimaryKeys, ",")}
	}
	t.pkInserted[pkval] = t.rowNum

	_, err := t.ins.Exec(data...)
	return err
}
func (t *tabWriter) close() error {
	return t.ins.Close()
}
