hotime/db/crud.go
hoteas c2955d2500 feat(db): 实现数据库查询中的数组参数展开和空数组处理
- 在 Get 方法中添加无参数时的默认字段和 LIMIT 1 处理
- 实现 expandArrayPlaceholder 方法,自动展开 IN (?) 和 NOT IN (?) 中的数组参数
- 为空数组的 IN 条件生成 1=0 永假条件,NOT IN 生成 1=1 永真条件
- 在 queryWithRetry 和 execWithRetry 中集成数组占位符预处理
- 修复 where.go 中空切片条件的处理逻辑
- 添加完整的 IN/NOT IN 数组查询测试用例
- 更新 .gitignore 规则格式
2026-01-22 07:16:42 +08:00

662 lines
17 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package db
import (
"reflect"
"sort"
"strings"
. "code.hoteas.com/golang/hotime/common"
)
// Page 设置分页参数
// page: 页码从1开始
// pageRow: 每页数量
func (that *HoTimeDB) Page(page, pageRow int) *HoTimeDB {
if page < 1 {
page = 1
}
if pageRow < 1 {
pageRow = 10
}
offset := (page - 1) * pageRow
that.limitMu.Lock()
that.limit = Slice{offset, pageRow}
that.limitMu.Unlock()
return that
}
// PageSelect 分页查询
func (that *HoTimeDB) PageSelect(table string, qu ...interface{}) []Map {
that.limitMu.Lock()
limit := that.limit
that.limit = nil // 使用后清空,避免影响下次调用
that.limitMu.Unlock()
if limit == nil {
return that.Select(table, qu...)
}
// 根据参数数量处理 LIMIT 注入
switch len(qu) {
case 0:
// PageSelect("user") -> 只有表名,添加 LIMIT
qu = append(qu, "*", Map{"LIMIT": limit})
case 1:
// PageSelect("user", "*") 或 PageSelect("user", Map{...})
if reflect.ValueOf(qu[0]).Kind() == reflect.Map {
// 是 where 条件
temp := DeepCopyMap(qu[0]).(Map)
temp["LIMIT"] = limit
qu[0] = temp
} else {
// 是字段选择
qu = append(qu, Map{"LIMIT": limit})
}
case 2:
// PageSelect("user", "*", Map{...}) 或 PageSelect("user", joinSlice, "*")
if reflect.ValueOf(qu[1]).Kind() == reflect.Map {
temp := DeepCopyMap(qu[1]).(Map)
temp["LIMIT"] = limit
qu[1] = temp
} else {
// join 模式,需要追加 where
qu = append(qu, Map{"LIMIT": limit})
}
case 3:
// PageSelect("user", joinSlice, "*", Map{...})
temp := DeepCopyMap(qu[2]).(Map)
temp["LIMIT"] = limit
qu[2] = temp
}
return that.Select(table, qu...)
}
// Select 查询多条记录
func (that *HoTimeDB) Select(table string, qu ...interface{}) []Map {
query := "SELECT"
where := Map{}
qs := make([]interface{}, 0)
intQs, intWhere := 0, 1
join := false
if len(qu) == 3 {
intQs = 1
intWhere = 2
join = true
}
if len(qu) > 0 {
if reflect.ValueOf(qu[intQs]).Type().String() == "string" {
query += " " + qu[intQs].(string)
} else {
data := ObjToSlice(qu[intQs])
for i := 0; i < len(data); i++ {
k := data.GetString(i)
if strings.Contains(k, " AS ") || strings.Contains(k, ".") {
query += " " + k + " "
} else {
query += " `" + k + "` "
}
if i+1 != len(data) {
query = query + ", "
}
}
}
} else {
query += " *"
}
if !strings.Contains(table, ".") && !strings.Contains(table, " AS ") {
query += " FROM `" + that.Prefix + table + "` "
} else {
query += " FROM " + that.Prefix + table + " "
}
if join {
query += that.buildJoin(qu[0])
}
if len(qu) > 1 {
where = qu[intWhere].(Map)
}
temp, resWhere := that.where(where)
query += temp + ";"
qs = append(qs, resWhere...)
md5 := that.md5(query, qs...)
if that.HoTimeCache != nil && table != "cached" {
// 如果缓存有则从缓存取
cacheData := that.HoTimeCache.Db(table + ":" + md5)
if cacheData != nil && cacheData.Data != nil {
return cacheData.ToMapArray()
}
}
// 无缓存则数据库取
res := that.Query(query, qs...)
if res == nil {
res = []Map{}
}
// 缓存
if that.HoTimeCache != nil && table != "cached" {
_ = that.HoTimeCache.Db(table+":"+md5, res)
}
return res
}
// buildJoin 构建 JOIN 语句
func (that *HoTimeDB) buildJoin(joinData interface{}) string {
query := ""
var testQu = []string{}
testQuData := Map{}
if reflect.ValueOf(joinData).Type().String() == "common.Map" {
testQuData = joinData.(Map)
for key := range testQuData {
testQu = append(testQu, key)
}
}
if reflect.ValueOf(joinData).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(joinData).Type().String(), "[]") {
qu0 := ObjToSlice(joinData)
for key := range qu0 {
v := qu0.GetMap(key)
for k1, v1 := range v {
testQu = append(testQu, k1)
testQuData[k1] = v1
}
}
}
sort.Strings(testQu)
for _, k := range testQu {
v := testQuData[k]
switch Substr(k, 0, 3) {
case "[>]":
func() {
table := Substr(k, 3, len(k)-3)
if !strings.Contains(table, " ") {
table = "`" + table + "`"
}
query += " LEFT JOIN " + table + " ON " + v.(string) + " "
}()
case "[<]":
func() {
table := Substr(k, 3, len(k)-3)
if !strings.Contains(table, " ") {
table = "`" + table + "`"
}
query += " RIGHT JOIN " + table + " ON " + v.(string) + " "
}()
}
switch Substr(k, 0, 4) {
case "[<>]":
func() {
table := Substr(k, 4, len(k)-4)
if !strings.Contains(table, " ") {
table = "`" + table + "`"
}
query += " FULL JOIN " + table + " ON " + v.(string) + " "
}()
case "[><]":
func() {
table := Substr(k, 4, len(k)-4)
if !strings.Contains(table, " ") {
table = "`" + table + "`"
}
query += " INNER JOIN " + table + " ON " + v.(string) + " "
}()
}
}
return query
}
// Get 获取单条记录
func (that *HoTimeDB) Get(table string, qu ...interface{}) Map {
if len(qu) == 0 {
// 没有参数时,添加默认字段和 LIMIT
qu = append(qu, "*", Map{"LIMIT": 1})
} else if len(qu) == 1 {
qu = append(qu, Map{"LIMIT": 1})
} else if len(qu) == 2 {
temp := qu[1].(Map)
temp["LIMIT"] = 1
qu[1] = temp
} else if len(qu) == 3 {
temp := qu[2].(Map)
temp["LIMIT"] = 1
qu[2] = temp
}
data := that.Select(table, qu...)
if len(data) == 0 {
return nil
}
return data[0]
}
// Insert 插入新数据
func (that *HoTimeDB) Insert(table string, data map[string]interface{}) int64 {
values := make([]interface{}, 0)
queryString := " ("
valueString := " ("
lens := len(data)
tempLen := 0
for k, v := range data {
tempLen++
vstr := "?"
if Substr(k, len(k)-3, 3) == "[#]" {
k = strings.Replace(k, "[#]", "", -1)
vstr = ObjToStr(v)
if tempLen < lens {
queryString += "`" + k + "`,"
valueString += vstr + ","
} else {
queryString += "`" + k + "`) "
valueString += vstr + ");"
}
} else {
values = append(values, v)
if tempLen < lens {
queryString += "`" + k + "`,"
valueString += "?,"
} else {
queryString += "`" + k + "`) "
valueString += "?);"
}
}
}
query := "INSERT INTO `" + that.Prefix + table + "` " + queryString + "VALUES" + valueString
res, err := that.Exec(query, values...)
id := int64(0)
if err.GetError() == nil && res != nil {
id1, err := res.LastInsertId()
that.LastErr.SetError(err)
id = id1
}
// 如果插入成功,删除缓存
if id != 0 {
if that.HoTimeCache != nil && table != "cached" {
_ = that.HoTimeCache.Db(table+"*", nil)
}
}
return id
}
// BatchInsert 批量插入数据
// table: 表名
// dataList: 数据列表,每个元素是一个 Map
// 返回受影响的行数
//
// 示例:
//
// affected := db.BatchInsert("user", []Map{
// {"name": "张三", "age": 25, "email": "zhang@example.com"},
// {"name": "李四", "age": 30, "email": "li@example.com"},
// {"name": "王五", "age": 28, "email": "wang@example.com"},
// })
func (that *HoTimeDB) BatchInsert(table string, dataList []Map) int64 {
if len(dataList) == 0 {
return 0
}
// 从第一条数据提取所有列名(确保顺序一致)
columns := make([]string, 0)
rawValues := make(map[string]string) // 存储 [#] 标记的直接 SQL 值
for k := range dataList[0] {
realKey := k
if Substr(k, len(k)-3, 3) == "[#]" {
realKey = strings.Replace(k, "[#]", "", -1)
rawValues[realKey] = ObjToStr(dataList[0][k])
}
columns = append(columns, realKey)
}
// 排序列名以确保一致性
sort.Strings(columns)
// 构建列名部分
quotedCols := make([]string, len(columns))
for i, col := range columns {
quotedCols[i] = "`" + col + "`"
}
colStr := strings.Join(quotedCols, ", ")
// 构建每行的占位符和值
placeholders := make([]string, len(dataList))
values := make([]interface{}, 0, len(dataList)*len(columns))
for i, data := range dataList {
rowPlaceholders := make([]string, len(columns))
for j, col := range columns {
// 检查是否有 [#] 标记
rawKey := col + "[#]"
if rawVal, ok := data[rawKey]; ok {
// 直接 SQL 表达式
rowPlaceholders[j] = ObjToStr(rawVal)
} else if _, isRaw := rawValues[col]; isRaw && i == 0 {
// 第一条数据中的 [#] 标记
rowPlaceholders[j] = rawValues[col]
} else if val, ok := data[col]; ok {
// 普通值
rowPlaceholders[j] = "?"
values = append(values, val)
} else {
// 字段不存在,使用 NULL
rowPlaceholders[j] = "NULL"
}
}
placeholders[i] = "(" + strings.Join(rowPlaceholders, ", ") + ")"
}
query := "INSERT INTO `" + that.Prefix + table + "` (" + colStr + ") VALUES " + strings.Join(placeholders, ", ")
res, err := that.Exec(query, values...)
rows64 := int64(0)
if err.GetError() == nil && res != nil {
rows64, _ = res.RowsAffected()
}
// 如果插入成功,删除缓存
if rows64 != 0 {
if that.HoTimeCache != nil && table != "cached" {
_ = that.HoTimeCache.Db(table+"*", nil)
}
}
return rows64
}
// Upsert 插入或更新数据
// table: 表名
// data: 要插入的数据
// uniqueKeys: 唯一键字段(用于冲突检测),支持 Slice{"id"} 或 Slice{"col1", "col2"}
// updateColumns: 冲突时要更新的字段(如果为空,则更新所有非唯一键字段)
// 返回受影响的行数
//
// 示例:
//
// affected := db.Upsert("user",
// Map{"id": 1, "name": "张三", "email": "zhang@example.com"},
// Slice{"id"}, // 唯一键
// Slice{"name", "email"}, // 冲突时更新的字段
// )
func (that *HoTimeDB) Upsert(table string, data Map, uniqueKeys Slice, updateColumns ...interface{}) int64 {
if len(data) == 0 || len(uniqueKeys) == 0 {
return 0
}
// 转换 uniqueKeys 为 []string
uniqueKeyStrs := make([]string, len(uniqueKeys))
for i, uk := range uniqueKeys {
uniqueKeyStrs[i] = ObjToStr(uk)
}
// 转换 updateColumns 为 []string
var updateColumnStrs []string
if len(updateColumns) > 0 {
// 支持两种调用方式Upsert(table, data, Slice{"id"}, Slice{"name"}) 或 Upsert(table, data, Slice{"id"}, "name", "email")
if slice, ok := updateColumns[0].(Slice); ok {
updateColumnStrs = make([]string, len(slice))
for i, col := range slice {
updateColumnStrs[i] = ObjToStr(col)
}
} else {
updateColumnStrs = make([]string, len(updateColumns))
for i, col := range updateColumns {
updateColumnStrs[i] = ObjToStr(col)
}
}
}
// 收集列和值
columns := make([]string, 0, len(data))
values := make([]interface{}, 0, len(data))
rawValues := make(map[string]string) // 存储 [#] 标记的直接 SQL 值
for k, v := range data {
if Substr(k, len(k)-3, 3) == "[#]" {
realKey := strings.Replace(k, "[#]", "", -1)
columns = append(columns, realKey)
rawValues[realKey] = ObjToStr(v)
} else {
columns = append(columns, k)
values = append(values, v)
}
}
// 如果没有指定更新字段,则更新所有非唯一键字段
if len(updateColumnStrs) == 0 {
uniqueKeySet := make(map[string]bool)
for _, uk := range uniqueKeyStrs {
uniqueKeySet[uk] = true
}
for _, col := range columns {
if !uniqueKeySet[col] {
updateColumnStrs = append(updateColumnStrs, col)
}
}
}
// 构建 SQL
var query string
dbType := that.Type
if dbType == "" {
dbType = "mysql"
}
switch dbType {
case "postgres", "postgresql":
query = that.buildPostgresUpsert(table, columns, uniqueKeyStrs, updateColumnStrs, rawValues)
case "sqlite3", "sqlite":
query = that.buildSQLiteUpsert(table, columns, uniqueKeyStrs, updateColumnStrs, rawValues)
default: // mysql
query = that.buildMySQLUpsert(table, columns, uniqueKeyStrs, updateColumnStrs, rawValues)
}
res, err := that.Exec(query, values...)
rows := int64(0)
if err.GetError() == nil && res != nil {
rows, _ = res.RowsAffected()
}
// 清除缓存
if rows != 0 {
if that.HoTimeCache != nil && table != "cached" {
_ = that.HoTimeCache.Db(table+"*", nil)
}
}
return rows
}
// buildMySQLUpsert 构建 MySQL 的 Upsert 语句
func (that *HoTimeDB) buildMySQLUpsert(table string, columns []string, uniqueKeys []string, updateColumns []string, rawValues map[string]string) string {
// INSERT INTO table (col1, col2) VALUES (?, ?)
// ON DUPLICATE KEY UPDATE col1 = VALUES(col1), col2 = VALUES(col2)
quotedCols := make([]string, len(columns))
valueParts := make([]string, len(columns))
for i, col := range columns {
quotedCols[i] = "`" + col + "`"
if raw, ok := rawValues[col]; ok {
valueParts[i] = raw
} else {
valueParts[i] = "?"
}
}
updateParts := make([]string, len(updateColumns))
for i, col := range updateColumns {
if raw, ok := rawValues[col]; ok {
updateParts[i] = "`" + col + "` = " + raw
} else {
updateParts[i] = "`" + col + "` = VALUES(`" + col + "`)"
}
}
return "INSERT INTO `" + that.Prefix + table + "` (" + strings.Join(quotedCols, ", ") +
") VALUES (" + strings.Join(valueParts, ", ") +
") ON DUPLICATE KEY UPDATE " + strings.Join(updateParts, ", ")
}
// buildPostgresUpsert 构建 PostgreSQL 的 Upsert 语句
func (that *HoTimeDB) buildPostgresUpsert(table string, columns []string, uniqueKeys []string, updateColumns []string, rawValues map[string]string) string {
// INSERT INTO table (col1, col2) VALUES ($1, $2)
// ON CONFLICT (unique_key) DO UPDATE SET col1 = EXCLUDED.col1
quotedCols := make([]string, len(columns))
valueParts := make([]string, len(columns))
paramIndex := 1
for i, col := range columns {
quotedCols[i] = "\"" + col + "\""
if raw, ok := rawValues[col]; ok {
valueParts[i] = raw
} else {
valueParts[i] = "$" + ObjToStr(paramIndex)
paramIndex++
}
}
quotedUniqueKeys := make([]string, len(uniqueKeys))
for i, key := range uniqueKeys {
quotedUniqueKeys[i] = "\"" + key + "\""
}
updateParts := make([]string, len(updateColumns))
for i, col := range updateColumns {
if raw, ok := rawValues[col]; ok {
updateParts[i] = "\"" + col + "\" = " + raw
} else {
updateParts[i] = "\"" + col + "\" = EXCLUDED.\"" + col + "\""
}
}
return "INSERT INTO \"" + that.Prefix + table + "\" (" + strings.Join(quotedCols, ", ") +
") VALUES (" + strings.Join(valueParts, ", ") +
") ON CONFLICT (" + strings.Join(quotedUniqueKeys, ", ") +
") DO UPDATE SET " + strings.Join(updateParts, ", ")
}
// buildSQLiteUpsert 构建 SQLite 的 Upsert 语句
func (that *HoTimeDB) buildSQLiteUpsert(table string, columns []string, uniqueKeys []string, updateColumns []string, rawValues map[string]string) string {
// INSERT INTO table (col1, col2) VALUES (?, ?)
// ON CONFLICT (unique_key) DO UPDATE SET col1 = excluded.col1
quotedCols := make([]string, len(columns))
valueParts := make([]string, len(columns))
for i, col := range columns {
quotedCols[i] = "\"" + col + "\""
if raw, ok := rawValues[col]; ok {
valueParts[i] = raw
} else {
valueParts[i] = "?"
}
}
quotedUniqueKeys := make([]string, len(uniqueKeys))
for i, key := range uniqueKeys {
quotedUniqueKeys[i] = "\"" + key + "\""
}
updateParts := make([]string, len(updateColumns))
for i, col := range updateColumns {
if raw, ok := rawValues[col]; ok {
updateParts[i] = "\"" + col + "\" = " + raw
} else {
updateParts[i] = "\"" + col + "\" = excluded.\"" + col + "\""
}
}
return "INSERT INTO \"" + that.Prefix + table + "\" (" + strings.Join(quotedCols, ", ") +
") VALUES (" + strings.Join(valueParts, ", ") +
") ON CONFLICT (" + strings.Join(quotedUniqueKeys, ", ") +
") DO UPDATE SET " + strings.Join(updateParts, ", ")
}
// Update 更新数据
func (that *HoTimeDB) Update(table string, data Map, where Map) int64 {
query := "UPDATE `" + that.Prefix + table + "` SET "
qs := make([]interface{}, 0)
tp := len(data)
for k, v := range data {
vstr := "?"
if Substr(k, len(k)-3, 3) == "[#]" {
k = strings.Replace(k, "[#]", "", -1)
vstr = ObjToStr(v)
} else {
qs = append(qs, v)
}
query += "`" + k + "`=" + vstr + " "
if tp--; tp != 0 {
query += ", "
}
}
temp, resWhere := that.where(where)
query += temp + ";"
qs = append(qs, resWhere...)
res, err := that.Exec(query, qs...)
rows := int64(0)
if err.GetError() == nil && res != nil {
rows, _ = res.RowsAffected()
}
// 如果更新成功,则删除缓存
if rows != 0 {
if that.HoTimeCache != nil && table != "cached" {
_ = that.HoTimeCache.Db(table+"*", nil)
}
}
return rows
}
// Delete 删除数据
func (that *HoTimeDB) Delete(table string, data map[string]interface{}) int64 {
query := "DELETE FROM `" + that.Prefix + table + "` "
temp, resWhere := that.where(data)
query += temp + ";"
res, err := that.Exec(query, resWhere...)
rows := int64(0)
if err.GetError() == nil && res != nil {
rows, _ = res.RowsAffected()
}
// 如果删除成功,删除对应缓存
if rows != 0 {
if that.HoTimeCache != nil && table != "cached" {
_ = that.HoTimeCache.Db(table+"*", nil)
}
}
return rows
}