hotime/db/where.go
hoteas c2955d2500 feat(db): 实现数据库查询中的数组参数展开和空数组处理
- 在 Get 方法中添加无参数时的默认字段和 LIMIT 1 处理
- 实现 expandArrayPlaceholder 方法,自动展开 IN (?) 和 NOT IN (?) 中的数组参数
- 为空数组的 IN 条件生成 1=0 永假条件,NOT IN 生成 1=1 永真条件
- 在 queryWithRetry 和 execWithRetry 中集成数组占位符预处理
- 修复 where.go 中空切片条件的处理逻辑
- 添加完整的 IN/NOT IN 数组查询测试用例
- 更新 .gitignore 规则格式
2026-01-22 07:16:42 +08:00

577 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package db
import (
. "code.hoteas.com/golang/hotime/common"
"reflect"
"sort"
"strings"
)
// 条件关键字
var condition = []string{"AND", "OR"}
// 特殊关键字(支持大小写)
var vcond = []string{"GROUP", "ORDER", "LIMIT", "DISTINCT", "HAVING", "OFFSET"}
// normalizeKey 标准化关键字(转大写)
func normalizeKey(k string) string {
upper := strings.ToUpper(k)
for _, v := range vcond {
if upper == v {
return v
}
}
for _, v := range condition {
if upper == v {
return v
}
}
return k
}
// isConditionKey 判断是否是条件关键字
func isConditionKey(k string) bool {
upper := strings.ToUpper(k)
for _, v := range condition {
if upper == v {
return true
}
}
return false
}
// isVcondKey 判断是否是特殊关键字
func isVcondKey(k string) bool {
upper := strings.ToUpper(k)
for _, v := range vcond {
if upper == v {
return true
}
}
return false
}
// where 语句解析
func (that *HoTimeDB) where(data Map) (string, []interface{}) {
where := ""
res := make([]interface{}, 0)
// 标准化 Map 的 key大小写兼容
normalizedData := Map{}
for k, v := range data {
normalizedData[normalizeKey(k)] = v
}
data = normalizedData
// 收集所有 key 并排序
testQu := []string{}
for key := range data {
testQu = append(testQu, key)
}
sort.Strings(testQu)
// 追踪普通条件数量,用于自动添加 AND
normalCondCount := 0
for _, k := range testQu {
v := data[k]
// 检查是否是 AND/OR 条件关键字
if isConditionKey(k) {
tw, ts := that.cond(strings.ToUpper(k), v.(Map))
where += tw
res = append(res, ts...)
continue
}
// 检查是否是特殊关键字GROUP, ORDER, LIMIT 等)
if isVcondKey(k) {
continue // 特殊关键字在后面单独处理
}
// 处理普通条件字段
// 空切片的 IN 条件应该生成永假条件1=0而不是跳过
if v != nil && reflect.ValueOf(v).Type().String() == "common.Slice" && len(v.(Slice)) == 0 {
// 检查是否是 NOT IN带 [!] 后缀)- NOT IN 空数组永真,跳过即可
if !strings.HasSuffix(k, "[!]") {
// IN 空数组 -> 生成永假条件
if normalCondCount > 0 {
where += " AND "
}
where += "1=0 "
normalCondCount++
}
continue
}
if v != nil && strings.Contains(reflect.ValueOf(v).Type().String(), "[]") && len(ObjToSlice(v)) == 0 {
// 检查是否是 NOT IN带 [!] 后缀)- NOT IN 空数组永真,跳过即可
if !strings.HasSuffix(k, "[!]") {
// IN 空数组 -> 生成永假条件
if normalCondCount > 0 {
where += " AND "
}
where += "1=0 "
normalCondCount++
}
continue
}
tv, vv := that.varCond(k, v)
if tv != "" {
// 自动添加 AND 连接符
if normalCondCount > 0 {
where += " AND "
}
where += tv
normalCondCount++
res = append(res, vv...)
}
}
// 添加 WHERE 关键字
// 先去除首尾空格,检查是否有实际条件内容
trimmedWhere := strings.TrimSpace(where)
if len(trimmedWhere) != 0 {
hasWhere := true
for _, v := range vcond {
if strings.Index(trimmedWhere, v) == 0 {
hasWhere = false
}
}
if hasWhere {
where = " WHERE " + trimmedWhere + " "
}
} else {
// 没有实际条件内容,重置 where
where = ""
}
// 处理特殊字符按固定顺序GROUP, HAVING, ORDER, LIMIT, OFFSET
specialOrder := []string{"GROUP", "HAVING", "ORDER", "LIMIT", "OFFSET", "DISTINCT"}
for _, vcondKey := range specialOrder {
v, exists := data[vcondKey]
if !exists {
continue
}
switch vcondKey {
case "GROUP":
where += " GROUP BY "
where += that.formatVcondValue(v)
case "HAVING":
// HAVING 条件处理
if havingMap, ok := v.(Map); ok {
havingWhere, havingRes := that.cond("AND", havingMap)
if havingWhere != "" {
where += " HAVING " + strings.TrimSpace(havingWhere) + " "
res = append(res, havingRes...)
}
}
case "ORDER":
where += " ORDER BY "
where += that.formatVcondValue(v)
case "LIMIT":
where += " LIMIT "
where += that.formatVcondValue(v)
case "OFFSET":
where += " OFFSET " + ObjToStr(v) + " "
case "DISTINCT":
// DISTINCT 通常在 SELECT 中处理,这里暂时忽略
}
}
return where, res
}
// formatVcondValue 格式化特殊关键字的值
func (that *HoTimeDB) formatVcondValue(v interface{}) string {
result := ""
if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") {
vs := ObjToSlice(v)
for i := 0; i < len(vs); i++ {
result += " " + vs.GetString(i) + " "
if len(vs) != i+1 {
result += ", "
}
}
} else {
result += " " + ObjToStr(v) + " "
}
return result
}
// varCond 变量条件解析
func (that *HoTimeDB) varCond(k string, v interface{}) (string, []interface{}) {
where := ""
res := make([]interface{}, 0)
length := len(k)
if k == "[#]" {
k = strings.Replace(k, "[#]", "", -1)
where += " " + ObjToStr(v) + " "
} else if k == "[##]" {
// 直接添加 SQL 片段key 为 [##] 时)
where += " " + ObjToStr(v) + " "
} else if length > 0 && strings.Contains(k, "[") && k[length-1] == ']' {
def := false
switch Substr(k, length-3, 3) {
case "[>]":
k = strings.Replace(k, "[>]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += k + ">? "
res = append(res, v)
case "[<]":
k = strings.Replace(k, "[<]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += k + "<? "
res = append(res, v)
case "[!]":
k = strings.Replace(k, "[!]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where, res = that.notIn(k, v, where, res)
case "[#]":
k = strings.Replace(k, "[#]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += " " + k + "=" + ObjToStr(v) + " "
case "[##]": // 直接添加value到sql需要考虑防注入
where += " " + ObjToStr(v)
case "[#!]":
k = strings.Replace(k, "[#!]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += " " + k + "!=" + ObjToStr(v) + " "
case "[!#]":
k = strings.Replace(k, "[!#]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += " " + k + "!=" + ObjToStr(v) + " "
case "[~]":
k = strings.Replace(k, "[~]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += k + " LIKE ? "
v = "%" + ObjToStr(v) + "%"
res = append(res, v)
case "[!~]": // 左边任意
k = strings.Replace(k, "[!~]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += k + " LIKE ? "
v = "%" + ObjToStr(v) + ""
res = append(res, v)
case "[~!]": // 右边任意
k = strings.Replace(k, "[~!]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += k + " LIKE ? "
v = ObjToStr(v) + "%"
res = append(res, v)
case "[~~]": // 手动任意
k = strings.Replace(k, "[~~]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += k + " LIKE ? "
res = append(res, v)
default:
def = true
}
if def {
switch Substr(k, length-4, 4) {
case "[>=]":
k = strings.Replace(k, "[>=]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += k + ">=? "
res = append(res, v)
case "[<=]":
k = strings.Replace(k, "[<=]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += k + "<=? "
res = append(res, v)
case "[><]":
k = strings.Replace(k, "[><]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += k + " NOT BETWEEN ? AND ? "
vs := ObjToSlice(v)
res = append(res, vs[0])
res = append(res, vs[1])
case "[<>]":
k = strings.Replace(k, "[<>]", "", -1)
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
where += k + " BETWEEN ? AND ? "
vs := ObjToSlice(v)
res = append(res, vs[0])
res = append(res, vs[1])
default:
where, res = that.handleDefaultCondition(k, v, where, res)
}
}
} else {
where, res = that.handlePlainField(k, v, where, res)
}
return where, res
}
// handleDefaultCondition 处理默认条件(带方括号但不是特殊操作符)
func (that *HoTimeDB) handleDefaultCondition(k string, v interface{}, where string, res []interface{}) (string, []interface{}) {
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") {
vs := ObjToSlice(v)
if len(vs) == 0 {
// IN 空数组 -> 生成永假条件
where += "1=0 "
return where, res
}
if len(vs) == 1 {
where += k + "=? "
res = append(res, vs[0])
return where, res
}
// IN 优化:连续整数转为 BETWEEN
where, res = that.optimizeInCondition(k, vs, where, res)
} else {
where += k + "=? "
res = append(res, v)
}
return where, res
}
// handlePlainField 处理普通字段(无方括号)
func (that *HoTimeDB) handlePlainField(k string, v interface{}, where string, res []interface{}) (string, []interface{}) {
if !strings.Contains(k, ".") {
k = "`" + k + "` "
}
if v == nil {
where += k + " IS NULL "
} else if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") {
vs := ObjToSlice(v)
if len(vs) == 0 {
// IN 空数组 -> 生成永假条件
where += "1=0 "
return where, res
}
if len(vs) == 1 {
where += k + "=? "
res = append(res, vs[0])
return where, res
}
// IN 优化
where, res = that.optimizeInCondition(k, vs, where, res)
} else {
where += k + "=? "
res = append(res, v)
}
return where, res
}
// optimizeInCondition 优化 IN 条件(连续整数转为 BETWEEN
func (that *HoTimeDB) optimizeInCondition(k string, vs Slice, where string, res []interface{}) (string, []interface{}) {
min := int64(0)
isMin := true
IsRange := true
num := int64(0)
isNum := true
where1 := ""
res1 := Slice{}
where2 := k + " IN ("
res2 := Slice{}
for kvs := 0; kvs <= len(vs); kvs++ {
vsv := int64(0)
if kvs < len(vs) {
vsv = vs.GetCeilInt64(kvs)
// 确保是全部是int类型
if ObjToStr(vsv) != vs.GetString(kvs) {
IsRange = false
break
}
}
if isNum {
isNum = false
num = vsv
} else {
num++
}
if isMin {
isMin = false
min = vsv
}
// 不等于则到了分路口
if num != vsv {
// between
if num-min > 1 {
if where1 != "" {
where1 += " OR " + k + " BETWEEN ? AND ? "
} else {
where1 += k + " BETWEEN ? AND ? "
}
res1 = append(res1, min)
res1 = append(res1, num-1)
} else {
where2 += "?,"
res2 = append(res2, min)
}
min = vsv
num = vsv
}
}
if IsRange {
where3 := ""
if where1 != "" {
where3 += where1
res = append(res, res1...)
}
if len(res2) == 1 {
if where3 == "" {
where3 += k + " = ? "
} else {
where3 += " OR " + k + " = ? "
}
res = append(res, res2...)
} else if len(res2) > 1 {
where2 = where2[:len(where2)-1]
if where3 == "" {
where3 += where2 + ")"
} else {
where3 += " OR " + where2 + ")"
}
res = append(res, res2...)
}
if where3 != "" {
where += "(" + where3 + ")"
}
return where, res
}
// 非连续整数,使用普通 IN
where += k + " IN ("
res = append(res, vs...)
for i := 0; i < len(vs); i++ {
if i+1 != len(vs) {
where += "?,"
} else {
where += "?) "
}
}
return where, res
}
// notIn NOT IN 条件处理
func (that *HoTimeDB) notIn(k string, v interface{}, where string, res []interface{}) (string, []interface{}) {
if v == nil {
where += k + " IS NOT NULL "
} else if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") {
vs := ObjToSlice(v)
if len(vs) == 0 {
return where, res
}
where += k + " NOT IN ("
res = append(res, vs...)
for i := 0; i < len(vs); i++ {
if i+1 != len(vs) {
where += "?,"
} else {
where += "?) "
}
}
} else {
where += k + " !=? "
res = append(res, v)
}
return where, res
}
// cond 条件组合处理
func (that *HoTimeDB) cond(tag string, data Map) (string, []interface{}) {
where := " "
res := make([]interface{}, 0)
lens := len(data)
testQu := []string{}
for key := range data {
testQu = append(testQu, key)
}
sort.Strings(testQu)
for _, k := range testQu {
v := data[k]
x := 0
for i := 0; i < len(condition); i++ {
if condition[i] == strings.ToUpper(k) {
tw, ts := that.cond(strings.ToUpper(k), v.(Map))
if lens--; lens <= 0 {
where += "(" + tw + ") "
} else {
where += "(" + tw + ") " + tag + " "
}
res = append(res, ts...)
break
}
x++
}
if x == len(condition) {
tv, vv := that.varCond(k, v)
if tv == "" {
lens--
continue
}
res = append(res, vv...)
if lens--; lens <= 0 {
where += tv + ""
} else {
where += tv + " " + tag + " "
}
}
}
return where, res
}