package core

import (
	"dbweb/lib/model"
	"fmt"
	"os"
	"strings"

	"dbweb/lib/bill"
	"unicode"

	"github.com/Sirupsen/logrus"
	"github.com/linlexing/dbx/common"
	"github.com/linlexing/dbx/ddb"
	"github.com/linlexing/dbx/schema"
)

type internalModel struct {
	name    string
	crtable bool //是否自动创建crtable
	dbname  string
	mtype   interface{}
	oldver  int64
	dbver   int64
	m       *model.Model
}

//Model汉字字段 嵌入此结构，将自动将所有字段的名称首字母去除作为数据库字段名
type Model汉字字段 struct{}

//ConvertFieldName 去除首字母
func (m Model汉字字段) ConvertFieldName(childName, fieldName string) string {
	return fieldName[1:]
}

//ModelUnderScore 驼峰转换成下划线,连续大写不加下划线
type ModelUnderScore struct{}

//ConvertFieldName 驼峰转换成下划线
func (m ModelUnderScore) ConvertFieldName(childName, fieldName string) string {
	rev := ""
	preUpper := true
	for i, c := range fieldName {
		if i == 0 {
			rev += string(unicode.ToLower(c))
			continue
		}
		if !preUpper && unicode.IsUpper(c) {
			rev += "_" + string(unicode.ToLower(c))
		} else {
			rev += string(c)
		}
		preUpper = unicode.IsUpper(c)
	}
	return rev
}

type modelListType struct {
	models map[string]*internalModel
	names  []string
}

func (m *modelListType) Names() []string {
	return m.names
}

//FindByName 按名称搜索，不区分大小写
func (m *modelListType) FindByName(name string) *internalModel {
	if f, ok := m.models[name]; ok {
		return f
	}
	//进行曾用名的查询
	for k, one := range m.models {
		if strings.EqualFold(name, k) {
			return one
		}
		if evt, ok := one.mtype.(bill.FormerNameEvent); ok {
			nm, _ := evt.OnFormerName()
			for _, s := range nm {
				if strings.EqualFold(s, name) {
					return one
				}
			}
		}
	}
	return nil
}

var (
	modelList = &modelListType{
		models: map[string]*internalModel{},
	}
)

type afterSchemaUpdateEvent interface {
	OnAfterSchemaUpdate(md *model.Model, oldver int64) error
}
type modelInitEvent interface {
	OnModelInit(md *model.Model, oldver int64) error
}

//RegisterModel 注册一个Model，并没有应用到数据库
func RegisterModel(md interface{}, dbver int64, crtable bool, name ...string) {
	checkLog()
	var dbname, mname string
	if len(name) == 2 {
		mname = name[0]
		dbname = name[1]
	}
	if len(name) == 1 {
		mname = name[0]
	}
	if len(name) > 2 {
		LOG.Panic(fmt.Errorf("names %v length >2", name))
	}
	LOG.Println("register model ", mname)
	modelList.models[mname] = &internalModel{name: mname, dbname: dbname,
		crtable: crtable, mtype: md, dbver: dbver}
	modelList.names = append(modelList.names, mname)
	return
}

//NewModel 分配一个新的Model，挂载在一个db中，但没有数据库操作
func NewModel(db ddb.DB, name string) (*model.Model, error) {
	im, err := findModel(name)
	if err != nil {
		return nil, err
	}
	return im.new(db)
}

//MustModel 必须返回一个model，找不到则异常
func MustModel(db ddb.DB, name string) *model.Model {
	im, err := findModel(name)
	if err != nil {
		LOG.Panic(err)
	}
	m, err := im.new(db)
	if err != nil {
		LOG.Panic(err)
	}
	return m
}
func findModel(name string) (*internalModel, error) {
	if md := modelList.FindByName(name); md != nil {
		return md, nil
	}
	return nil, fmt.Errorf("findModel:not found model %s", name)
}
func mustFindModel(name string) *internalModel {
	if md := modelList.FindByName(name); md != nil {
		return md
	}
	LOG.WithFields(logrus.Fields{
		"name": name,
	}).Panic("findModel:not found model")
	return nil
}

//ModelNames 返回所有model的名称，主要用于dump
func ModelNames() []string {
	return modelList.Names()
}

//ModelVersion 返回model的版本，主要用于导出模块
func ModelVersion(name string) int64 {
	return mustFindModel(name).dbver
}

//LoadModel 调入一个Model，使用默认的数据库连接
func LoadModel(name string) *model.Model {
	return mustFindModel(name).m
}

//new 分配一个新的可操作数据库的Model,根据dbname 去获取新的db
func (m *internalModel) new(db ddb.DB) (*model.Model, error) {
	return model.New(db, m.mtype, m.name)
}
func (m *internalModel) checkResultTableName() string {
	ns := strings.Split(m.name, ".")

	if len(ns) == 2 {
		return fmt.Sprintf("%s.CR_%s", ns[0], ns[1])
	}
	if len(ns) == 1 {
		return fmt.Sprintf("CR_%s", ns[0])
	}
	LOG.Panic(fmt.Errorf("model %s invalid name", m.name))
	return ""
}
func (m *internalModel) checkResultTableFormerName() []string {

	if evt, ok := m.mtype.(bill.FormerNameEvent); ok {
		var rev []string
		list, _ := evt.OnFormerName()
		for _, one := range list {
			rev = append(rev, fmt.Sprintf("CR_%s", one))
		}
		return rev
	}
	return nil
}

//setup 初始化一个Model，更新其schema至数据库，会判断版本
func (m *internalModel) setup(metadb ddb.DB) error {
	if m.dbver < m.oldver {
		return fmt.Errorf("model %s oldver %d > newver %d", m.name, m.oldver, m.dbver)
	}
	db := metadb
	//如果是缺省数据库名称，则直接用当前的db，节省一个事务，且大部分情况都适合，且sqlite3不能多事务
	if len(m.dbname) > 0 {
		db = LoadOuterDB(metadb, m.dbname)
	}
	md, err := m.new(db)
	if err != nil {
		return err
	}

	//当版本相等时，不需要更新数据库结构
	if m.dbver == m.oldver {
		return nil
	}
	//一般情况下没有修改的sql了，但是modelver除外，会需要首次生成
	list, err := m.extractSchemaChanges(db)
	if err != nil {
		return err
	}
	if len(list) > 0 {
		if err = common.BatchRunAndPrint(db, list); err != nil {
			return err
		}
	}

	if evt, ok := m.mtype.(afterSchemaUpdateEvent); ok {
		if err := evt.OnAfterSchemaUpdate(md, m.oldver); err != nil {
			return err
		}
	}
	return nil
}

//extractChanges 比较当前定义和数据库实际结构，返回修改的语句，如果一致，返回nil
func (m *internalModel) extractSchemaChanges(metadb ddb.DB) (rev []string, err error) {
	if m.dbver < m.oldver {
		err = fmt.Errorf("model %s oldver %d > newver %d", m.name, m.oldver, m.dbver)
		return
	}
	db := metadb
	//如果是缺省数据库名称，则直接用当前的db，节省一个事务，且大部分情况都适合，且sqlite3不能多事务
	if len(m.dbname) > 0 {
		db = LoadOuterDB(metadb, m.dbname)
	}
	md, err := m.new(db)
	if err != nil {
		return
	}
	//初始化事件里面，一般存放数据结构的程序化赋值，所以要执行
	if evt, ok := m.mtype.(modelInitEvent); ok {
		if err := evt.OnModelInit(md, m.oldver); err != nil {
			return nil, err
		}
	}
	//当版本相等时，不需要更新数据库结构
	if m.dbver == m.oldver {
		return nil, nil
	}
	var list []string
	if list, err = md.ExtractSchemaChanges(); err != nil {
		return
	}
	rev = append(rev, list...)
	//如果有必要，创建CRTable
	if m.crtable {
		crTable, err := schema.TableFromStruct(modelCRTable{}, m.checkResultTableName())
		if err != nil {
			return nil, err
		}
		for _, one := range md.Bill().Main.PrimaryKeys {
			crTable[0].Columns = append(crTable[0].Columns,
				md.Bill().Main.ColumnByName(one).Clone())
			crTable[0].PrimaryKeys = append(crTable[0].PrimaryKeys, one)
		}
		//如果有曾用名，则也加上
		crTable[0].FormerName = m.checkResultTableFormerName()

		if list, err = crTable[0].Extract(db.DriverName(), db); err != nil {
			return nil, err
		}
		rev = append(rev, list...)
	}

	return
}

//初始化模块版本清单
func initModelVersion(db ddb.TxDB) error {
	//读取所有数据库model的版本，并检查代码中是否都存在
	exists, err := schema.Find(db.DriverName()).TableExists(db, ModelVerName)
	if err != nil {
		return err
	}
	if exists {
		modelVer, err := NewModel(db, ModelVerName)
		if err != nil {
			LOG.Println("new model", ModelVerName, "error")
			return err
		}
		rows, err := modelVer.Query("")
		if err != nil {
			LOG.Println(ModelVerName, "query error")
			return err
		}
		defer rows.Close()
		for rows.Next() {
			var out ModelModelVer
			if err = rows.Scan(&out); err != nil {
				LOG.Println(ModelVerName, "scan error")
				return err
			}
			//从数据库取得的名称，要进行全范围查询，包括曾用名
			if m := modelList.FindByName(out.Name); m == nil {
				return fmt.Errorf("initModelVersion:not found model %s", out.Name)
			} else {
				//顺便保存数据库版本
				m.oldver = out.Version
				if m.dbver < m.oldver {
					return fmt.Errorf("model:%s oldver:%d > current ver:%d", out.Name, m.oldver, m.dbver)
				}
			}
		}
	}
	return nil
}
func extractChangeSQL(db ddb.TxDB) (rev []string, err error) {
	//开始获取修改SQL
	for _, mname := range modelList.Names() {
		m := modelList.FindByName(mname)
		var sqlList []string
		if sqlList, err = m.extractSchemaChanges(db); err != nil {
			LOG.Println("setup model", m.name, "error")
			return
		}
		rev = append(rev, sqlList...)
	}
	return
}
func setupModels(db ddb.TxDB) error {
	return ddb.RunAtTx(db, func(tx ddb.Txer) (err error) {
		modelVer, err := NewModel(tx, ModelVerName)
		if err != nil {
			LOG.Println("new model", ModelVerName, "error")
			return err
		}
		for _, mname := range modelList.Names() {
			m := modelList.FindByName(mname)
			//数据库结构相等则不需要安装
			if m.oldver == m.dbver {
				continue
			}
			out := ModelModelVer{
				Name:    mname,
				Version: m.dbver,
			}
			if evt, ok := m.mtype.(afterSchemaUpdateEvent); ok {
				md, err := NewModel(tx, m.name)
				if err != nil {
					return err
				}

				if err := evt.OnAfterSchemaUpdate(md, m.oldver); err != nil {
					return err
				}
			}

			//最后更新数据库中的版本
			if err = modelVer.Set(out); err != nil {
				LOG.Println("set modelver", out.Name, out.Version, "error")
				return err
			}
		}
		return nil
	})
}

//initModels 初始化所有的Model,检查数据库结构，如果有不一致的，提示sql语句
func initModels(db ddb.TxDB) {
	//创建数据库中的model的数据库版本存放表
	var err error
	if err = initModelVersion(db); err != nil {
		LOG.Panic(err)
	}
	var changes []string
	if changes, err = extractChangeSQL(db); err != nil {
		LOG.Panic(err)
	}
	//如果有变动sql，则提示并退出
	if len(changes) > 0 {
		fmt.Fprintln(os.Stdout, "------> detect schema change,please post sql to db <------")
		for _, one := range changes {
			fmt.Fprintln(os.Stdout, one+";")
		}
		fmt.Fprintln(os.Stdout, "-----------------> has schema change <--------------------")
		LOG.Panicln("SCHEMA_CHANGE")
	}
	if err = setupModels(db); err != nil {
		LOG.Panic(err)
	}

	//最后设置每个Model的实际类，主要是不能用到tx
	for _, mname := range modelList.Names() {
		m := modelList.FindByName(mname)
		newdb := db
		if len(m.dbname) >= 0 {
			newdb = LoadOuterDB(db, m.dbname)
		}
		if mm, er := m.new(newdb); er != nil {
			LOG.Panic(er)
		} else {
			m.m = mm
		}
	}
}
