- 在 Get 方法中添加无参数时的默认字段和 LIMIT 1 处理 - 实现 expandArrayPlaceholder 方法,自动展开 IN (?) 和 NOT IN (?) 中的数组参数 - 为空数组的 IN 条件生成 1=0 永假条件,NOT IN 生成 1=1 永真条件 - 在 queryWithRetry 和 execWithRetry 中集成数组占位符预处理 - 修复 where.go 中空切片条件的处理逻辑 - 添加完整的 IN/NOT IN 数组查询测试用例 - 更新 .gitignore 规则格式
392 lines
10 KiB
Go
392 lines
10 KiB
Go
package db
|
||
|
||
import (
|
||
"database/sql"
|
||
"encoding/json"
|
||
"errors"
|
||
"reflect"
|
||
"strings"
|
||
|
||
. "code.hoteas.com/golang/hotime/common"
|
||
)
|
||
|
||
// md5 生成查询的 MD5 哈希(用于缓存)
|
||
func (that *HoTimeDB) md5(query string, args ...interface{}) string {
|
||
strByte, _ := json.Marshal(args)
|
||
str := Md5(query + ":" + string(strByte))
|
||
return str
|
||
}
|
||
|
||
// Query 执行查询 SQL
|
||
func (that *HoTimeDB) Query(query string, args ...interface{}) []Map {
|
||
return that.queryWithRetry(query, false, args...)
|
||
}
|
||
|
||
// queryWithRetry 内部查询方法,支持重试标记
|
||
func (that *HoTimeDB) queryWithRetry(query string, retried bool, args ...interface{}) []Map {
|
||
// 预处理数组占位符 ?[]
|
||
query, args = that.expandArrayPlaceholder(query, args)
|
||
|
||
// 保存调试信息(加锁保护)
|
||
that.mu.Lock()
|
||
that.LastQuery = query
|
||
that.LastData = args
|
||
that.mu.Unlock()
|
||
|
||
defer func() {
|
||
if that.Mode != 0 {
|
||
that.mu.RLock()
|
||
that.Log.Info("SQL:"+that.LastQuery, " DATA:", that.LastData, " ERROR:", that.LastErr.GetError())
|
||
that.mu.RUnlock()
|
||
}
|
||
}()
|
||
|
||
var err error
|
||
var resl *sql.Rows
|
||
|
||
// 主从数据库切换,只有select语句有从数据库
|
||
db := that.DB
|
||
if that.SlaveDB != nil {
|
||
db = that.SlaveDB
|
||
}
|
||
|
||
if db == nil {
|
||
err = errors.New("没有初始化数据库")
|
||
that.LastErr.SetError(err)
|
||
return nil
|
||
}
|
||
|
||
// 处理参数中的 slice 类型
|
||
processedArgs := that.processArgs(args)
|
||
|
||
if that.Tx != nil {
|
||
resl, err = that.Tx.Query(query, processedArgs...)
|
||
} else {
|
||
resl, err = db.Query(query, processedArgs...)
|
||
}
|
||
|
||
that.LastErr.SetError(err)
|
||
if err != nil {
|
||
// 如果还没重试过,尝试 Ping 后重试一次
|
||
if !retried {
|
||
if pingErr := db.Ping(); pingErr == nil {
|
||
return that.queryWithRetry(query, true, args...)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
return that.Row(resl)
|
||
}
|
||
|
||
// Exec 执行非查询 SQL
|
||
func (that *HoTimeDB) Exec(query string, args ...interface{}) (sql.Result, *Error) {
|
||
return that.execWithRetry(query, false, args...)
|
||
}
|
||
|
||
// execWithRetry 内部执行方法,支持重试标记
|
||
func (that *HoTimeDB) execWithRetry(query string, retried bool, args ...interface{}) (sql.Result, *Error) {
|
||
// 预处理数组占位符 ?[]
|
||
query, args = that.expandArrayPlaceholder(query, args)
|
||
|
||
// 保存调试信息(加锁保护)
|
||
that.mu.Lock()
|
||
that.LastQuery = query
|
||
that.LastData = args
|
||
that.mu.Unlock()
|
||
|
||
defer func() {
|
||
if that.Mode != 0 {
|
||
that.mu.RLock()
|
||
that.Log.Info("SQL: "+that.LastQuery, " DATA: ", that.LastData, " ERROR: ", that.LastErr.GetError())
|
||
that.mu.RUnlock()
|
||
}
|
||
}()
|
||
|
||
var e error
|
||
var resl sql.Result
|
||
|
||
if that.DB == nil {
|
||
err := errors.New("没有初始化数据库")
|
||
that.LastErr.SetError(err)
|
||
return nil, that.LastErr
|
||
}
|
||
|
||
// 处理参数中的 slice 类型
|
||
processedArgs := that.processArgs(args)
|
||
|
||
if that.Tx != nil {
|
||
resl, e = that.Tx.Exec(query, processedArgs...)
|
||
} else {
|
||
resl, e = that.DB.Exec(query, processedArgs...)
|
||
}
|
||
|
||
that.LastErr.SetError(e)
|
||
// 判断是否连接断开了,如果还没重试过,尝试重试一次
|
||
if e != nil {
|
||
if !retried {
|
||
if pingErr := that.DB.Ping(); pingErr == nil {
|
||
return that.execWithRetry(query, true, args...)
|
||
}
|
||
}
|
||
return resl, that.LastErr
|
||
}
|
||
|
||
return resl, that.LastErr
|
||
}
|
||
|
||
// processArgs 处理参数中的 slice 类型
|
||
func (that *HoTimeDB) processArgs(args []interface{}) []interface{} {
|
||
processedArgs := make([]interface{}, len(args))
|
||
copy(processedArgs, args)
|
||
for key := range processedArgs {
|
||
arg := processedArgs[key]
|
||
if arg == nil {
|
||
continue
|
||
}
|
||
argType := reflect.ValueOf(arg).Type().String()
|
||
if strings.Contains(argType, "[]") || strings.Contains(argType, "Slice") {
|
||
argLis := ObjToSlice(arg)
|
||
// 将slice转为逗号分割字符串
|
||
argStr := ""
|
||
for i := 0; i < len(argLis); i++ {
|
||
if i == len(argLis)-1 {
|
||
argStr += ObjToStr(argLis[i])
|
||
} else {
|
||
argStr += ObjToStr(argLis[i]) + ","
|
||
}
|
||
}
|
||
processedArgs[key] = argStr
|
||
}
|
||
}
|
||
return processedArgs
|
||
}
|
||
|
||
// expandArrayPlaceholder 展开 IN (?) / NOT IN (?) 中的数组参数
|
||
// 自动识别 IN/NOT IN (?) 模式,当参数是数组时展开为多个 ?
|
||
//
|
||
// 示例:
|
||
//
|
||
// db.Query("SELECT * FROM user WHERE id IN (?)", []int{1, 2, 3})
|
||
// // 展开为: SELECT * FROM user WHERE id IN (?, ?, ?) 参数: [1, 2, 3]
|
||
//
|
||
// db.Query("SELECT * FROM user WHERE id IN (?)", []int{})
|
||
// // 展开为: SELECT * FROM user WHERE 1=0 参数: [] (空集合的IN永假)
|
||
//
|
||
// db.Query("SELECT * FROM user WHERE id NOT IN (?)", []int{})
|
||
// // 展开为: SELECT * FROM user WHERE 1=1 参数: [] (空集合的NOT IN永真)
|
||
//
|
||
// db.Query("SELECT * FROM user WHERE id = ?", 1)
|
||
// // 保持不变: SELECT * FROM user WHERE id = ? 参数: [1]
|
||
func (that *HoTimeDB) expandArrayPlaceholder(query string, args []interface{}) (string, []interface{}) {
|
||
if len(args) == 0 || !strings.Contains(query, "?") {
|
||
return query, args
|
||
}
|
||
|
||
// 检查是否有数组参数
|
||
hasArray := false
|
||
for _, arg := range args {
|
||
if arg == nil {
|
||
continue
|
||
}
|
||
argType := reflect.ValueOf(arg).Type().String()
|
||
if strings.Contains(argType, "[]") || strings.Contains(argType, "Slice") {
|
||
hasArray = true
|
||
break
|
||
}
|
||
}
|
||
if !hasArray {
|
||
return query, args
|
||
}
|
||
|
||
newArgs := make([]interface{}, 0, len(args))
|
||
result := strings.Builder{}
|
||
argIndex := 0
|
||
|
||
for i := 0; i < len(query); i++ {
|
||
if query[i] == '?' && argIndex < len(args) {
|
||
arg := args[argIndex]
|
||
argIndex++
|
||
|
||
if arg == nil {
|
||
result.WriteByte('?')
|
||
newArgs = append(newArgs, nil)
|
||
continue
|
||
}
|
||
|
||
argType := reflect.ValueOf(arg).Type().String()
|
||
if strings.Contains(argType, "[]") || strings.Contains(argType, "Slice") {
|
||
// 是数组参数,检查是否在 IN (...) 或 NOT IN (...) 中
|
||
prevPart := result.String()
|
||
prevUpper := strings.ToUpper(prevPart)
|
||
|
||
// 查找最近的 NOT IN ( 模式
|
||
notInIndex := strings.LastIndex(prevUpper, " NOT IN (")
|
||
notInIndex2 := strings.LastIndex(prevUpper, " NOT IN(")
|
||
if notInIndex2 > notInIndex {
|
||
notInIndex = notInIndex2
|
||
}
|
||
|
||
// 查找最近的 IN ( 模式(但要排除 NOT IN 的情况)
|
||
inIndex := strings.LastIndex(prevUpper, " IN (")
|
||
inIndex2 := strings.LastIndex(prevUpper, " IN(")
|
||
if inIndex2 > inIndex {
|
||
inIndex = inIndex2
|
||
}
|
||
|
||
// 判断是 NOT IN 还是 IN
|
||
// 注意:" NOT IN (" 包含 " IN (",所以如果找到的 IN 位置在 NOT IN 范围内,应该优先判断为 NOT IN
|
||
isNotIn := false
|
||
matchIndex := -1
|
||
if notInIndex != -1 {
|
||
// 检查 inIndex 是否在 notInIndex 范围内(即 NOT IN 的 IN 部分)
|
||
// NOT IN ( 的 IN ( 部分从 notInIndex + 4 开始
|
||
if inIndex != -1 && inIndex >= notInIndex && inIndex <= notInIndex+5 {
|
||
// inIndex 是 NOT IN 的一部分,使用 NOT IN
|
||
isNotIn = true
|
||
matchIndex = notInIndex
|
||
} else if inIndex == -1 || notInIndex > inIndex {
|
||
// 没有独立的 IN,或 NOT IN 在 IN 之后
|
||
isNotIn = true
|
||
matchIndex = notInIndex
|
||
} else {
|
||
// 有独立的 IN 且在 NOT IN 之后
|
||
matchIndex = inIndex
|
||
}
|
||
} else if inIndex != -1 {
|
||
matchIndex = inIndex
|
||
}
|
||
|
||
// 检查 IN ( 后面是否只有空格(即当前 ? 紧跟在 IN ( 后面)
|
||
isInPattern := false
|
||
if matchIndex != -1 {
|
||
afterIn := prevPart[matchIndex:]
|
||
// 找到 ( 的位置
|
||
parenIdx := strings.Index(afterIn, "(")
|
||
if parenIdx != -1 {
|
||
afterParen := strings.TrimSpace(afterIn[parenIdx+1:])
|
||
if afterParen == "" {
|
||
isInPattern = true
|
||
}
|
||
}
|
||
}
|
||
|
||
if isInPattern {
|
||
// 在 IN (...) 或 NOT IN (...) 模式中
|
||
argList := ObjToSlice(arg)
|
||
if len(argList) == 0 {
|
||
// 空数组处理:需要找到字段名的开始位置
|
||
// 往前找最近的 AND/OR/WHERE/(,以确定条件的开始位置
|
||
truncateIndex := matchIndex
|
||
searchPart := prevUpper[:matchIndex]
|
||
|
||
// 找最近的分隔符位置
|
||
andIdx := strings.LastIndex(searchPart, " AND ")
|
||
orIdx := strings.LastIndex(searchPart, " OR ")
|
||
whereIdx := strings.LastIndex(searchPart, " WHERE ")
|
||
parenIdx := strings.LastIndex(searchPart, "(")
|
||
|
||
// 取最靠后的分隔符
|
||
sepIndex := -1
|
||
sepLen := 0
|
||
if andIdx > sepIndex {
|
||
sepIndex = andIdx
|
||
sepLen = 5 // " AND "
|
||
}
|
||
if orIdx > sepIndex {
|
||
sepIndex = orIdx
|
||
sepLen = 4 // " OR "
|
||
}
|
||
if whereIdx > sepIndex {
|
||
sepIndex = whereIdx
|
||
sepLen = 7 // " WHERE "
|
||
}
|
||
if parenIdx > sepIndex {
|
||
sepIndex = parenIdx
|
||
sepLen = 1 // "("
|
||
}
|
||
|
||
if sepIndex != -1 {
|
||
truncateIndex = sepIndex + sepLen
|
||
}
|
||
|
||
result.Reset()
|
||
result.WriteString(prevPart[:truncateIndex])
|
||
if isNotIn {
|
||
// NOT IN 空集合 = 永真
|
||
result.WriteString(" 1=1 ")
|
||
} else {
|
||
// IN 空集合 = 永假
|
||
result.WriteString(" 1=0 ")
|
||
}
|
||
// 跳过后面的 )
|
||
for j := i + 1; j < len(query); j++ {
|
||
if query[j] == ')' {
|
||
i = j
|
||
break
|
||
}
|
||
}
|
||
} else if len(argList) == 1 {
|
||
// 单元素数组
|
||
result.WriteByte('?')
|
||
newArgs = append(newArgs, argList[0])
|
||
} else {
|
||
// 多元素数组,展开为多个 ?
|
||
for j := 0; j < len(argList); j++ {
|
||
if j > 0 {
|
||
result.WriteString(", ")
|
||
}
|
||
result.WriteByte('?')
|
||
newArgs = append(newArgs, argList[j])
|
||
}
|
||
}
|
||
} else {
|
||
// 不在 IN 模式中,保持原有行为(数组会被 processArgs 转为逗号字符串)
|
||
result.WriteByte('?')
|
||
newArgs = append(newArgs, arg)
|
||
}
|
||
} else {
|
||
// 非数组参数
|
||
result.WriteByte('?')
|
||
newArgs = append(newArgs, arg)
|
||
}
|
||
} else {
|
||
result.WriteByte(query[i])
|
||
}
|
||
}
|
||
|
||
return result.String(), newArgs
|
||
}
|
||
|
||
// Row 数据库数据解析
|
||
func (that *HoTimeDB) Row(resl *sql.Rows) []Map {
|
||
dest := make([]Map, 0)
|
||
strs, _ := resl.Columns()
|
||
|
||
for i := 0; resl.Next(); i++ {
|
||
lis := make(Map, 0)
|
||
a := make([]interface{}, len(strs))
|
||
|
||
b := make([]interface{}, len(a))
|
||
for j := 0; j < len(a); j++ {
|
||
b[j] = &a[j]
|
||
}
|
||
err := resl.Scan(b...)
|
||
if err != nil {
|
||
that.LastErr.SetError(err)
|
||
return nil
|
||
}
|
||
for j := 0; j < len(a); j++ {
|
||
if a[j] != nil && reflect.ValueOf(a[j]).Type().String() == "[]uint8" {
|
||
lis[strs[j]] = string(a[j].([]byte))
|
||
} else {
|
||
lis[strs[j]] = a[j] // 取实际类型
|
||
}
|
||
}
|
||
|
||
dest = append(dest, lis)
|
||
}
|
||
|
||
return dest
|
||
}
|