package db import ( "database/sql" "encoding/json" "errors" "reflect" "strings" . "code.hoteas.com/golang/hotime/common" ) // md5 生成查询的 MD5 哈希(用于缓存) func (that *HoTimeDB) md5(query string, args ...interface{}) string { strByte, _ := json.Marshal(args) str := Md5(query + ":" + string(strByte)) return str } // Query 执行查询 SQL func (that *HoTimeDB) Query(query string, args ...interface{}) []Map { return that.queryWithRetry(query, false, args...) } // queryWithRetry 内部查询方法,支持重试标记 func (that *HoTimeDB) queryWithRetry(query string, retried bool, args ...interface{}) []Map { // 预处理数组占位符 ?[] query, args = that.expandArrayPlaceholder(query, args) // 保存调试信息(加锁保护) that.mu.Lock() that.LastQuery = query that.LastData = args that.mu.Unlock() defer func() { if that.Mode != 0 { that.mu.RLock() that.Log.Info("SQL:"+that.LastQuery, " DATA:", that.LastData, " ERROR:", that.LastErr.GetError()) that.mu.RUnlock() } }() var err error var resl *sql.Rows // 主从数据库切换,只有select语句有从数据库 db := that.DB if that.SlaveDB != nil { db = that.SlaveDB } if db == nil { err = errors.New("没有初始化数据库") that.LastErr.SetError(err) return nil } // 处理参数中的 slice 类型 processedArgs := that.processArgs(args) if that.Tx != nil { resl, err = that.Tx.Query(query, processedArgs...) } else { resl, err = db.Query(query, processedArgs...) } that.LastErr.SetError(err) if err != nil { // 如果还没重试过,尝试 Ping 后重试一次 if !retried { if pingErr := db.Ping(); pingErr == nil { return that.queryWithRetry(query, true, args...) } } return nil } return that.Row(resl) } // Exec 执行非查询 SQL func (that *HoTimeDB) Exec(query string, args ...interface{}) (sql.Result, *Error) { return that.execWithRetry(query, false, args...) } // execWithRetry 内部执行方法,支持重试标记 func (that *HoTimeDB) execWithRetry(query string, retried bool, args ...interface{}) (sql.Result, *Error) { // 预处理数组占位符 ?[] query, args = that.expandArrayPlaceholder(query, args) // 保存调试信息(加锁保护) that.mu.Lock() that.LastQuery = query that.LastData = args that.mu.Unlock() defer func() { if that.Mode != 0 { that.mu.RLock() that.Log.Info("SQL: "+that.LastQuery, " DATA: ", that.LastData, " ERROR: ", that.LastErr.GetError()) that.mu.RUnlock() } }() var e error var resl sql.Result if that.DB == nil { err := errors.New("没有初始化数据库") that.LastErr.SetError(err) return nil, that.LastErr } // 处理参数中的 slice 类型 processedArgs := that.processArgs(args) if that.Tx != nil { resl, e = that.Tx.Exec(query, processedArgs...) } else { resl, e = that.DB.Exec(query, processedArgs...) } that.LastErr.SetError(e) // 判断是否连接断开了,如果还没重试过,尝试重试一次 if e != nil { if !retried { if pingErr := that.DB.Ping(); pingErr == nil { return that.execWithRetry(query, true, args...) } } return resl, that.LastErr } return resl, that.LastErr } // processArgs 处理参数中的 slice 类型 func (that *HoTimeDB) processArgs(args []interface{}) []interface{} { processedArgs := make([]interface{}, len(args)) copy(processedArgs, args) for key := range processedArgs { arg := processedArgs[key] if arg == nil { continue } argType := reflect.ValueOf(arg).Type().String() if strings.Contains(argType, "[]") || strings.Contains(argType, "Slice") { argLis := ObjToSlice(arg) // 将slice转为逗号分割字符串 argStr := "" for i := 0; i < len(argLis); i++ { if i == len(argLis)-1 { argStr += ObjToStr(argLis[i]) } else { argStr += ObjToStr(argLis[i]) + "," } } processedArgs[key] = argStr } } return processedArgs } // expandArrayPlaceholder 展开 IN (?) / NOT IN (?) 中的数组参数 // 自动识别 IN/NOT IN (?) 模式,当参数是数组时展开为多个 ? // // 示例: // // db.Query("SELECT * FROM user WHERE id IN (?)", []int{1, 2, 3}) // // 展开为: SELECT * FROM user WHERE id IN (?, ?, ?) 参数: [1, 2, 3] // // db.Query("SELECT * FROM user WHERE id IN (?)", []int{}) // // 展开为: SELECT * FROM user WHERE 1=0 参数: [] (空集合的IN永假) // // db.Query("SELECT * FROM user WHERE id NOT IN (?)", []int{}) // // 展开为: SELECT * FROM user WHERE 1=1 参数: [] (空集合的NOT IN永真) // // db.Query("SELECT * FROM user WHERE id = ?", 1) // // 保持不变: SELECT * FROM user WHERE id = ? 参数: [1] func (that *HoTimeDB) expandArrayPlaceholder(query string, args []interface{}) (string, []interface{}) { if len(args) == 0 || !strings.Contains(query, "?") { return query, args } // 检查是否有数组参数 hasArray := false for _, arg := range args { if arg == nil { continue } argType := reflect.ValueOf(arg).Type().String() if strings.Contains(argType, "[]") || strings.Contains(argType, "Slice") { hasArray = true break } } if !hasArray { return query, args } newArgs := make([]interface{}, 0, len(args)) result := strings.Builder{} argIndex := 0 for i := 0; i < len(query); i++ { if query[i] == '?' && argIndex < len(args) { arg := args[argIndex] argIndex++ if arg == nil { result.WriteByte('?') newArgs = append(newArgs, nil) continue } argType := reflect.ValueOf(arg).Type().String() if strings.Contains(argType, "[]") || strings.Contains(argType, "Slice") { // 是数组参数,检查是否在 IN (...) 或 NOT IN (...) 中 prevPart := result.String() prevUpper := strings.ToUpper(prevPart) // 查找最近的 NOT IN ( 模式 notInIndex := strings.LastIndex(prevUpper, " NOT IN (") notInIndex2 := strings.LastIndex(prevUpper, " NOT IN(") if notInIndex2 > notInIndex { notInIndex = notInIndex2 } // 查找最近的 IN ( 模式(但要排除 NOT IN 的情况) inIndex := strings.LastIndex(prevUpper, " IN (") inIndex2 := strings.LastIndex(prevUpper, " IN(") if inIndex2 > inIndex { inIndex = inIndex2 } // 判断是 NOT IN 还是 IN // 注意:" NOT IN (" 包含 " IN (",所以如果找到的 IN 位置在 NOT IN 范围内,应该优先判断为 NOT IN isNotIn := false matchIndex := -1 if notInIndex != -1 { // 检查 inIndex 是否在 notInIndex 范围内(即 NOT IN 的 IN 部分) // NOT IN ( 的 IN ( 部分从 notInIndex + 4 开始 if inIndex != -1 && inIndex >= notInIndex && inIndex <= notInIndex+5 { // inIndex 是 NOT IN 的一部分,使用 NOT IN isNotIn = true matchIndex = notInIndex } else if inIndex == -1 || notInIndex > inIndex { // 没有独立的 IN,或 NOT IN 在 IN 之后 isNotIn = true matchIndex = notInIndex } else { // 有独立的 IN 且在 NOT IN 之后 matchIndex = inIndex } } else if inIndex != -1 { matchIndex = inIndex } // 检查 IN ( 后面是否只有空格(即当前 ? 紧跟在 IN ( 后面) isInPattern := false if matchIndex != -1 { afterIn := prevPart[matchIndex:] // 找到 ( 的位置 parenIdx := strings.Index(afterIn, "(") if parenIdx != -1 { afterParen := strings.TrimSpace(afterIn[parenIdx+1:]) if afterParen == "" { isInPattern = true } } } if isInPattern { // 在 IN (...) 或 NOT IN (...) 模式中 argList := ObjToSlice(arg) if len(argList) == 0 { // 空数组处理:需要找到字段名的开始位置 // 往前找最近的 AND/OR/WHERE/(,以确定条件的开始位置 truncateIndex := matchIndex searchPart := prevUpper[:matchIndex] // 找最近的分隔符位置 andIdx := strings.LastIndex(searchPart, " AND ") orIdx := strings.LastIndex(searchPart, " OR ") whereIdx := strings.LastIndex(searchPart, " WHERE ") parenIdx := strings.LastIndex(searchPart, "(") // 取最靠后的分隔符 sepIndex := -1 sepLen := 0 if andIdx > sepIndex { sepIndex = andIdx sepLen = 5 // " AND " } if orIdx > sepIndex { sepIndex = orIdx sepLen = 4 // " OR " } if whereIdx > sepIndex { sepIndex = whereIdx sepLen = 7 // " WHERE " } if parenIdx > sepIndex { sepIndex = parenIdx sepLen = 1 // "(" } if sepIndex != -1 { truncateIndex = sepIndex + sepLen } result.Reset() result.WriteString(prevPart[:truncateIndex]) if isNotIn { // NOT IN 空集合 = 永真 result.WriteString(" 1=1 ") } else { // IN 空集合 = 永假 result.WriteString(" 1=0 ") } // 跳过后面的 ) for j := i + 1; j < len(query); j++ { if query[j] == ')' { i = j break } } } else if len(argList) == 1 { // 单元素数组 result.WriteByte('?') newArgs = append(newArgs, argList[0]) } else { // 多元素数组,展开为多个 ? for j := 0; j < len(argList); j++ { if j > 0 { result.WriteString(", ") } result.WriteByte('?') newArgs = append(newArgs, argList[j]) } } } else { // 不在 IN 模式中,保持原有行为(数组会被 processArgs 转为逗号字符串) result.WriteByte('?') newArgs = append(newArgs, arg) } } else { // 非数组参数 result.WriteByte('?') newArgs = append(newArgs, arg) } } else { result.WriteByte(query[i]) } } return result.String(), newArgs } // Row 数据库数据解析 func (that *HoTimeDB) Row(resl *sql.Rows) []Map { dest := make([]Map, 0) strs, _ := resl.Columns() for i := 0; resl.Next(); i++ { lis := make(Map, 0) a := make([]interface{}, len(strs)) b := make([]interface{}, len(a)) for j := 0; j < len(a); j++ { b[j] = &a[j] } err := resl.Scan(b...) if err != nil { that.LastErr.SetError(err) return nil } for j := 0; j < len(a); j++ { if a[j] != nil && reflect.ValueOf(a[j]).Type().String() == "[]uint8" { lis[strs[j]] = string(a[j].([]byte)) } else { lis[strs[j]] = a[j] // 取实际类型 } } dest = append(dest, lis) } return dest }