hotime/db/crud.go

662 lines
17 KiB
Go
Raw Normal View History

package db
import (
"reflect"
"sort"
"strings"
. "code.hoteas.com/golang/hotime/common"
)
// Page 设置分页参数
// page: 页码从1开始
// pageRow: 每页数量
func (that *HoTimeDB) Page(page, pageRow int) *HoTimeDB {
if page < 1 {
page = 1
}
if pageRow < 1 {
pageRow = 10
}
offset := (page - 1) * pageRow
that.limitMu.Lock()
that.limit = Slice{offset, pageRow}
that.limitMu.Unlock()
return that
}
// PageSelect 分页查询
func (that *HoTimeDB) PageSelect(table string, qu ...interface{}) []Map {
that.limitMu.Lock()
limit := that.limit
that.limit = nil // 使用后清空,避免影响下次调用
that.limitMu.Unlock()
if limit == nil {
return that.Select(table, qu...)
}
// 根据参数数量处理 LIMIT 注入
switch len(qu) {
case 0:
// PageSelect("user") -> 只有表名,添加 LIMIT
qu = append(qu, "*", Map{"LIMIT": limit})
case 1:
// PageSelect("user", "*") 或 PageSelect("user", Map{...})
if reflect.ValueOf(qu[0]).Kind() == reflect.Map {
// 是 where 条件
temp := DeepCopyMap(qu[0]).(Map)
temp["LIMIT"] = limit
qu[0] = temp
} else {
// 是字段选择
qu = append(qu, Map{"LIMIT": limit})
}
case 2:
// PageSelect("user", "*", Map{...}) 或 PageSelect("user", joinSlice, "*")
if reflect.ValueOf(qu[1]).Kind() == reflect.Map {
temp := DeepCopyMap(qu[1]).(Map)
temp["LIMIT"] = limit
qu[1] = temp
} else {
// join 模式,需要追加 where
qu = append(qu, Map{"LIMIT": limit})
}
case 3:
// PageSelect("user", joinSlice, "*", Map{...})
temp := DeepCopyMap(qu[2]).(Map)
temp["LIMIT"] = limit
qu[2] = temp
}
return that.Select(table, qu...)
}
// Select 查询多条记录
func (that *HoTimeDB) Select(table string, qu ...interface{}) []Map {
query := "SELECT"
where := Map{}
qs := make([]interface{}, 0)
intQs, intWhere := 0, 1
join := false
if len(qu) == 3 {
intQs = 1
intWhere = 2
join = true
}
if len(qu) > 0 {
if reflect.ValueOf(qu[intQs]).Type().String() == "string" {
query += " " + qu[intQs].(string)
} else {
data := ObjToSlice(qu[intQs])
for i := 0; i < len(data); i++ {
k := data.GetString(i)
if strings.Contains(k, " AS ") || strings.Contains(k, ".") {
query += " " + k + " "
} else {
query += " `" + k + "` "
}
if i+1 != len(data) {
query = query + ", "
}
}
}
} else {
query += " *"
}
if !strings.Contains(table, ".") && !strings.Contains(table, " AS ") {
query += " FROM `" + that.Prefix + table + "` "
} else {
query += " FROM " + that.Prefix + table + " "
}
if join {
query += that.buildJoin(qu[0])
}
if len(qu) > 1 {
where = qu[intWhere].(Map)
}
temp, resWhere := that.where(where)
query += temp + ";"
qs = append(qs, resWhere...)
md5 := that.md5(query, qs...)
if that.HoTimeCache != nil && table != "cached" {
// 如果缓存有则从缓存取
cacheData := that.HoTimeCache.Db(table + ":" + md5)
if cacheData != nil && cacheData.Data != nil {
return cacheData.ToMapArray()
}
}
// 无缓存则数据库取
res := that.Query(query, qs...)
if res == nil {
res = []Map{}
}
// 缓存
if that.HoTimeCache != nil && table != "cached" {
_ = that.HoTimeCache.Db(table+":"+md5, res)
}
return res
}
// buildJoin 构建 JOIN 语句
func (that *HoTimeDB) buildJoin(joinData interface{}) string {
query := ""
var testQu = []string{}
testQuData := Map{}
if reflect.ValueOf(joinData).Type().String() == "common.Map" {
testQuData = joinData.(Map)
for key := range testQuData {
testQu = append(testQu, key)
}
}
if reflect.ValueOf(joinData).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(joinData).Type().String(), "[]") {
qu0 := ObjToSlice(joinData)
for key := range qu0 {
v := qu0.GetMap(key)
for k1, v1 := range v {
testQu = append(testQu, k1)
testQuData[k1] = v1
}
}
}
sort.Strings(testQu)
for _, k := range testQu {
v := testQuData[k]
switch Substr(k, 0, 3) {
case "[>]":
func() {
table := Substr(k, 3, len(k)-3)
if !strings.Contains(table, " ") {
table = "`" + table + "`"
}
query += " LEFT JOIN " + table + " ON " + v.(string) + " "
}()
case "[<]":
func() {
table := Substr(k, 3, len(k)-3)
if !strings.Contains(table, " ") {
table = "`" + table + "`"
}
query += " RIGHT JOIN " + table + " ON " + v.(string) + " "
}()
}
switch Substr(k, 0, 4) {
case "[<>]":
func() {
table := Substr(k, 4, len(k)-4)
if !strings.Contains(table, " ") {
table = "`" + table + "`"
}
query += " FULL JOIN " + table + " ON " + v.(string) + " "
}()
case "[><]":
func() {
table := Substr(k, 4, len(k)-4)
if !strings.Contains(table, " ") {
table = "`" + table + "`"
}
query += " INNER JOIN " + table + " ON " + v.(string) + " "
}()
}
}
return query
}
// Get 获取单条记录
func (that *HoTimeDB) Get(table string, qu ...interface{}) Map {
if len(qu) == 0 {
// 没有参数时,添加默认字段和 LIMIT
qu = append(qu, "*", Map{"LIMIT": 1})
} else if len(qu) == 1 {
qu = append(qu, Map{"LIMIT": 1})
} else if len(qu) == 2 {
temp := qu[1].(Map)
temp["LIMIT"] = 1
qu[1] = temp
} else if len(qu) == 3 {
temp := qu[2].(Map)
temp["LIMIT"] = 1
qu[2] = temp
}
data := that.Select(table, qu...)
if len(data) == 0 {
return nil
}
return data[0]
}
// Insert 插入新数据
func (that *HoTimeDB) Insert(table string, data map[string]interface{}) int64 {
values := make([]interface{}, 0)
queryString := " ("
valueString := " ("
lens := len(data)
tempLen := 0
for k, v := range data {
tempLen++
vstr := "?"
if Substr(k, len(k)-3, 3) == "[#]" {
k = strings.Replace(k, "[#]", "", -1)
vstr = ObjToStr(v)
if tempLen < lens {
queryString += "`" + k + "`,"
valueString += vstr + ","
} else {
queryString += "`" + k + "`) "
valueString += vstr + ");"
}
} else {
values = append(values, v)
if tempLen < lens {
queryString += "`" + k + "`,"
valueString += "?,"
} else {
queryString += "`" + k + "`) "
valueString += "?);"
}
}
}
query := "INSERT INTO `" + that.Prefix + table + "` " + queryString + "VALUES" + valueString
res, err := that.Exec(query, values...)
id := int64(0)
if err.GetError() == nil && res != nil {
id1, err := res.LastInsertId()
that.LastErr.SetError(err)
id = id1
}
// 如果插入成功,删除缓存
if id != 0 {
if that.HoTimeCache != nil && table != "cached" {
_ = that.HoTimeCache.Db(table+"*", nil)
}
}
return id
}
// BatchInsert 批量插入数据
// table: 表名
// dataList: 数据列表,每个元素是一个 Map
// 返回受影响的行数
//
// 示例:
//
// affected := db.BatchInsert("user", []Map{
// {"name": "张三", "age": 25, "email": "zhang@example.com"},
// {"name": "李四", "age": 30, "email": "li@example.com"},
// {"name": "王五", "age": 28, "email": "wang@example.com"},
// })
func (that *HoTimeDB) BatchInsert(table string, dataList []Map) int64 {
if len(dataList) == 0 {
return 0
}
// 从第一条数据提取所有列名(确保顺序一致)
columns := make([]string, 0)
rawValues := make(map[string]string) // 存储 [#] 标记的直接 SQL 值
for k := range dataList[0] {
realKey := k
if Substr(k, len(k)-3, 3) == "[#]" {
realKey = strings.Replace(k, "[#]", "", -1)
rawValues[realKey] = ObjToStr(dataList[0][k])
}
columns = append(columns, realKey)
}
// 排序列名以确保一致性
sort.Strings(columns)
// 构建列名部分
quotedCols := make([]string, len(columns))
for i, col := range columns {
quotedCols[i] = "`" + col + "`"
}
colStr := strings.Join(quotedCols, ", ")
// 构建每行的占位符和值
placeholders := make([]string, len(dataList))
values := make([]interface{}, 0, len(dataList)*len(columns))
for i, data := range dataList {
rowPlaceholders := make([]string, len(columns))
for j, col := range columns {
// 检查是否有 [#] 标记
rawKey := col + "[#]"
if rawVal, ok := data[rawKey]; ok {
// 直接 SQL 表达式
rowPlaceholders[j] = ObjToStr(rawVal)
} else if _, isRaw := rawValues[col]; isRaw && i == 0 {
// 第一条数据中的 [#] 标记
rowPlaceholders[j] = rawValues[col]
} else if val, ok := data[col]; ok {
// 普通值
rowPlaceholders[j] = "?"
values = append(values, val)
} else {
// 字段不存在,使用 NULL
rowPlaceholders[j] = "NULL"
}
}
placeholders[i] = "(" + strings.Join(rowPlaceholders, ", ") + ")"
}
query := "INSERT INTO `" + that.Prefix + table + "` (" + colStr + ") VALUES " + strings.Join(placeholders, ", ")
res, err := that.Exec(query, values...)
rows64 := int64(0)
if err.GetError() == nil && res != nil {
rows64, _ = res.RowsAffected()
}
// 如果插入成功,删除缓存
if rows64 != 0 {
if that.HoTimeCache != nil && table != "cached" {
_ = that.HoTimeCache.Db(table+"*", nil)
}
}
return rows64
}
// Upsert 插入或更新数据
// table: 表名
// data: 要插入的数据
// uniqueKeys: 唯一键字段(用于冲突检测),支持 Slice{"id"} 或 Slice{"col1", "col2"}
// updateColumns: 冲突时要更新的字段(如果为空,则更新所有非唯一键字段)
// 返回受影响的行数
//
// 示例:
//
// affected := db.Upsert("user",
// Map{"id": 1, "name": "张三", "email": "zhang@example.com"},
// Slice{"id"}, // 唯一键
// Slice{"name", "email"}, // 冲突时更新的字段
// )
func (that *HoTimeDB) Upsert(table string, data Map, uniqueKeys Slice, updateColumns ...interface{}) int64 {
if len(data) == 0 || len(uniqueKeys) == 0 {
return 0
}
// 转换 uniqueKeys 为 []string
uniqueKeyStrs := make([]string, len(uniqueKeys))
for i, uk := range uniqueKeys {
uniqueKeyStrs[i] = ObjToStr(uk)
}
// 转换 updateColumns 为 []string
var updateColumnStrs []string
if len(updateColumns) > 0 {
// 支持两种调用方式Upsert(table, data, Slice{"id"}, Slice{"name"}) 或 Upsert(table, data, Slice{"id"}, "name", "email")
if slice, ok := updateColumns[0].(Slice); ok {
updateColumnStrs = make([]string, len(slice))
for i, col := range slice {
updateColumnStrs[i] = ObjToStr(col)
}
} else {
updateColumnStrs = make([]string, len(updateColumns))
for i, col := range updateColumns {
updateColumnStrs[i] = ObjToStr(col)
}
}
}
// 收集列和值
columns := make([]string, 0, len(data))
values := make([]interface{}, 0, len(data))
rawValues := make(map[string]string) // 存储 [#] 标记的直接 SQL 值
for k, v := range data {
if Substr(k, len(k)-3, 3) == "[#]" {
realKey := strings.Replace(k, "[#]", "", -1)
columns = append(columns, realKey)
rawValues[realKey] = ObjToStr(v)
} else {
columns = append(columns, k)
values = append(values, v)
}
}
// 如果没有指定更新字段,则更新所有非唯一键字段
if len(updateColumnStrs) == 0 {
uniqueKeySet := make(map[string]bool)
for _, uk := range uniqueKeyStrs {
uniqueKeySet[uk] = true
}
for _, col := range columns {
if !uniqueKeySet[col] {
updateColumnStrs = append(updateColumnStrs, col)
}
}
}
// 构建 SQL
var query string
dbType := that.Type
if dbType == "" {
dbType = "mysql"
}
switch dbType {
case "postgres", "postgresql":
query = that.buildPostgresUpsert(table, columns, uniqueKeyStrs, updateColumnStrs, rawValues)
case "sqlite3", "sqlite":
query = that.buildSQLiteUpsert(table, columns, uniqueKeyStrs, updateColumnStrs, rawValues)
default: // mysql
query = that.buildMySQLUpsert(table, columns, uniqueKeyStrs, updateColumnStrs, rawValues)
}
res, err := that.Exec(query, values...)
rows := int64(0)
if err.GetError() == nil && res != nil {
rows, _ = res.RowsAffected()
}
// 清除缓存
if rows != 0 {
if that.HoTimeCache != nil && table != "cached" {
_ = that.HoTimeCache.Db(table+"*", nil)
}
}
return rows
}
// buildMySQLUpsert 构建 MySQL 的 Upsert 语句
func (that *HoTimeDB) buildMySQLUpsert(table string, columns []string, uniqueKeys []string, updateColumns []string, rawValues map[string]string) string {
// INSERT INTO table (col1, col2) VALUES (?, ?)
// ON DUPLICATE KEY UPDATE col1 = VALUES(col1), col2 = VALUES(col2)
quotedCols := make([]string, len(columns))
valueParts := make([]string, len(columns))
for i, col := range columns {
quotedCols[i] = "`" + col + "`"
if raw, ok := rawValues[col]; ok {
valueParts[i] = raw
} else {
valueParts[i] = "?"
}
}
updateParts := make([]string, len(updateColumns))
for i, col := range updateColumns {
if raw, ok := rawValues[col]; ok {
updateParts[i] = "`" + col + "` = " + raw
} else {
updateParts[i] = "`" + col + "` = VALUES(`" + col + "`)"
}
}
return "INSERT INTO `" + that.Prefix + table + "` (" + strings.Join(quotedCols, ", ") +
") VALUES (" + strings.Join(valueParts, ", ") +
") ON DUPLICATE KEY UPDATE " + strings.Join(updateParts, ", ")
}
// buildPostgresUpsert 构建 PostgreSQL 的 Upsert 语句
func (that *HoTimeDB) buildPostgresUpsert(table string, columns []string, uniqueKeys []string, updateColumns []string, rawValues map[string]string) string {
// INSERT INTO table (col1, col2) VALUES ($1, $2)
// ON CONFLICT (unique_key) DO UPDATE SET col1 = EXCLUDED.col1
quotedCols := make([]string, len(columns))
valueParts := make([]string, len(columns))
paramIndex := 1
for i, col := range columns {
quotedCols[i] = "\"" + col + "\""
if raw, ok := rawValues[col]; ok {
valueParts[i] = raw
} else {
valueParts[i] = "$" + ObjToStr(paramIndex)
paramIndex++
}
}
quotedUniqueKeys := make([]string, len(uniqueKeys))
for i, key := range uniqueKeys {
quotedUniqueKeys[i] = "\"" + key + "\""
}
updateParts := make([]string, len(updateColumns))
for i, col := range updateColumns {
if raw, ok := rawValues[col]; ok {
updateParts[i] = "\"" + col + "\" = " + raw
} else {
updateParts[i] = "\"" + col + "\" = EXCLUDED.\"" + col + "\""
}
}
return "INSERT INTO \"" + that.Prefix + table + "\" (" + strings.Join(quotedCols, ", ") +
") VALUES (" + strings.Join(valueParts, ", ") +
") ON CONFLICT (" + strings.Join(quotedUniqueKeys, ", ") +
") DO UPDATE SET " + strings.Join(updateParts, ", ")
}
// buildSQLiteUpsert 构建 SQLite 的 Upsert 语句
func (that *HoTimeDB) buildSQLiteUpsert(table string, columns []string, uniqueKeys []string, updateColumns []string, rawValues map[string]string) string {
// INSERT INTO table (col1, col2) VALUES (?, ?)
// ON CONFLICT (unique_key) DO UPDATE SET col1 = excluded.col1
quotedCols := make([]string, len(columns))
valueParts := make([]string, len(columns))
for i, col := range columns {
quotedCols[i] = "\"" + col + "\""
if raw, ok := rawValues[col]; ok {
valueParts[i] = raw
} else {
valueParts[i] = "?"
}
}
quotedUniqueKeys := make([]string, len(uniqueKeys))
for i, key := range uniqueKeys {
quotedUniqueKeys[i] = "\"" + key + "\""
}
updateParts := make([]string, len(updateColumns))
for i, col := range updateColumns {
if raw, ok := rawValues[col]; ok {
updateParts[i] = "\"" + col + "\" = " + raw
} else {
updateParts[i] = "\"" + col + "\" = excluded.\"" + col + "\""
}
}
return "INSERT INTO \"" + that.Prefix + table + "\" (" + strings.Join(quotedCols, ", ") +
") VALUES (" + strings.Join(valueParts, ", ") +
") ON CONFLICT (" + strings.Join(quotedUniqueKeys, ", ") +
") DO UPDATE SET " + strings.Join(updateParts, ", ")
}
// Update 更新数据
func (that *HoTimeDB) Update(table string, data Map, where Map) int64 {
query := "UPDATE `" + that.Prefix + table + "` SET "
qs := make([]interface{}, 0)
tp := len(data)
for k, v := range data {
vstr := "?"
if Substr(k, len(k)-3, 3) == "[#]" {
k = strings.Replace(k, "[#]", "", -1)
vstr = ObjToStr(v)
} else {
qs = append(qs, v)
}
query += "`" + k + "`=" + vstr + " "
if tp--; tp != 0 {
query += ", "
}
}
temp, resWhere := that.where(where)
query += temp + ";"
qs = append(qs, resWhere...)
res, err := that.Exec(query, qs...)
rows := int64(0)
if err.GetError() == nil && res != nil {
rows, _ = res.RowsAffected()
}
// 如果更新成功,则删除缓存
if rows != 0 {
if that.HoTimeCache != nil && table != "cached" {
_ = that.HoTimeCache.Db(table+"*", nil)
}
}
return rows
}
// Delete 删除数据
func (that *HoTimeDB) Delete(table string, data map[string]interface{}) int64 {
query := "DELETE FROM `" + that.Prefix + table + "` "
temp, resWhere := that.where(data)
query += temp + ";"
res, err := that.Exec(query, resWhere...)
rows := int64(0)
if err.GetError() == nil && res != nil {
rows, _ = res.RowsAffected()
}
// 如果删除成功,删除对应缓存
if rows != 0 {
if that.HoTimeCache != nil && table != "cached" {
_ = that.HoTimeCache.Db(table+"*", nil)
}
}
return rows
}