hotime/db/query.go
hoteas c2955d2500 feat(db): 实现数据库查询中的数组参数展开和空数组处理
- 在 Get 方法中添加无参数时的默认字段和 LIMIT 1 处理
- 实现 expandArrayPlaceholder 方法,自动展开 IN (?) 和 NOT IN (?) 中的数组参数
- 为空数组的 IN 条件生成 1=0 永假条件,NOT IN 生成 1=1 永真条件
- 在 queryWithRetry 和 execWithRetry 中集成数组占位符预处理
- 修复 where.go 中空切片条件的处理逻辑
- 添加完整的 IN/NOT IN 数组查询测试用例
- 更新 .gitignore 规则格式
2026-01-22 07:16:42 +08:00

392 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package db
import (
"database/sql"
"encoding/json"
"errors"
"reflect"
"strings"
. "code.hoteas.com/golang/hotime/common"
)
// md5 生成查询的 MD5 哈希(用于缓存)
func (that *HoTimeDB) md5(query string, args ...interface{}) string {
strByte, _ := json.Marshal(args)
str := Md5(query + ":" + string(strByte))
return str
}
// Query 执行查询 SQL
func (that *HoTimeDB) Query(query string, args ...interface{}) []Map {
return that.queryWithRetry(query, false, args...)
}
// queryWithRetry 内部查询方法,支持重试标记
func (that *HoTimeDB) queryWithRetry(query string, retried bool, args ...interface{}) []Map {
// 预处理数组占位符 ?[]
query, args = that.expandArrayPlaceholder(query, args)
// 保存调试信息(加锁保护)
that.mu.Lock()
that.LastQuery = query
that.LastData = args
that.mu.Unlock()
defer func() {
if that.Mode != 0 {
that.mu.RLock()
that.Log.Info("SQL:"+that.LastQuery, " DATA:", that.LastData, " ERROR:", that.LastErr.GetError())
that.mu.RUnlock()
}
}()
var err error
var resl *sql.Rows
// 主从数据库切换只有select语句有从数据库
db := that.DB
if that.SlaveDB != nil {
db = that.SlaveDB
}
if db == nil {
err = errors.New("没有初始化数据库")
that.LastErr.SetError(err)
return nil
}
// 处理参数中的 slice 类型
processedArgs := that.processArgs(args)
if that.Tx != nil {
resl, err = that.Tx.Query(query, processedArgs...)
} else {
resl, err = db.Query(query, processedArgs...)
}
that.LastErr.SetError(err)
if err != nil {
// 如果还没重试过,尝试 Ping 后重试一次
if !retried {
if pingErr := db.Ping(); pingErr == nil {
return that.queryWithRetry(query, true, args...)
}
}
return nil
}
return that.Row(resl)
}
// Exec 执行非查询 SQL
func (that *HoTimeDB) Exec(query string, args ...interface{}) (sql.Result, *Error) {
return that.execWithRetry(query, false, args...)
}
// execWithRetry 内部执行方法,支持重试标记
func (that *HoTimeDB) execWithRetry(query string, retried bool, args ...interface{}) (sql.Result, *Error) {
// 预处理数组占位符 ?[]
query, args = that.expandArrayPlaceholder(query, args)
// 保存调试信息(加锁保护)
that.mu.Lock()
that.LastQuery = query
that.LastData = args
that.mu.Unlock()
defer func() {
if that.Mode != 0 {
that.mu.RLock()
that.Log.Info("SQL: "+that.LastQuery, " DATA: ", that.LastData, " ERROR: ", that.LastErr.GetError())
that.mu.RUnlock()
}
}()
var e error
var resl sql.Result
if that.DB == nil {
err := errors.New("没有初始化数据库")
that.LastErr.SetError(err)
return nil, that.LastErr
}
// 处理参数中的 slice 类型
processedArgs := that.processArgs(args)
if that.Tx != nil {
resl, e = that.Tx.Exec(query, processedArgs...)
} else {
resl, e = that.DB.Exec(query, processedArgs...)
}
that.LastErr.SetError(e)
// 判断是否连接断开了,如果还没重试过,尝试重试一次
if e != nil {
if !retried {
if pingErr := that.DB.Ping(); pingErr == nil {
return that.execWithRetry(query, true, args...)
}
}
return resl, that.LastErr
}
return resl, that.LastErr
}
// processArgs 处理参数中的 slice 类型
func (that *HoTimeDB) processArgs(args []interface{}) []interface{} {
processedArgs := make([]interface{}, len(args))
copy(processedArgs, args)
for key := range processedArgs {
arg := processedArgs[key]
if arg == nil {
continue
}
argType := reflect.ValueOf(arg).Type().String()
if strings.Contains(argType, "[]") || strings.Contains(argType, "Slice") {
argLis := ObjToSlice(arg)
// 将slice转为逗号分割字符串
argStr := ""
for i := 0; i < len(argLis); i++ {
if i == len(argLis)-1 {
argStr += ObjToStr(argLis[i])
} else {
argStr += ObjToStr(argLis[i]) + ","
}
}
processedArgs[key] = argStr
}
}
return processedArgs
}
// expandArrayPlaceholder 展开 IN (?) / NOT IN (?) 中的数组参数
// 自动识别 IN/NOT IN (?) 模式,当参数是数组时展开为多个 ?
//
// 示例:
//
// db.Query("SELECT * FROM user WHERE id IN (?)", []int{1, 2, 3})
// // 展开为: SELECT * FROM user WHERE id IN (?, ?, ?) 参数: [1, 2, 3]
//
// db.Query("SELECT * FROM user WHERE id IN (?)", []int{})
// // 展开为: SELECT * FROM user WHERE 1=0 参数: [] (空集合的IN永假)
//
// db.Query("SELECT * FROM user WHERE id NOT IN (?)", []int{})
// // 展开为: SELECT * FROM user WHERE 1=1 参数: [] (空集合的NOT IN永真)
//
// db.Query("SELECT * FROM user WHERE id = ?", 1)
// // 保持不变: SELECT * FROM user WHERE id = ? 参数: [1]
func (that *HoTimeDB) expandArrayPlaceholder(query string, args []interface{}) (string, []interface{}) {
if len(args) == 0 || !strings.Contains(query, "?") {
return query, args
}
// 检查是否有数组参数
hasArray := false
for _, arg := range args {
if arg == nil {
continue
}
argType := reflect.ValueOf(arg).Type().String()
if strings.Contains(argType, "[]") || strings.Contains(argType, "Slice") {
hasArray = true
break
}
}
if !hasArray {
return query, args
}
newArgs := make([]interface{}, 0, len(args))
result := strings.Builder{}
argIndex := 0
for i := 0; i < len(query); i++ {
if query[i] == '?' && argIndex < len(args) {
arg := args[argIndex]
argIndex++
if arg == nil {
result.WriteByte('?')
newArgs = append(newArgs, nil)
continue
}
argType := reflect.ValueOf(arg).Type().String()
if strings.Contains(argType, "[]") || strings.Contains(argType, "Slice") {
// 是数组参数,检查是否在 IN (...) 或 NOT IN (...) 中
prevPart := result.String()
prevUpper := strings.ToUpper(prevPart)
// 查找最近的 NOT IN ( 模式
notInIndex := strings.LastIndex(prevUpper, " NOT IN (")
notInIndex2 := strings.LastIndex(prevUpper, " NOT IN(")
if notInIndex2 > notInIndex {
notInIndex = notInIndex2
}
// 查找最近的 IN ( 模式(但要排除 NOT IN 的情况)
inIndex := strings.LastIndex(prevUpper, " IN (")
inIndex2 := strings.LastIndex(prevUpper, " IN(")
if inIndex2 > inIndex {
inIndex = inIndex2
}
// 判断是 NOT IN 还是 IN
// 注意:" NOT IN (" 包含 " IN (",所以如果找到的 IN 位置在 NOT IN 范围内,应该优先判断为 NOT IN
isNotIn := false
matchIndex := -1
if notInIndex != -1 {
// 检查 inIndex 是否在 notInIndex 范围内(即 NOT IN 的 IN 部分)
// NOT IN ( 的 IN ( 部分从 notInIndex + 4 开始
if inIndex != -1 && inIndex >= notInIndex && inIndex <= notInIndex+5 {
// inIndex 是 NOT IN 的一部分,使用 NOT IN
isNotIn = true
matchIndex = notInIndex
} else if inIndex == -1 || notInIndex > inIndex {
// 没有独立的 IN或 NOT IN 在 IN 之后
isNotIn = true
matchIndex = notInIndex
} else {
// 有独立的 IN 且在 NOT IN 之后
matchIndex = inIndex
}
} else if inIndex != -1 {
matchIndex = inIndex
}
// 检查 IN ( 后面是否只有空格(即当前 ? 紧跟在 IN ( 后面)
isInPattern := false
if matchIndex != -1 {
afterIn := prevPart[matchIndex:]
// 找到 ( 的位置
parenIdx := strings.Index(afterIn, "(")
if parenIdx != -1 {
afterParen := strings.TrimSpace(afterIn[parenIdx+1:])
if afterParen == "" {
isInPattern = true
}
}
}
if isInPattern {
// 在 IN (...) 或 NOT IN (...) 模式中
argList := ObjToSlice(arg)
if len(argList) == 0 {
// 空数组处理:需要找到字段名的开始位置
// 往前找最近的 AND/OR/WHERE/(,以确定条件的开始位置
truncateIndex := matchIndex
searchPart := prevUpper[:matchIndex]
// 找最近的分隔符位置
andIdx := strings.LastIndex(searchPart, " AND ")
orIdx := strings.LastIndex(searchPart, " OR ")
whereIdx := strings.LastIndex(searchPart, " WHERE ")
parenIdx := strings.LastIndex(searchPart, "(")
// 取最靠后的分隔符
sepIndex := -1
sepLen := 0
if andIdx > sepIndex {
sepIndex = andIdx
sepLen = 5 // " AND "
}
if orIdx > sepIndex {
sepIndex = orIdx
sepLen = 4 // " OR "
}
if whereIdx > sepIndex {
sepIndex = whereIdx
sepLen = 7 // " WHERE "
}
if parenIdx > sepIndex {
sepIndex = parenIdx
sepLen = 1 // "("
}
if sepIndex != -1 {
truncateIndex = sepIndex + sepLen
}
result.Reset()
result.WriteString(prevPart[:truncateIndex])
if isNotIn {
// NOT IN 空集合 = 永真
result.WriteString(" 1=1 ")
} else {
// IN 空集合 = 永假
result.WriteString(" 1=0 ")
}
// 跳过后面的 )
for j := i + 1; j < len(query); j++ {
if query[j] == ')' {
i = j
break
}
}
} else if len(argList) == 1 {
// 单元素数组
result.WriteByte('?')
newArgs = append(newArgs, argList[0])
} else {
// 多元素数组,展开为多个 ?
for j := 0; j < len(argList); j++ {
if j > 0 {
result.WriteString(", ")
}
result.WriteByte('?')
newArgs = append(newArgs, argList[j])
}
}
} else {
// 不在 IN 模式中,保持原有行为(数组会被 processArgs 转为逗号字符串)
result.WriteByte('?')
newArgs = append(newArgs, arg)
}
} else {
// 非数组参数
result.WriteByte('?')
newArgs = append(newArgs, arg)
}
} else {
result.WriteByte(query[i])
}
}
return result.String(), newArgs
}
// Row 数据库数据解析
func (that *HoTimeDB) Row(resl *sql.Rows) []Map {
dest := make([]Map, 0)
strs, _ := resl.Columns()
for i := 0; resl.Next(); i++ {
lis := make(Map, 0)
a := make([]interface{}, len(strs))
b := make([]interface{}, len(a))
for j := 0; j < len(a); j++ {
b[j] = &a[j]
}
err := resl.Scan(b...)
if err != nil {
that.LastErr.SetError(err)
return nil
}
for j := 0; j < len(a); j++ {
if a[j] != nil && reflect.ValueOf(a[j]).Type().String() == "[]uint8" {
lis[strs[j]] = string(a[j].([]byte))
} else {
lis[strs[j]] = a[j] // 取实际类型
}
}
dest = append(dest, lis)
}
return dest
}