package bill

import (
	"errors"
	"fmt"

	"github.com/linlexing/dbx/ddb"

	"github.com/linlexing/mapfun"

	"log"

	"database/sql"
	"strings"

	"github.com/linlexing/dbx/schema"
)

//FormerNameEvent 返回model的曾用名，包括主表和明细表
type FormerNameEvent interface {
	OnFormerName() ([]string, map[string][]string)
}

//Bill 一个简单的主表-明细表的集合，提供了读写的方法
type Bill struct {
	Main  *ddb.Table
	Child map[string]*ddb.Table
}

//NewBill 返回一个新单据，直接从一个struct中构建，没有数据库读取
func NewBill(db ddb.DB, meta interface{}, tabNames ...string) (*Bill, error) {
	tabs, err := schema.TableFromStruct(meta, tabNames...)
	if err != nil {
		return nil, err
	}
	if len(tabs) == 0 {
		return nil, errors.New("no table")
	}
	childTables := map[string]*schema.Table{}
	for i := 1; i < len(tabs); i++ {
		childTables[tabs[i].FullName()] = tabs[i]
	}
	if evt, ok := meta.(FormerNameEvent); ok {
		mainFName, childFName := evt.OnFormerName()
		tabs[0].FormerName = mainFName
		for k, one := range childFName {
			childTables[k].FormerName = one
		}
	}
	rev := Bill{
		Main:  ddb.NewTable(db, tabs[0]),
		Child: map[string]*ddb.Table{},
	}
	for i := 1; i < len(tabs); i++ {
		rev.Child[tabs[i].FullName()] = ddb.NewTable(db, tabs[i])
	}
	return &rev, nil
}

//OpenBill 返回一个新单据，传入表名，从数据库中获取表结构
func OpenBill(db ddb.DB, main string, child ...string) (*Bill, error) {
	if len(main) == 0 {
		return nil, errors.New("not main table")
	}
	m, err := ddb.OpenTable(db, main)
	if err != nil {
		return nil, err
	}
	if len(m.PrimaryKeys) == 0 {
		return nil, errors.New("table:" + main + " no primary key")
	}
	r := &Bill{
		Main:  m,
		Child: map[string]*ddb.Table{},
	}
	for _, v := range child {
		ct, err := ddb.OpenTable(db, v)
		if err != nil {
			return nil, err
		}
		if len(ct.PrimaryKeys) == 0 {
			return nil, errors.New("table:" + v + " no primary key")
		}
		//用表全名称
		r.Child[ct.FullName()] = ct
	}
	return r, nil
}

//Clone 重新复制一个Bill,设置为新的DB，不用SetDB的原因是Bill可能是并发访问，会有影响
func (b *Bill) Clone(db ddb.DB) *Bill {
	rev := &Bill{
		Main:  ddb.NewTable(db, b.Main.Table.Table),
		Child: map[string]*ddb.Table{},
	}
	for k, v := range b.Child {
		rev.Child[k] = ddb.NewTable(db, v.Table.Table)
	}
	return rev
}

//DB 返回单据的DB操作类
func (b *Bill) DB() ddb.DB {
	return b.Main.DB()
}
func (b *Bill) keyValues(record *Record) []interface{} {
	return b.Main.KeyValues(record.Main)
}
func (b *Bill) exists(keyValues ...interface{}) (bool, error) {
	return b.Main.KeyExists(keyValues...)
}

//ChangeKeyValues 修改一个记录的主键值
func (b *Bill) ChangeKeyValues(record *Record, pks ...interface{}) {
	if record.IsEmpty() {
		return
	}
	//修改主表记录
	for i, name := range b.Main.PrimaryKeys {
		record.Main[name] = b.Main.ColumnByName(name).Type.ParseScan(pks[i])
	}
	//修改明细表记录，明细表主键的前几位必须是主表主键
	for _, child := range b.Child {
		childPk := child.PrimaryKeys
		for i, name := range b.Main.PrimaryKeys {
			for _, row := range record.Child[child.FullName()] {
				row[childPk[i]] = record.Main[name]
			}
		}
	}
	return
}

//Remove 移除一个记录，如果没有找到记录，返回一个error
func (b *Bill) Remove(oldRecord *Record) (err error) {
	var iCount int64
	if iCount, err = b.Main.Remove(oldRecord.Main); err != nil {
		return err
	}
	if iCount == 0 {
		err = errors.New("can't foud then record")
		log.Println(err)
		return
	}
	if len(oldRecord.Child) == 0 {
		return nil
	}
	mainKeyValues := b.Main.KeyValues(oldRecord.Main)
	for k, v := range oldRecord.Child {
		if len(v) == 0 {
			continue
		}
		if iCount, err = b.Child[k].Delete(v); err != nil {
			return err
		}
		//如果找不到要删除的字段，说明有其他用户操作过，返回一个错误
		if iCount == 0 {
			err = errors.New("can't foud then record")
			log.Println(err)
			return
		}
		//删除完成后，还需要检查是否有剩余，有的话说明明细记录已经被人改动过

		where := []string{}
		findKeys := []interface{}{}

		for si, sv := range mainKeyValues {
			where = append(where, fmt.Sprintf("%s=?", b.Child[k].PrimaryKeys[si]))
			findKeys = append(findKeys, sv)
		}
		has, err := b.Child[k].Exists(strings.Join(where, " and\n"), findKeys...)
		if err != nil {
			return err
		}
		if has {
			return fmt.Errorf("删除完成后，%s 中还含有数据", k)
		}
	}
	return nil
}

//Insert 插入一个单据记录
func (b *Bill) Insert(record *Record) error {
	if err := b.Main.Insert([]map[string]interface{}{record.Main}); err != nil {
		return err
	}
	if len(record.Child) == 0 {
		return nil
	}
	for k, v := range record.Child {
		if err := b.Child[k].Insert(v); err != nil {
			return err
		}
	}
	return nil
}

//Save 保存一个记录，如果对应的记录存在则被覆盖
func (b *Bill) Save(record *Record) error {
	//主表save
	if err := b.Main.Save(record.Main); err != nil {
		return err
	}
	if len(b.Child) == 0 {
		return nil
	}
	//明细表取出主键进行判断
	//删除库中多余的记录
	//然后逐个save
	mainKeyValues := b.Main.KeyValues(record.Main)
	for tabName, tab := range b.Child {
		cpk := tab.PrimaryKeys
		//组合where条件去查询
		where := []string{}
		whereVals := []interface{}{}
		for j, v := range mainKeyValues {
			where = append(where, fmt.Sprintf("%s=?", cpk[j]))
			whereVals = append(whereVals, v)
		}
		rows, err := tab.QueryRows(strings.Join(where, " and\n"), whereVals...)
		if err != nil {
			return err
		}
		for _, rv := range mapfun.Difference(rows, record.Child[tabName], cpk) {
			var iCount int64
			if iCount, err = tab.Remove(rv); err != nil {
				return err
			}
			if iCount == 0 {
				return errors.New("can't found remove record")
			}
		}
		for _, ins := range record.Child[tabName] {
			if err = tab.Save(ins); err != nil {
				return err
			}
		}

	}
	return nil
}

//Update 更新一个记录，旧记录的值必须要全相等
func (b *Bill) Update(oldRecord, newRecord *Record) (err error) {
	var iCount int64
	if iCount, err = b.Main.Update(oldRecord.Main, newRecord.Main); err != nil {
		return
	}
	if iCount == 0 {
		err = errors.New("can't found the update record at table " + b.Main.FullName())
	}
	if len(b.Child) == 0 {
		return
	}
	for _, v := range b.Child {
		if _, _, _, err = v.Replace(oldRecord.Child[v.FullName()],
			newRecord.Child[v.FullName()]); err != nil {
			return
		}
	}
	return nil
}

//Record 返回一个记录，根据主键值,找不到返回ErrNoRows
func (b *Bill) Record(keyValues ...interface{}) (result *Record, err error) {
	where := []string{}
	for i := range keyValues {
		where = append(where, fmt.Sprintf("%s=?", b.Main.PrimaryKeys[i]))
	}
	rows, err := b.Query(strings.Join(where, " and\n"), keyValues...)
	if err != nil {
		return
	}
	defer rows.Close()
	if !rows.Next() {
		err = sql.ErrNoRows
		return
	}
	return rows.Scan()
}

//Query 模拟一个db.Query，返回一个可以Next和Scan的对象
func (b *Bill) Query(where string, args ...interface{}) (result *Rows, err error) {
	var rows *sql.Rows
	rows, err = b.Main.Query(where, args...)
	if err != nil {
		return
	}
	result = &Rows{
		rows: rows,
		bill: b,
	}
	return
}

//QueryOrder 可以进行排序
func (b *Bill) QueryOrder(orderby []string, where string, args ...interface{}) (result *Rows, err error) {
	var rows *sql.Rows
	rows, err = b.Main.QueryOrder(orderby, where, args...)
	if err != nil {
		return
	}
	result = &Rows{
		rows: rows,
		bill: b,
	}
	return
}

//ImportFrom 从另一个单据批量导入数据，两个单据的结构必须完全一致
//可以传入一个主表的where条件,返回的totalImport是主表和明细表所有记录和
func (b *Bill) ImportFrom(srcBill *Bill, progressFunc func(message string),
	where string, args ...interface{}) (totalImport int64, err error) {
	tabNum := 1
	//先导入主表数据
	progressFunc(fmt.Sprintf("start import bill,total %d table", len(srcBill.Child)+1))
	progressFunc(fmt.Sprintf("import table 1/%d", len(srcBill.Child)+1))

	totalImport, err = b.Main.ImportFromTable(srcBill.Main.Table, progressFunc, where, args...)
	if err != nil {
		return
	}
	//在依次导入明细表
	for name, tab := range srcBill.Child {
		tabNum++
		progressFunc(fmt.Sprintf("import table %d/%d", tabNum, len(srcBill.Child)+1))
		btab, ok := b.Child[name]
		if !ok {
			err = errors.New("table " + name + " not found at the bill")
			log.Println(err)
			return
		}
		//构造关联查询
		childWhere := []string{}
		if len(where) > 0 {
			//加括号是防止where里面有or
			childWhere = []string{"(" + where + ")"}
		}
		for i, v := range srcBill.Main.PrimaryKeys {
			childWhere = append(childWhere, fmt.Sprintf("%s.%s=m_tab.%s",
				tab.FullName(), tab.PrimaryKeys[i], v))
		}
		childWhereStr := fmt.Sprintf("exists(select 1 from %s m_tab where %s)", srcBill.Main.FullName(),
			strings.Join(childWhere, " and\n"))
		icount, err := btab.ImportFromTable(tab.Table, progressFunc, childWhereStr, args...)
		if err != nil {
			return -1, err
		}
		totalImport += icount
	}
	return
}
