init.go 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. package migration
  2. import (
  3. "fmt"
  4. "log"
  5. "path/filepath"
  6. "sort"
  7. "strconv"
  8. "sync"
  9. "github.com/spf13/cast"
  10. "gorm.io/gorm"
  11. )
  12. var Migrate = &Migration{
  13. version: make(map[int]func(db *gorm.DB, version string) error),
  14. }
  15. type Migration struct {
  16. db *gorm.DB
  17. version map[int]func(db *gorm.DB, version string) error
  18. mutex sync.Mutex
  19. }
  20. func (e *Migration) GetDb() *gorm.DB {
  21. return e.db
  22. }
  23. func (e *Migration) SetDb(db *gorm.DB) {
  24. e.db = db
  25. }
  26. //version 初始化 init 时调用, 将数据库信息导入到 migration.
  27. func (e *Migration) SetVersion(k int, f func(db *gorm.DB, version string) error) {
  28. e.mutex.Lock()
  29. defer e.mutex.Unlock()
  30. e.version[k] = f
  31. }
  32. func (e *Migration) Migrate() {
  33. versions := make([]int, 0)
  34. for k := range e.version {
  35. versions = append(versions, k)
  36. }
  37. sort.IntsAreSorted(versions)
  38. var err error
  39. var count int64
  40. for _, v := range versions {
  41. fmt.Println(v)
  42. err = e.db.Debug().Table("sys_migration").Where("version = ?", v).Count(&count).Error
  43. if err != nil {
  44. log.Fatalln(err)
  45. }
  46. if count > 0 {
  47. log.Println(count)
  48. count = 0
  49. continue
  50. }
  51. err = (e.version[v])(e.db.Debug(), strconv.Itoa(v))
  52. if err != nil {
  53. log.Fatalln(err)
  54. }
  55. }
  56. }
  57. func GetFilename(s string) int {
  58. s = filepath.Base(s)
  59. fmt.Println(s)
  60. return cast.ToInt(s[:13])
  61. }