hotime/db/where.go
hoteas d1b905b780 test(db): 添加 OR 条件处理的单元测试并修复 WHERE 子句逻辑
- 添加了 TestWhereWithORCondition 测试函数验证 OR 条件的括号包裹
- 修复了 where.go 中条件计数变量名称从 normalCondCount 改为 condCount
- 实现了 OR/AND 组条件的括号包裹逻辑确保 SQL 语法正确
- 添加了空条件检查避免生成无效的 SQL 片段
- 更新了 .gitignore 文件添加日志文件忽略规则
2026-01-23 01:51:35 +08:00

556 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package db
import (
. "code.hoteas.com/golang/hotime/common"
"reflect"
"sort"
"strings"
)
// 条件关键字
var condition = []string{"AND", "OR"}
// 特殊关键字(支持大小写)
var vcond = []string{"GROUP", "ORDER", "LIMIT", "DISTINCT", "HAVING", "OFFSET"}
// normalizeKey 标准化关键字(转大写)
func normalizeKey(k string) string {
upper := strings.ToUpper(k)
for _, v := range vcond {
if upper == v {
return v
}
}
for _, v := range condition {
if upper == v {
return v
}
}
return k
}
// isConditionKey 判断是否是条件关键字
func isConditionKey(k string) bool {
upper := strings.ToUpper(k)
for _, v := range condition {
if upper == v {
return true
}
}
return false
}
// isVcondKey 判断是否是特殊关键字
func isVcondKey(k string) bool {
upper := strings.ToUpper(k)
for _, v := range vcond {
if upper == v {
return true
}
}
return false
}
// where 语句解析
func (that *HoTimeDB) where(data Map) (string, []interface{}) {
where := ""
res := make([]interface{}, 0)
// 标准化 Map 的 key大小写兼容
normalizedData := Map{}
for k, v := range data {
normalizedData[normalizeKey(k)] = v
}
data = normalizedData
// 收集所有 key 并排序
testQu := []string{}
for key := range data {
testQu = append(testQu, key)
}
sort.Strings(testQu)
// 追踪条件数量,用于自动添加 AND
condCount := 0
for _, k := range testQu {
v := data[k]
// 检查是否是 AND/OR 条件关键字
if isConditionKey(k) {
tw, ts := that.cond(strings.ToUpper(k), v.(Map))
if tw != "" && strings.TrimSpace(tw) != "" {
// 与前面的条件用 AND 连接
if condCount > 0 {
where += " AND "
}
// 用括号包裹 OR/AND 组条件
where += "(" + strings.TrimSpace(tw) + ")"
condCount++
res = append(res, ts...)
}
continue
}
// 检查是否是特殊关键字GROUP, ORDER, LIMIT 等)
if isVcondKey(k) {
continue // 特殊关键字在后面单独处理
}
// 处理普通条件字段
// 空切片的 IN 条件应该生成永假条件1=0而不是跳过
if v != nil && reflect.ValueOf(v).Type().String() == "common.Slice" && len(v.(Slice)) == 0 {
// 检查是否是 NOT IN带 [!] 后缀)- NOT IN 空数组永真,跳过即可
if !strings.HasSuffix(k, "[!]") {
// IN 空数组 -> 生成永假条件
if condCount > 0 {
where += " AND "
}
where += "1=0 "
condCount++
}
continue
}
if v != nil && strings.Contains(reflect.ValueOf(v).Type().String(), "[]") && len(ObjToSlice(v)) == 0 {
// 检查是否是 NOT IN带 [!] 后缀)- NOT IN 空数组永真,跳过即可
if !strings.HasSuffix(k, "[!]") {
// IN 空数组 -> 生成永假条件
if condCount > 0 {
where += " AND "
}
where += "1=0 "
condCount++
}
continue
}
tv, vv := that.varCond(k, v)
if tv != "" {
// 自动添加 AND 连接符
if condCount > 0 {
where += " AND "
}
where += tv
condCount++
res = append(res, vv...)
}
}
// 添加 WHERE 关键字
// 先去除首尾空格,检查是否有实际条件内容
trimmedWhere := strings.TrimSpace(where)
if len(trimmedWhere) != 0 {
hasWhere := true
for _, v := range vcond {
if strings.Index(trimmedWhere, v) == 0 {
hasWhere = false
}
}
if hasWhere {
where = " WHERE " + trimmedWhere + " "
}
} else {
// 没有实际条件内容,重置 where
where = ""
}
// 处理特殊字符按固定顺序GROUP, HAVING, ORDER, LIMIT, OFFSET
specialOrder := []string{"GROUP", "HAVING", "ORDER", "LIMIT", "OFFSET", "DISTINCT"}
for _, vcondKey := range specialOrder {
v, exists := data[vcondKey]
if !exists {
continue
}
switch vcondKey {
case "GROUP":
where += " GROUP BY "
where += that.formatVcondValue(v)
case "HAVING":
// HAVING 条件处理
if havingMap, ok := v.(Map); ok {
havingWhere, havingRes := that.cond("AND", havingMap)
if havingWhere != "" {
where += " HAVING " + strings.TrimSpace(havingWhere) + " "
res = append(res, havingRes...)
}
}
case "ORDER":
where += " ORDER BY "
where += that.formatVcondValue(v)
case "LIMIT":
where += " LIMIT "
where += that.formatVcondValue(v)
case "OFFSET":
where += " OFFSET " + ObjToStr(v) + " "
case "DISTINCT":
// DISTINCT 通常在 SELECT 中处理,这里暂时忽略
}
}
return where, res
}
// formatVcondValue 格式化特殊关键字的值
func (that *HoTimeDB) formatVcondValue(v interface{}) string {
result := ""
if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") {
vs := ObjToSlice(v)
for i := 0; i < len(vs); i++ {
result += " " + vs.GetString(i) + " "
if len(vs) != i+1 {
result += ", "
}
}
} else {
result += " " + ObjToStr(v) + " "
}
return result
}
// varCond 变量条件解析
func (that *HoTimeDB) varCond(k string, v interface{}) (string, []interface{}) {
where := ""
res := make([]interface{}, 0)
length := len(k)
processor := that.GetProcessor()
if k == "[#]" {
k = strings.Replace(k, "[#]", "", -1)
where += " " + ObjToStr(v) + " "
} else if k == "[##]" {
// 直接添加 SQL 片段key 为 [##] 时)
where += " " + ObjToStr(v) + " "
} else if length > 0 && strings.Contains(k, "[") && k[length-1] == ']' {
def := false
switch Substr(k, length-3, 3) {
case "[>]":
k = strings.Replace(k, "[>]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + ">? "
res = append(res, v)
case "[<]":
k = strings.Replace(k, "[<]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + "<? "
res = append(res, v)
case "[!]":
k = strings.Replace(k, "[!]", "", -1)
k = processor.ProcessColumn(k) + " "
where, res = that.notIn(k, v, where, res)
case "[#]":
k = strings.Replace(k, "[#]", "", -1)
k = processor.ProcessColumn(k) + " "
where += " " + k + "=" + ObjToStr(v) + " "
case "[##]": // 直接添加value到sql需要考虑防注入
where += " " + ObjToStr(v)
case "[#!]":
k = strings.Replace(k, "[#!]", "", -1)
k = processor.ProcessColumn(k) + " "
where += " " + k + "!=" + ObjToStr(v) + " "
case "[!#]":
k = strings.Replace(k, "[!#]", "", -1)
k = processor.ProcessColumn(k) + " "
where += " " + k + "!=" + ObjToStr(v) + " "
case "[~]":
k = strings.Replace(k, "[~]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + " LIKE ? "
v = "%" + ObjToStr(v) + "%"
res = append(res, v)
case "[!~]": // 左边任意
k = strings.Replace(k, "[!~]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + " LIKE ? "
v = "%" + ObjToStr(v) + ""
res = append(res, v)
case "[~!]": // 右边任意
k = strings.Replace(k, "[~!]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + " LIKE ? "
v = ObjToStr(v) + "%"
res = append(res, v)
case "[~~]": // 手动任意
k = strings.Replace(k, "[~~]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + " LIKE ? "
res = append(res, v)
default:
def = true
}
if def {
switch Substr(k, length-4, 4) {
case "[>=]":
k = strings.Replace(k, "[>=]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + ">=? "
res = append(res, v)
case "[<=]":
k = strings.Replace(k, "[<=]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + "<=? "
res = append(res, v)
case "[><]":
k = strings.Replace(k, "[><]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + " NOT BETWEEN ? AND ? "
vs := ObjToSlice(v)
res = append(res, vs[0])
res = append(res, vs[1])
case "[<>]":
k = strings.Replace(k, "[<>]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + " BETWEEN ? AND ? "
vs := ObjToSlice(v)
res = append(res, vs[0])
res = append(res, vs[1])
default:
where, res = that.handleDefaultCondition(k, v, where, res)
}
}
} else {
where, res = that.handlePlainField(k, v, where, res)
}
return where, res
}
// handleDefaultCondition 处理默认条件(带方括号但不是特殊操作符)
func (that *HoTimeDB) handleDefaultCondition(k string, v interface{}, where string, res []interface{}) (string, []interface{}) {
processor := that.GetProcessor()
k = processor.ProcessColumn(k) + " "
if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") {
vs := ObjToSlice(v)
if len(vs) == 0 {
// IN 空数组 -> 生成永假条件
where += "1=0 "
return where, res
}
if len(vs) == 1 {
where += k + "=? "
res = append(res, vs[0])
return where, res
}
// IN 优化:连续整数转为 BETWEEN
where, res = that.optimizeInCondition(k, vs, where, res)
} else {
where += k + "=? "
res = append(res, v)
}
return where, res
}
// handlePlainField 处理普通字段(无方括号)
func (that *HoTimeDB) handlePlainField(k string, v interface{}, where string, res []interface{}) (string, []interface{}) {
processor := that.GetProcessor()
k = processor.ProcessColumn(k) + " "
if v == nil {
where += k + " IS NULL "
} else if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") {
vs := ObjToSlice(v)
if len(vs) == 0 {
// IN 空数组 -> 生成永假条件
where += "1=0 "
return where, res
}
if len(vs) == 1 {
where += k + "=? "
res = append(res, vs[0])
return where, res
}
// IN 优化
where, res = that.optimizeInCondition(k, vs, where, res)
} else {
where += k + "=? "
res = append(res, v)
}
return where, res
}
// optimizeInCondition 优化 IN 条件(连续整数转为 BETWEEN
func (that *HoTimeDB) optimizeInCondition(k string, vs Slice, where string, res []interface{}) (string, []interface{}) {
min := int64(0)
isMin := true
IsRange := true
num := int64(0)
isNum := true
where1 := ""
res1 := Slice{}
where2 := k + " IN ("
res2 := Slice{}
for kvs := 0; kvs <= len(vs); kvs++ {
vsv := int64(0)
if kvs < len(vs) {
vsv = vs.GetCeilInt64(kvs)
// 确保是全部是int类型
if ObjToStr(vsv) != vs.GetString(kvs) {
IsRange = false
break
}
}
if isNum {
isNum = false
num = vsv
} else {
num++
}
if isMin {
isMin = false
min = vsv
}
// 不等于则到了分路口
if num != vsv {
// between
if num-min > 1 {
if where1 != "" {
where1 += " OR " + k + " BETWEEN ? AND ? "
} else {
where1 += k + " BETWEEN ? AND ? "
}
res1 = append(res1, min)
res1 = append(res1, num-1)
} else {
where2 += "?,"
res2 = append(res2, min)
}
min = vsv
num = vsv
}
}
if IsRange {
where3 := ""
if where1 != "" {
where3 += where1
res = append(res, res1...)
}
if len(res2) == 1 {
if where3 == "" {
where3 += k + " = ? "
} else {
where3 += " OR " + k + " = ? "
}
res = append(res, res2...)
} else if len(res2) > 1 {
where2 = where2[:len(where2)-1]
if where3 == "" {
where3 += where2 + ")"
} else {
where3 += " OR " + where2 + ")"
}
res = append(res, res2...)
}
if where3 != "" {
where += "(" + where3 + ")"
}
return where, res
}
// 非连续整数,使用普通 IN
where += k + " IN ("
res = append(res, vs...)
for i := 0; i < len(vs); i++ {
if i+1 != len(vs) {
where += "?,"
} else {
where += "?) "
}
}
return where, res
}
// notIn NOT IN 条件处理
func (that *HoTimeDB) notIn(k string, v interface{}, where string, res []interface{}) (string, []interface{}) {
if v == nil {
where += k + " IS NOT NULL "
} else if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") {
vs := ObjToSlice(v)
if len(vs) == 0 {
return where, res
}
where += k + " NOT IN ("
res = append(res, vs...)
for i := 0; i < len(vs); i++ {
if i+1 != len(vs) {
where += "?,"
} else {
where += "?) "
}
}
} else {
where += k + " !=? "
res = append(res, v)
}
return where, res
}
// cond 条件组合处理
func (that *HoTimeDB) cond(tag string, data Map) (string, []interface{}) {
where := " "
res := make([]interface{}, 0)
lens := len(data)
testQu := []string{}
for key := range data {
testQu = append(testQu, key)
}
sort.Strings(testQu)
for _, k := range testQu {
v := data[k]
x := 0
for i := 0; i < len(condition); i++ {
if condition[i] == strings.ToUpper(k) {
tw, ts := that.cond(strings.ToUpper(k), v.(Map))
if lens--; lens <= 0 {
where += "(" + tw + ") "
} else {
where += "(" + tw + ") " + tag + " "
}
res = append(res, ts...)
break
}
x++
}
if x == len(condition) {
tv, vv := that.varCond(k, v)
if tv == "" {
lens--
continue
}
res = append(res, vv...)
if lens--; lens <= 0 {
where += tv + ""
} else {
where += tv + " " + tag + " "
}
}
}
return where, res
}