101 lines
2.3 KiB
Go
101 lines
2.3 KiB
Go
|
|
package db
|
||
|
|
|
||
|
|
import (
|
||
|
|
"code.hoteas.com/golang/hotime/cache"
|
||
|
|
. "code.hoteas.com/golang/hotime/common"
|
||
|
|
"database/sql"
|
||
|
|
_ "github.com/go-sql-driver/mysql"
|
||
|
|
_ "github.com/mattn/go-sqlite3"
|
||
|
|
"github.com/sirupsen/logrus"
|
||
|
|
"sync"
|
||
|
|
)
|
||
|
|
|
||
|
|
// HoTimeDB 数据库操作核心结构体
|
||
|
|
type HoTimeDB struct {
|
||
|
|
*sql.DB
|
||
|
|
ContextBase
|
||
|
|
DBName string
|
||
|
|
*cache.HoTimeCache
|
||
|
|
Log *logrus.Logger
|
||
|
|
Type string // 数据库类型: mysql, sqlite3, postgres
|
||
|
|
Prefix string
|
||
|
|
LastQuery string
|
||
|
|
LastData []interface{}
|
||
|
|
ConnectFunc func(err ...*Error) (*sql.DB, *sql.DB)
|
||
|
|
LastErr *Error
|
||
|
|
limit Slice
|
||
|
|
*sql.Tx //事务对象
|
||
|
|
SlaveDB *sql.DB // 从数据库
|
||
|
|
Mode int // mode为0生产模式,1为测试模式,2为开发模式
|
||
|
|
mu sync.RWMutex
|
||
|
|
limitMu sync.Mutex
|
||
|
|
Dialect Dialect // 数据库方言适配器
|
||
|
|
}
|
||
|
|
|
||
|
|
// SetConnect 设置数据库配置连接
|
||
|
|
func (that *HoTimeDB) SetConnect(connect func(err ...*Error) (master, slave *sql.DB), err ...*Error) {
|
||
|
|
that.ConnectFunc = connect
|
||
|
|
_ = that.InitDb(err...)
|
||
|
|
}
|
||
|
|
|
||
|
|
// InitDb 初始化数据库连接
|
||
|
|
func (that *HoTimeDB) InitDb(err ...*Error) *Error {
|
||
|
|
if len(err) != 0 {
|
||
|
|
that.LastErr = err[0]
|
||
|
|
}
|
||
|
|
that.DB, that.SlaveDB = that.ConnectFunc(that.LastErr)
|
||
|
|
if that.DB == nil {
|
||
|
|
return that.LastErr
|
||
|
|
}
|
||
|
|
e := that.DB.Ping()
|
||
|
|
|
||
|
|
that.LastErr.SetError(e)
|
||
|
|
|
||
|
|
if that.SlaveDB != nil {
|
||
|
|
e := that.SlaveDB.Ping()
|
||
|
|
that.LastErr.SetError(e)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 根据数据库类型初始化方言适配器
|
||
|
|
if that.Dialect == nil {
|
||
|
|
that.initDialect()
|
||
|
|
}
|
||
|
|
|
||
|
|
return that.LastErr
|
||
|
|
}
|
||
|
|
|
||
|
|
// initDialect 根据数据库类型初始化方言
|
||
|
|
func (that *HoTimeDB) initDialect() {
|
||
|
|
switch that.Type {
|
||
|
|
case "postgres", "postgresql":
|
||
|
|
that.Dialect = &PostgreSQLDialect{}
|
||
|
|
case "sqlite3", "sqlite":
|
||
|
|
that.Dialect = &SQLiteDialect{}
|
||
|
|
default:
|
||
|
|
that.Dialect = &MySQLDialect{}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetDialect 获取当前方言适配器
|
||
|
|
func (that *HoTimeDB) GetDialect() Dialect {
|
||
|
|
if that.Dialect == nil {
|
||
|
|
that.initDialect()
|
||
|
|
}
|
||
|
|
return that.Dialect
|
||
|
|
}
|
||
|
|
|
||
|
|
// SetDialect 设置方言适配器
|
||
|
|
func (that *HoTimeDB) SetDialect(dialect Dialect) {
|
||
|
|
that.Dialect = dialect
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetType 获取数据库类型
|
||
|
|
func (that *HoTimeDB) GetType() string {
|
||
|
|
return that.Type
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetPrefix 获取表前缀
|
||
|
|
func (that *HoTimeDB) GetPrefix() string {
|
||
|
|
return that.Prefix
|
||
|
|
}
|