hotime/db/where.go
hoteas 650fafad1a refactor(db): 重构数据库查询构建器以支持多数据库方言和标识符处理
- 实现了标识符处理器,统一处理表名、字段名的前缀添加和引号转换
- 添加对 MySQL、PostgreSQL、SQLite 三种数据库方言的支持
- 引入 ProcessTableName、ProcessColumn、ProcessConditionString 等方法处理标识符
- 为 HoTimeDB 添加 T() 和 C() 辅助方法用于手动构建 SQL 查询
- 重构 CRUD 操作中的表名和字段名处理逻辑,统一使用标识符处理器
- 添加完整的单元测试验证不同数据库方言下的标识符处理功能
- 优化 JOIN 操作中表名和条件字符串的处理方式
2026-01-22 09:32:01 +08:00

548 lines
13 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package db
import (
. "code.hoteas.com/golang/hotime/common"
"reflect"
"sort"
"strings"
)
// 条件关键字
var condition = []string{"AND", "OR"}
// 特殊关键字(支持大小写)
var vcond = []string{"GROUP", "ORDER", "LIMIT", "DISTINCT", "HAVING", "OFFSET"}
// normalizeKey 标准化关键字(转大写)
func normalizeKey(k string) string {
upper := strings.ToUpper(k)
for _, v := range vcond {
if upper == v {
return v
}
}
for _, v := range condition {
if upper == v {
return v
}
}
return k
}
// isConditionKey 判断是否是条件关键字
func isConditionKey(k string) bool {
upper := strings.ToUpper(k)
for _, v := range condition {
if upper == v {
return true
}
}
return false
}
// isVcondKey 判断是否是特殊关键字
func isVcondKey(k string) bool {
upper := strings.ToUpper(k)
for _, v := range vcond {
if upper == v {
return true
}
}
return false
}
// where 语句解析
func (that *HoTimeDB) where(data Map) (string, []interface{}) {
where := ""
res := make([]interface{}, 0)
// 标准化 Map 的 key大小写兼容
normalizedData := Map{}
for k, v := range data {
normalizedData[normalizeKey(k)] = v
}
data = normalizedData
// 收集所有 key 并排序
testQu := []string{}
for key := range data {
testQu = append(testQu, key)
}
sort.Strings(testQu)
// 追踪普通条件数量,用于自动添加 AND
normalCondCount := 0
for _, k := range testQu {
v := data[k]
// 检查是否是 AND/OR 条件关键字
if isConditionKey(k) {
tw, ts := that.cond(strings.ToUpper(k), v.(Map))
where += tw
res = append(res, ts...)
continue
}
// 检查是否是特殊关键字GROUP, ORDER, LIMIT 等)
if isVcondKey(k) {
continue // 特殊关键字在后面单独处理
}
// 处理普通条件字段
// 空切片的 IN 条件应该生成永假条件1=0而不是跳过
if v != nil && reflect.ValueOf(v).Type().String() == "common.Slice" && len(v.(Slice)) == 0 {
// 检查是否是 NOT IN带 [!] 后缀)- NOT IN 空数组永真,跳过即可
if !strings.HasSuffix(k, "[!]") {
// IN 空数组 -> 生成永假条件
if normalCondCount > 0 {
where += " AND "
}
where += "1=0 "
normalCondCount++
}
continue
}
if v != nil && strings.Contains(reflect.ValueOf(v).Type().String(), "[]") && len(ObjToSlice(v)) == 0 {
// 检查是否是 NOT IN带 [!] 后缀)- NOT IN 空数组永真,跳过即可
if !strings.HasSuffix(k, "[!]") {
// IN 空数组 -> 生成永假条件
if normalCondCount > 0 {
where += " AND "
}
where += "1=0 "
normalCondCount++
}
continue
}
tv, vv := that.varCond(k, v)
if tv != "" {
// 自动添加 AND 连接符
if normalCondCount > 0 {
where += " AND "
}
where += tv
normalCondCount++
res = append(res, vv...)
}
}
// 添加 WHERE 关键字
// 先去除首尾空格,检查是否有实际条件内容
trimmedWhere := strings.TrimSpace(where)
if len(trimmedWhere) != 0 {
hasWhere := true
for _, v := range vcond {
if strings.Index(trimmedWhere, v) == 0 {
hasWhere = false
}
}
if hasWhere {
where = " WHERE " + trimmedWhere + " "
}
} else {
// 没有实际条件内容,重置 where
where = ""
}
// 处理特殊字符按固定顺序GROUP, HAVING, ORDER, LIMIT, OFFSET
specialOrder := []string{"GROUP", "HAVING", "ORDER", "LIMIT", "OFFSET", "DISTINCT"}
for _, vcondKey := range specialOrder {
v, exists := data[vcondKey]
if !exists {
continue
}
switch vcondKey {
case "GROUP":
where += " GROUP BY "
where += that.formatVcondValue(v)
case "HAVING":
// HAVING 条件处理
if havingMap, ok := v.(Map); ok {
havingWhere, havingRes := that.cond("AND", havingMap)
if havingWhere != "" {
where += " HAVING " + strings.TrimSpace(havingWhere) + " "
res = append(res, havingRes...)
}
}
case "ORDER":
where += " ORDER BY "
where += that.formatVcondValue(v)
case "LIMIT":
where += " LIMIT "
where += that.formatVcondValue(v)
case "OFFSET":
where += " OFFSET " + ObjToStr(v) + " "
case "DISTINCT":
// DISTINCT 通常在 SELECT 中处理,这里暂时忽略
}
}
return where, res
}
// formatVcondValue 格式化特殊关键字的值
func (that *HoTimeDB) formatVcondValue(v interface{}) string {
result := ""
if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") {
vs := ObjToSlice(v)
for i := 0; i < len(vs); i++ {
result += " " + vs.GetString(i) + " "
if len(vs) != i+1 {
result += ", "
}
}
} else {
result += " " + ObjToStr(v) + " "
}
return result
}
// varCond 变量条件解析
func (that *HoTimeDB) varCond(k string, v interface{}) (string, []interface{}) {
where := ""
res := make([]interface{}, 0)
length := len(k)
processor := that.GetProcessor()
if k == "[#]" {
k = strings.Replace(k, "[#]", "", -1)
where += " " + ObjToStr(v) + " "
} else if k == "[##]" {
// 直接添加 SQL 片段key 为 [##] 时)
where += " " + ObjToStr(v) + " "
} else if length > 0 && strings.Contains(k, "[") && k[length-1] == ']' {
def := false
switch Substr(k, length-3, 3) {
case "[>]":
k = strings.Replace(k, "[>]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + ">? "
res = append(res, v)
case "[<]":
k = strings.Replace(k, "[<]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + "<? "
res = append(res, v)
case "[!]":
k = strings.Replace(k, "[!]", "", -1)
k = processor.ProcessColumn(k) + " "
where, res = that.notIn(k, v, where, res)
case "[#]":
k = strings.Replace(k, "[#]", "", -1)
k = processor.ProcessColumn(k) + " "
where += " " + k + "=" + ObjToStr(v) + " "
case "[##]": // 直接添加value到sql需要考虑防注入
where += " " + ObjToStr(v)
case "[#!]":
k = strings.Replace(k, "[#!]", "", -1)
k = processor.ProcessColumn(k) + " "
where += " " + k + "!=" + ObjToStr(v) + " "
case "[!#]":
k = strings.Replace(k, "[!#]", "", -1)
k = processor.ProcessColumn(k) + " "
where += " " + k + "!=" + ObjToStr(v) + " "
case "[~]":
k = strings.Replace(k, "[~]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + " LIKE ? "
v = "%" + ObjToStr(v) + "%"
res = append(res, v)
case "[!~]": // 左边任意
k = strings.Replace(k, "[!~]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + " LIKE ? "
v = "%" + ObjToStr(v) + ""
res = append(res, v)
case "[~!]": // 右边任意
k = strings.Replace(k, "[~!]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + " LIKE ? "
v = ObjToStr(v) + "%"
res = append(res, v)
case "[~~]": // 手动任意
k = strings.Replace(k, "[~~]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + " LIKE ? "
res = append(res, v)
default:
def = true
}
if def {
switch Substr(k, length-4, 4) {
case "[>=]":
k = strings.Replace(k, "[>=]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + ">=? "
res = append(res, v)
case "[<=]":
k = strings.Replace(k, "[<=]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + "<=? "
res = append(res, v)
case "[><]":
k = strings.Replace(k, "[><]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + " NOT BETWEEN ? AND ? "
vs := ObjToSlice(v)
res = append(res, vs[0])
res = append(res, vs[1])
case "[<>]":
k = strings.Replace(k, "[<>]", "", -1)
k = processor.ProcessColumn(k) + " "
where += k + " BETWEEN ? AND ? "
vs := ObjToSlice(v)
res = append(res, vs[0])
res = append(res, vs[1])
default:
where, res = that.handleDefaultCondition(k, v, where, res)
}
}
} else {
where, res = that.handlePlainField(k, v, where, res)
}
return where, res
}
// handleDefaultCondition 处理默认条件(带方括号但不是特殊操作符)
func (that *HoTimeDB) handleDefaultCondition(k string, v interface{}, where string, res []interface{}) (string, []interface{}) {
processor := that.GetProcessor()
k = processor.ProcessColumn(k) + " "
if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") {
vs := ObjToSlice(v)
if len(vs) == 0 {
// IN 空数组 -> 生成永假条件
where += "1=0 "
return where, res
}
if len(vs) == 1 {
where += k + "=? "
res = append(res, vs[0])
return where, res
}
// IN 优化:连续整数转为 BETWEEN
where, res = that.optimizeInCondition(k, vs, where, res)
} else {
where += k + "=? "
res = append(res, v)
}
return where, res
}
// handlePlainField 处理普通字段(无方括号)
func (that *HoTimeDB) handlePlainField(k string, v interface{}, where string, res []interface{}) (string, []interface{}) {
processor := that.GetProcessor()
k = processor.ProcessColumn(k) + " "
if v == nil {
where += k + " IS NULL "
} else if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") {
vs := ObjToSlice(v)
if len(vs) == 0 {
// IN 空数组 -> 生成永假条件
where += "1=0 "
return where, res
}
if len(vs) == 1 {
where += k + "=? "
res = append(res, vs[0])
return where, res
}
// IN 优化
where, res = that.optimizeInCondition(k, vs, where, res)
} else {
where += k + "=? "
res = append(res, v)
}
return where, res
}
// optimizeInCondition 优化 IN 条件(连续整数转为 BETWEEN
func (that *HoTimeDB) optimizeInCondition(k string, vs Slice, where string, res []interface{}) (string, []interface{}) {
min := int64(0)
isMin := true
IsRange := true
num := int64(0)
isNum := true
where1 := ""
res1 := Slice{}
where2 := k + " IN ("
res2 := Slice{}
for kvs := 0; kvs <= len(vs); kvs++ {
vsv := int64(0)
if kvs < len(vs) {
vsv = vs.GetCeilInt64(kvs)
// 确保是全部是int类型
if ObjToStr(vsv) != vs.GetString(kvs) {
IsRange = false
break
}
}
if isNum {
isNum = false
num = vsv
} else {
num++
}
if isMin {
isMin = false
min = vsv
}
// 不等于则到了分路口
if num != vsv {
// between
if num-min > 1 {
if where1 != "" {
where1 += " OR " + k + " BETWEEN ? AND ? "
} else {
where1 += k + " BETWEEN ? AND ? "
}
res1 = append(res1, min)
res1 = append(res1, num-1)
} else {
where2 += "?,"
res2 = append(res2, min)
}
min = vsv
num = vsv
}
}
if IsRange {
where3 := ""
if where1 != "" {
where3 += where1
res = append(res, res1...)
}
if len(res2) == 1 {
if where3 == "" {
where3 += k + " = ? "
} else {
where3 += " OR " + k + " = ? "
}
res = append(res, res2...)
} else if len(res2) > 1 {
where2 = where2[:len(where2)-1]
if where3 == "" {
where3 += where2 + ")"
} else {
where3 += " OR " + where2 + ")"
}
res = append(res, res2...)
}
if where3 != "" {
where += "(" + where3 + ")"
}
return where, res
}
// 非连续整数,使用普通 IN
where += k + " IN ("
res = append(res, vs...)
for i := 0; i < len(vs); i++ {
if i+1 != len(vs) {
where += "?,"
} else {
where += "?) "
}
}
return where, res
}
// notIn NOT IN 条件处理
func (that *HoTimeDB) notIn(k string, v interface{}, where string, res []interface{}) (string, []interface{}) {
if v == nil {
where += k + " IS NOT NULL "
} else if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") {
vs := ObjToSlice(v)
if len(vs) == 0 {
return where, res
}
where += k + " NOT IN ("
res = append(res, vs...)
for i := 0; i < len(vs); i++ {
if i+1 != len(vs) {
where += "?,"
} else {
where += "?) "
}
}
} else {
where += k + " !=? "
res = append(res, v)
}
return where, res
}
// cond 条件组合处理
func (that *HoTimeDB) cond(tag string, data Map) (string, []interface{}) {
where := " "
res := make([]interface{}, 0)
lens := len(data)
testQu := []string{}
for key := range data {
testQu = append(testQu, key)
}
sort.Strings(testQu)
for _, k := range testQu {
v := data[k]
x := 0
for i := 0; i < len(condition); i++ {
if condition[i] == strings.ToUpper(k) {
tw, ts := that.cond(strings.ToUpper(k), v.(Map))
if lens--; lens <= 0 {
where += "(" + tw + ") "
} else {
where += "(" + tw + ") " + tag + " "
}
res = append(res, ts...)
break
}
x++
}
if x == len(condition) {
tv, vv := that.varCond(k, v)
if tv == "" {
lens--
continue
}
res = append(res, vv...)
if lens--; lens <= 0 {
where += tv + ""
} else {
where += tv + " " + tag + " "
}
}
}
return where, res
}