hotime/db/db.go

148 lines
3.6 KiB
Go
Raw Permalink Normal View History

package db
import (
"code.hoteas.com/golang/hotime/cache"
. "code.hoteas.com/golang/hotime/common"
"database/sql"
"strings"
"sync"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"github.com/sirupsen/logrus"
)
// HoTimeDB 数据库操作核心结构体
type HoTimeDB struct {
*sql.DB
ContextBase
DBName string
*cache.HoTimeCache
Log *logrus.Logger
Type string // 数据库类型: mysql, sqlite3, postgres
Prefix string
LastQuery string
LastData []interface{}
ConnectFunc func(err ...*Error) (*sql.DB, *sql.DB)
LastErr *Error
limit Slice
*sql.Tx //事务对象
SlaveDB *sql.DB // 从数据库
Mode int // mode为0生产模式,1为测试模式,2为开发模式
mu sync.RWMutex
limitMu sync.Mutex
Dialect Dialect // 数据库方言适配器
}
// SetConnect 设置数据库配置连接
func (that *HoTimeDB) SetConnect(connect func(err ...*Error) (master, slave *sql.DB), err ...*Error) {
that.ConnectFunc = connect
_ = that.InitDb(err...)
}
// InitDb 初始化数据库连接
func (that *HoTimeDB) InitDb(err ...*Error) *Error {
if len(err) != 0 {
that.LastErr = err[0]
}
that.DB, that.SlaveDB = that.ConnectFunc(that.LastErr)
if that.DB == nil {
return that.LastErr
}
e := that.DB.Ping()
that.LastErr.SetError(e)
if that.SlaveDB != nil {
e := that.SlaveDB.Ping()
that.LastErr.SetError(e)
}
// 根据数据库类型初始化方言适配器
if that.Dialect == nil {
that.initDialect()
}
return that.LastErr
}
// initDialect 根据数据库类型初始化方言
func (that *HoTimeDB) initDialect() {
switch that.Type {
case "postgres", "postgresql":
that.Dialect = &PostgreSQLDialect{}
case "sqlite3", "sqlite":
that.Dialect = &SQLiteDialect{}
default:
that.Dialect = &MySQLDialect{}
}
}
// GetDialect 获取当前方言适配器
func (that *HoTimeDB) GetDialect() Dialect {
if that.Dialect == nil {
that.initDialect()
}
return that.Dialect
}
// SetDialect 设置方言适配器
func (that *HoTimeDB) SetDialect(dialect Dialect) {
that.Dialect = dialect
}
// GetType 获取数据库类型
func (that *HoTimeDB) GetType() string {
return that.Type
}
// GetPrefix 获取表前缀
func (that *HoTimeDB) GetPrefix() string {
return that.Prefix
}
// GetProcessor 获取标识符处理器
// 用于处理表名、字段名的前缀添加和引号转换
func (that *HoTimeDB) GetProcessor() *IdentifierProcessor {
return NewIdentifierProcessor(that.GetDialect(), that.Prefix)
}
// T 辅助方法:获取带前缀和引号的表名
// 用于手动构建 SQL 时使用
// 示例: db.T("order") 返回 "`app_order`" (MySQL) 或 "\"app_order\"" (PostgreSQL)
func (that *HoTimeDB) T(table string) string {
return that.GetProcessor().ProcessTableName(table)
}
// C 辅助方法:获取带前缀和引号的 table.column
// 支持两种调用方式:
// - db.C("order", "name") 返回 "`app_order`.`name`"
// - db.C("order.name") 返回 "`app_order`.`name`"
func (that *HoTimeDB) C(args ...string) string {
if len(args) == 0 {
return ""
}
if len(args) == 1 {
return that.GetProcessor().ProcessColumn(args[0])
}
// 两个参数: table, column
dialect := that.GetDialect()
table := args[0]
column := args[1]
// 去除已有引号
table = trimQuotes(table)
column = trimQuotes(column)
return dialect.QuoteIdentifier(that.Prefix+table) + "." + dialect.QuoteIdentifier(column)
}
// trimQuotes 去除字符串两端的引号
func trimQuotes(s string) string {
s = strings.TrimSpace(s)
if len(s) >= 2 {
if (s[0] == '`' && s[len(s)-1] == '`') || (s[0] == '"' && s[len(s)-1] == '"') {
return s[1 : len(s)-1]
}
}
return s
}