package common

import (
	"archive/zip"
	"dbweb/core"
	"dbweb/lib/safe"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"io"
	"io/ioutil"
	"mime/multipart"
	"net/url"
	"os"
	"strings"

	"github.com/pborman/uuid"
)

const (
	ctrlRestoreName = "restore"
)

type restoreTaskParam struct {
	FileName  string
	DumpName  string
	BillNames []string
}
type Restore struct{}

// 获取文件大小的接口
type Size interface {
	Size() int64
}

// 获取文件信息的接口
type Stat interface {
	Stat() (os.FileInfo, error)
}

func init() {
	core.RegisterFun(ctrlRestoreName, new(Restore))
	core.RegisterElementType(ctrlRestoreName, core.ModelRestoreElement{})
}
func (r *Restore) ElementParamRender(p *core.BillGetHandleArgs) {
	jsonParams := core.RestoreParam{}
	if bys := safe.Bytea(p.Record.Main["PARAMS"]); bys != nil {
		if err := json.Unmarshal(bys, &jsonParams); err != nil {
			core.LOG.Println(string(bys))
			core.LOG.Panic(err)
		}
	}
	p.More["JSONParams"] = jsonParams
}
func (r *Restore) Get(p *core.ElementHandleArgs) {
	if p.Req.URL.Query().Get("step") == "2" {
		p.Render.HTML(200, "restore/step2", p)
	} else {
		p.Render.HTML(200, "restore/step1", p)
	}
}

//检查一个上传的文件是否是有效的dump文件
func checkIsDumpFile(upFile multipart.File) (*zip.Reader, error) {
	var fileSize int64
	if statInterface, ok := upFile.(Stat); ok {
		fileInfo, _ := statInterface.Stat()
		fileSize = fileInfo.Size()
	}
	if sizeInterface, ok := upFile.(Size); ok {
		fileSize = sizeInterface.Size()
	}
	reader, err := zip.NewReader(upFile, fileSize)
	if err != nil {
		return nil, err
	}
	if len(reader.File) != 1 || reader.File[0].Name != "dump.sl3" {
		return nil, fmt.Errorf("不是有效的zip文件，找不到文件dump.sl3")
	}
	return reader, nil
}
func decompressDumpFile(zipArc *zip.Reader) (string, error) {
	//检查上传的文件
	rd, err := zipArc.File[0].Open()
	if err != nil {
		return "", err
	}
	defer rd.Close()
	tmpFile, err := ioutil.TempFile("", "imp_")
	if err != nil {
		return "", err
	}
	defer tmpFile.Close()
	if _, err = io.Copy(tmpFile, rd); err != nil {
		return "", err
	}
	return tmpFile.Name(), nil

}
func restoreFile(t *core.TaskRun) string {
	pam := t.Param.(restoreTaskParam)
	if err := core.RestoreFile(pam.FileName, pam.DumpName, pam.BillNames, t.Db, t.User, t.AddProgress); err != nil {
		core.LOG.Panic(err)
	}
	return ""
}

//Post 业务
func (r *Restore) Post(p *core.ElementHandleArgs) {
	eleParam := &core.RestoreParam{}
	if err := json.Unmarshal(p.Element.Params, eleParam); err != nil {
		p.GotoMessage(err.Error())
		return
	}
	if p.Req.URL.Query().Get("step") == "2" {
		fileID, err := base64.RawURLEncoding.DecodeString(p.Req.PostFormValue("dumpFileID"))
		if err != nil {
			core.LOG.Panic(err)
		}
		selBills := []string{}
		for k := range p.Req.PostForm {
			names := strings.Split(k, "_")
			if len(names) > 1 {
				selBills = append(selBills, names[1])
			}
		}
		fileName := p.LSession.Get(fileID).(string)
		param := restoreTaskParam{
			FileName:  fileName,
			DumpName:  eleParam.DumpName,
			BillNames: selBills,
		}

		task := &core.TaskRun{
			Db:       p.DB,
			Name:     p.Element.DisplayLabel(),
			User:     p.User,
			ClientIP: p.Req.RemoteAddr,
			Param:    param,
			Func:     restoreFile,
		}
		if err := task.GoRun(); err != nil {
			p.RenderError(err.Error())

		} else {
			p.GotoMessage(fmt.Sprintf("导入操作将在后台运行，可以<a href='%s' target='_blank'>点击查看进度</a>", p.User.Sign("/browsetask/"+task.ID())))
		}
		return
	} else {
		upFile, _, err := p.Req.FormFile("impfile")
		if err != nil {
			p.GotoMessage(err.Error())
			return
		}
		defer upFile.Close()
		zipArc, err := checkIsDumpFile(upFile)
		if err != nil {
			p.GotoMessage(err.Error())
			return
		}
		tmpFileName, err := decompressDumpFile(zipArc)
		if err != nil {
			p.GotoMessage(err.Error())
			return
		}
		//下面开始读取dump包中的内容

		config, err := core.ReadDumpConfig(tmpFileName)
		if err != nil {
			p.GotoMessage(err.Error())
			return
		}
		if err = config.Check(eleParam.DumpName); err != nil {
			p.GotoMessage(err.Error())
			return
		}

		newid := []byte(uuid.NewUUID())
		p.LSession.Set(newid, tmpFileName)
		p.More["DumpFileID"] = base64.RawURLEncoding.EncodeToString(newid)
		p.More["DumpConfig"] = config
		bys, _ := json.MarshalIndent(config, "", "\t")
		p.More["DumpConfigJSON"] = string(bys)
		//构造新的url，加上step=2参数
		newUrl, _ := url.Parse(p.Req.URL.String())
		q := newUrl.Query()
		q.Set("step", "2")
		newUrl.RawQuery = q.Encode()
		p.More["PostUrl"] = newUrl
		p.Render.HTML(200, "restore/step2", p)
	}
}
