hotime/db/query.go

392 lines
10 KiB
Go
Raw Permalink Normal View History

package db
import (
"database/sql"
"encoding/json"
"errors"
"reflect"
"strings"
. "code.hoteas.com/golang/hotime/common"
)
// md5 生成查询的 MD5 哈希(用于缓存)
func (that *HoTimeDB) md5(query string, args ...interface{}) string {
strByte, _ := json.Marshal(args)
str := Md5(query + ":" + string(strByte))
return str
}
// Query 执行查询 SQL
func (that *HoTimeDB) Query(query string, args ...interface{}) []Map {
return that.queryWithRetry(query, false, args...)
}
// queryWithRetry 内部查询方法,支持重试标记
func (that *HoTimeDB) queryWithRetry(query string, retried bool, args ...interface{}) []Map {
// 预处理数组占位符 ?[]
query, args = that.expandArrayPlaceholder(query, args)
// 保存调试信息(加锁保护)
that.mu.Lock()
that.LastQuery = query
that.LastData = args
that.mu.Unlock()
defer func() {
if that.Mode != 0 {
that.mu.RLock()
that.Log.Info("SQL:"+that.LastQuery, " DATA:", that.LastData, " ERROR:", that.LastErr.GetError())
that.mu.RUnlock()
}
}()
var err error
var resl *sql.Rows
// 主从数据库切换只有select语句有从数据库
db := that.DB
if that.SlaveDB != nil {
db = that.SlaveDB
}
if db == nil {
err = errors.New("没有初始化数据库")
that.LastErr.SetError(err)
return nil
}
// 处理参数中的 slice 类型
processedArgs := that.processArgs(args)
if that.Tx != nil {
resl, err = that.Tx.Query(query, processedArgs...)
} else {
resl, err = db.Query(query, processedArgs...)
}
that.LastErr.SetError(err)
if err != nil {
// 如果还没重试过,尝试 Ping 后重试一次
if !retried {
if pingErr := db.Ping(); pingErr == nil {
return that.queryWithRetry(query, true, args...)
}
}
return nil
}
return that.Row(resl)
}
// Exec 执行非查询 SQL
func (that *HoTimeDB) Exec(query string, args ...interface{}) (sql.Result, *Error) {
return that.execWithRetry(query, false, args...)
}
// execWithRetry 内部执行方法,支持重试标记
func (that *HoTimeDB) execWithRetry(query string, retried bool, args ...interface{}) (sql.Result, *Error) {
// 预处理数组占位符 ?[]
query, args = that.expandArrayPlaceholder(query, args)
// 保存调试信息(加锁保护)
that.mu.Lock()
that.LastQuery = query
that.LastData = args
that.mu.Unlock()
defer func() {
if that.Mode != 0 {
that.mu.RLock()
that.Log.Info("SQL: "+that.LastQuery, " DATA: ", that.LastData, " ERROR: ", that.LastErr.GetError())
that.mu.RUnlock()
}
}()
var e error
var resl sql.Result
if that.DB == nil {
err := errors.New("没有初始化数据库")
that.LastErr.SetError(err)
return nil, that.LastErr
}
// 处理参数中的 slice 类型
processedArgs := that.processArgs(args)
if that.Tx != nil {
resl, e = that.Tx.Exec(query, processedArgs...)
} else {
resl, e = that.DB.Exec(query, processedArgs...)
}
that.LastErr.SetError(e)
// 判断是否连接断开了,如果还没重试过,尝试重试一次
if e != nil {
if !retried {
if pingErr := that.DB.Ping(); pingErr == nil {
return that.execWithRetry(query, true, args...)
}
}
return resl, that.LastErr
}
return resl, that.LastErr
}
// processArgs 处理参数中的 slice 类型
func (that *HoTimeDB) processArgs(args []interface{}) []interface{} {
processedArgs := make([]interface{}, len(args))
copy(processedArgs, args)
for key := range processedArgs {
arg := processedArgs[key]
if arg == nil {
continue
}
argType := reflect.ValueOf(arg).Type().String()
if strings.Contains(argType, "[]") || strings.Contains(argType, "Slice") {
argLis := ObjToSlice(arg)
// 将slice转为逗号分割字符串
argStr := ""
for i := 0; i < len(argLis); i++ {
if i == len(argLis)-1 {
argStr += ObjToStr(argLis[i])
} else {
argStr += ObjToStr(argLis[i]) + ","
}
}
processedArgs[key] = argStr
}
}
return processedArgs
}
// expandArrayPlaceholder 展开 IN (?) / NOT IN (?) 中的数组参数
// 自动识别 IN/NOT IN (?) 模式,当参数是数组时展开为多个 ?
//
// 示例:
//
// db.Query("SELECT * FROM user WHERE id IN (?)", []int{1, 2, 3})
// // 展开为: SELECT * FROM user WHERE id IN (?, ?, ?) 参数: [1, 2, 3]
//
// db.Query("SELECT * FROM user WHERE id IN (?)", []int{})
// // 展开为: SELECT * FROM user WHERE 1=0 参数: [] (空集合的IN永假)
//
// db.Query("SELECT * FROM user WHERE id NOT IN (?)", []int{})
// // 展开为: SELECT * FROM user WHERE 1=1 参数: [] (空集合的NOT IN永真)
//
// db.Query("SELECT * FROM user WHERE id = ?", 1)
// // 保持不变: SELECT * FROM user WHERE id = ? 参数: [1]
func (that *HoTimeDB) expandArrayPlaceholder(query string, args []interface{}) (string, []interface{}) {
if len(args) == 0 || !strings.Contains(query, "?") {
return query, args
}
// 检查是否有数组参数
hasArray := false
for _, arg := range args {
if arg == nil {
continue
}
argType := reflect.ValueOf(arg).Type().String()
if strings.Contains(argType, "[]") || strings.Contains(argType, "Slice") {
hasArray = true
break
}
}
if !hasArray {
return query, args
}
newArgs := make([]interface{}, 0, len(args))
result := strings.Builder{}
argIndex := 0
for i := 0; i < len(query); i++ {
if query[i] == '?' && argIndex < len(args) {
arg := args[argIndex]
argIndex++
if arg == nil {
result.WriteByte('?')
newArgs = append(newArgs, nil)
continue
}
argType := reflect.ValueOf(arg).Type().String()
if strings.Contains(argType, "[]") || strings.Contains(argType, "Slice") {
// 是数组参数,检查是否在 IN (...) 或 NOT IN (...) 中
prevPart := result.String()
prevUpper := strings.ToUpper(prevPart)
// 查找最近的 NOT IN ( 模式
notInIndex := strings.LastIndex(prevUpper, " NOT IN (")
notInIndex2 := strings.LastIndex(prevUpper, " NOT IN(")
if notInIndex2 > notInIndex {
notInIndex = notInIndex2
}
// 查找最近的 IN ( 模式(但要排除 NOT IN 的情况)
inIndex := strings.LastIndex(prevUpper, " IN (")
inIndex2 := strings.LastIndex(prevUpper, " IN(")
if inIndex2 > inIndex {
inIndex = inIndex2
}
// 判断是 NOT IN 还是 IN
// 注意:" NOT IN (" 包含 " IN (",所以如果找到的 IN 位置在 NOT IN 范围内,应该优先判断为 NOT IN
isNotIn := false
matchIndex := -1
if notInIndex != -1 {
// 检查 inIndex 是否在 notInIndex 范围内(即 NOT IN 的 IN 部分)
// NOT IN ( 的 IN ( 部分从 notInIndex + 4 开始
if inIndex != -1 && inIndex >= notInIndex && inIndex <= notInIndex+5 {
// inIndex 是 NOT IN 的一部分,使用 NOT IN
isNotIn = true
matchIndex = notInIndex
} else if inIndex == -1 || notInIndex > inIndex {
// 没有独立的 IN或 NOT IN 在 IN 之后
isNotIn = true
matchIndex = notInIndex
} else {
// 有独立的 IN 且在 NOT IN 之后
matchIndex = inIndex
}
} else if inIndex != -1 {
matchIndex = inIndex
}
// 检查 IN ( 后面是否只有空格(即当前 ? 紧跟在 IN ( 后面)
isInPattern := false
if matchIndex != -1 {
afterIn := prevPart[matchIndex:]
// 找到 ( 的位置
parenIdx := strings.Index(afterIn, "(")
if parenIdx != -1 {
afterParen := strings.TrimSpace(afterIn[parenIdx+1:])
if afterParen == "" {
isInPattern = true
}
}
}
if isInPattern {
// 在 IN (...) 或 NOT IN (...) 模式中
argList := ObjToSlice(arg)
if len(argList) == 0 {
// 空数组处理:需要找到字段名的开始位置
// 往前找最近的 AND/OR/WHERE/(,以确定条件的开始位置
truncateIndex := matchIndex
searchPart := prevUpper[:matchIndex]
// 找最近的分隔符位置
andIdx := strings.LastIndex(searchPart, " AND ")
orIdx := strings.LastIndex(searchPart, " OR ")
whereIdx := strings.LastIndex(searchPart, " WHERE ")
parenIdx := strings.LastIndex(searchPart, "(")
// 取最靠后的分隔符
sepIndex := -1
sepLen := 0
if andIdx > sepIndex {
sepIndex = andIdx
sepLen = 5 // " AND "
}
if orIdx > sepIndex {
sepIndex = orIdx
sepLen = 4 // " OR "
}
if whereIdx > sepIndex {
sepIndex = whereIdx
sepLen = 7 // " WHERE "
}
if parenIdx > sepIndex {
sepIndex = parenIdx
sepLen = 1 // "("
}
if sepIndex != -1 {
truncateIndex = sepIndex + sepLen
}
result.Reset()
result.WriteString(prevPart[:truncateIndex])
if isNotIn {
// NOT IN 空集合 = 永真
result.WriteString(" 1=1 ")
} else {
// IN 空集合 = 永假
result.WriteString(" 1=0 ")
}
// 跳过后面的 )
for j := i + 1; j < len(query); j++ {
if query[j] == ')' {
i = j
break
}
}
} else if len(argList) == 1 {
// 单元素数组
result.WriteByte('?')
newArgs = append(newArgs, argList[0])
} else {
// 多元素数组,展开为多个 ?
for j := 0; j < len(argList); j++ {
if j > 0 {
result.WriteString(", ")
}
result.WriteByte('?')
newArgs = append(newArgs, argList[j])
}
}
} else {
// 不在 IN 模式中,保持原有行为(数组会被 processArgs 转为逗号字符串)
result.WriteByte('?')
newArgs = append(newArgs, arg)
}
} else {
// 非数组参数
result.WriteByte('?')
newArgs = append(newArgs, arg)
}
} else {
result.WriteByte(query[i])
}
}
return result.String(), newArgs
}
// Row 数据库数据解析
func (that *HoTimeDB) Row(resl *sql.Rows) []Map {
dest := make([]Map, 0)
strs, _ := resl.Columns()
for i := 0; resl.Next(); i++ {
lis := make(Map, 0)
a := make([]interface{}, len(strs))
b := make([]interface{}, len(a))
for j := 0; j < len(a); j++ {
b[j] = &a[j]
}
err := resl.Scan(b...)
if err != nil {
that.LastErr.SetError(err)
return nil
}
for j := 0; j < len(a); j++ {
if a[j] != nil && reflect.ValueOf(a[j]).Type().String() == "[]uint8" {
lis[strs[j]] = string(a[j].([]byte))
} else {
lis[strs[j]] = a[j] // 取实际类型
}
}
dest = append(dest, lis)
}
return dest
}