diff --git a/db/crud.go b/db/crud.go index d8acf4d..be9d82a 100644 --- a/db/crud.go +++ b/db/crud.go @@ -87,17 +87,26 @@ func (that *HoTimeDB) Select(table string, qu ...interface{}) []Map { join = true } + processor := that.GetProcessor() + if len(qu) > 0 { if reflect.ValueOf(qu[intQs]).Type().String() == "string" { - query += " " + qu[intQs].(string) + // 字段列表字符串,使用处理器处理 table.column 格式 + fieldStr := qu[intQs].(string) + if fieldStr != "*" { + fieldStr = processor.ProcessFieldList(fieldStr) + } + query += " " + fieldStr } else { data := ObjToSlice(qu[intQs]) for i := 0; i < len(data); i++ { k := data.GetString(i) if strings.Contains(k, " AS ") || strings.Contains(k, ".") { - query += " " + k + " " + // 处理 table.column 格式 + query += " " + processor.ProcessFieldList(k) + " " } else { - query += " `" + k + "` " + // 单独的列名 + query += " " + processor.ProcessColumnNoPrefix(k) + " " } if i+1 != len(data) { @@ -109,11 +118,8 @@ func (that *HoTimeDB) Select(table string, qu ...interface{}) []Map { query += " *" } - if !strings.Contains(table, ".") && !strings.Contains(table, " AS ") { - query += " FROM `" + that.Prefix + table + "` " - } else { - query += " FROM " + that.Prefix + table + " " - } + // 处理表名(添加前缀和正确的引号) + query += " FROM " + processor.ProcessTableName(table) + " " if join { query += that.buildJoin(qu[0]) @@ -157,6 +163,7 @@ func (that *HoTimeDB) buildJoin(joinData interface{}) string { query := "" var testQu = []string{} testQuData := Map{} + processor := that.GetProcessor() if reflect.ValueOf(joinData).Type().String() == "common.Map" { testQuData = joinData.(Map) @@ -184,36 +191,34 @@ func (that *HoTimeDB) buildJoin(joinData interface{}) string { case "[>]": func() { table := Substr(k, 3, len(k)-3) - if !strings.Contains(table, " ") { - table = "`" + table + "`" - } - query += " LEFT JOIN " + table + " ON " + v.(string) + " " + // 处理表名(添加前缀和正确的引号) + table = processor.ProcessTableName(table) + // 处理 ON 条件中的 table.column + onCondition := processor.ProcessConditionString(v.(string)) + query += " LEFT JOIN " + table + " ON " + onCondition + " " }() case "[<]": func() { table := Substr(k, 3, len(k)-3) - if !strings.Contains(table, " ") { - table = "`" + table + "`" - } - query += " RIGHT JOIN " + table + " ON " + v.(string) + " " + table = processor.ProcessTableName(table) + onCondition := processor.ProcessConditionString(v.(string)) + query += " RIGHT JOIN " + table + " ON " + onCondition + " " }() } switch Substr(k, 0, 4) { case "[<>]": func() { table := Substr(k, 4, len(k)-4) - if !strings.Contains(table, " ") { - table = "`" + table + "`" - } - query += " FULL JOIN " + table + " ON " + v.(string) + " " + table = processor.ProcessTableName(table) + onCondition := processor.ProcessConditionString(v.(string)) + query += " FULL JOIN " + table + " ON " + onCondition + " " }() case "[><]": func() { table := Substr(k, 4, len(k)-4) - if !strings.Contains(table, " ") { - table = "`" + table + "`" - } - query += " INNER JOIN " + table + " ON " + v.(string) + " " + table = processor.ProcessTableName(table) + onCondition := processor.ProcessConditionString(v.(string)) + query += " INNER JOIN " + table + " ON " + onCondition + " " }() } } @@ -250,6 +255,7 @@ func (that *HoTimeDB) Insert(table string, data map[string]interface{}) int64 { values := make([]interface{}, 0) queryString := " (" valueString := " (" + processor := that.GetProcessor() lens := len(data) tempLen := 0 @@ -262,25 +268,25 @@ func (that *HoTimeDB) Insert(table string, data map[string]interface{}) int64 { k = strings.Replace(k, "[#]", "", -1) vstr = ObjToStr(v) if tempLen < lens { - queryString += "`" + k + "`," + queryString += processor.ProcessColumnNoPrefix(k) + "," valueString += vstr + "," } else { - queryString += "`" + k + "`) " + queryString += processor.ProcessColumnNoPrefix(k) + ") " valueString += vstr + ");" } } else { values = append(values, v) if tempLen < lens { - queryString += "`" + k + "`," + queryString += processor.ProcessColumnNoPrefix(k) + "," valueString += "?," } else { - queryString += "`" + k + "`) " + queryString += processor.ProcessColumnNoPrefix(k) + ") " valueString += "?);" } } } - query := "INSERT INTO `" + that.Prefix + table + "` " + queryString + "VALUES" + valueString + query := "INSERT INTO " + processor.ProcessTableName(table) + " " + queryString + "VALUES" + valueString res, err := that.Exec(query, values...) @@ -334,10 +340,12 @@ func (that *HoTimeDB) BatchInsert(table string, dataList []Map) int64 { // 排序列名以确保一致性 sort.Strings(columns) + processor := that.GetProcessor() + // 构建列名部分 quotedCols := make([]string, len(columns)) for i, col := range columns { - quotedCols[i] = "`" + col + "`" + quotedCols[i] = processor.ProcessColumnNoPrefix(col) } colStr := strings.Join(quotedCols, ", ") @@ -368,7 +376,7 @@ func (that *HoTimeDB) BatchInsert(table string, dataList []Map) int64 { placeholders[i] = "(" + strings.Join(rowPlaceholders, ", ") + ")" } - query := "INSERT INTO `" + that.Prefix + table + "` (" + colStr + ") VALUES " + strings.Join(placeholders, ", ") + query := "INSERT INTO " + processor.ProcessTableName(table) + " (" + colStr + ") VALUES " + strings.Join(placeholders, ", ") res, err := that.Exec(query, values...) @@ -495,11 +503,12 @@ func (that *HoTimeDB) Upsert(table string, data Map, uniqueKeys Slice, updateCol func (that *HoTimeDB) buildMySQLUpsert(table string, columns []string, uniqueKeys []string, updateColumns []string, rawValues map[string]string) string { // INSERT INTO table (col1, col2) VALUES (?, ?) // ON DUPLICATE KEY UPDATE col1 = VALUES(col1), col2 = VALUES(col2) + processor := that.GetProcessor() quotedCols := make([]string, len(columns)) valueParts := make([]string, len(columns)) for i, col := range columns { - quotedCols[i] = "`" + col + "`" + quotedCols[i] = processor.ProcessColumnNoPrefix(col) if raw, ok := rawValues[col]; ok { valueParts[i] = raw } else { @@ -509,14 +518,15 @@ func (that *HoTimeDB) buildMySQLUpsert(table string, columns []string, uniqueKey updateParts := make([]string, len(updateColumns)) for i, col := range updateColumns { + quotedCol := processor.ProcessColumnNoPrefix(col) if raw, ok := rawValues[col]; ok { - updateParts[i] = "`" + col + "` = " + raw + updateParts[i] = quotedCol + " = " + raw } else { - updateParts[i] = "`" + col + "` = VALUES(`" + col + "`)" + updateParts[i] = quotedCol + " = VALUES(" + quotedCol + ")" } } - return "INSERT INTO `" + that.Prefix + table + "` (" + strings.Join(quotedCols, ", ") + + return "INSERT INTO " + processor.ProcessTableName(table) + " (" + strings.Join(quotedCols, ", ") + ") VALUES (" + strings.Join(valueParts, ", ") + ") ON DUPLICATE KEY UPDATE " + strings.Join(updateParts, ", ") } @@ -525,12 +535,14 @@ func (that *HoTimeDB) buildMySQLUpsert(table string, columns []string, uniqueKey func (that *HoTimeDB) buildPostgresUpsert(table string, columns []string, uniqueKeys []string, updateColumns []string, rawValues map[string]string) string { // INSERT INTO table (col1, col2) VALUES ($1, $2) // ON CONFLICT (unique_key) DO UPDATE SET col1 = EXCLUDED.col1 + processor := that.GetProcessor() + dialect := that.GetDialect() quotedCols := make([]string, len(columns)) valueParts := make([]string, len(columns)) paramIndex := 1 for i, col := range columns { - quotedCols[i] = "\"" + col + "\"" + quotedCols[i] = dialect.QuoteIdentifier(col) if raw, ok := rawValues[col]; ok { valueParts[i] = raw } else { @@ -541,19 +553,20 @@ func (that *HoTimeDB) buildPostgresUpsert(table string, columns []string, unique quotedUniqueKeys := make([]string, len(uniqueKeys)) for i, key := range uniqueKeys { - quotedUniqueKeys[i] = "\"" + key + "\"" + quotedUniqueKeys[i] = dialect.QuoteIdentifier(key) } updateParts := make([]string, len(updateColumns)) for i, col := range updateColumns { + quotedCol := dialect.QuoteIdentifier(col) if raw, ok := rawValues[col]; ok { - updateParts[i] = "\"" + col + "\" = " + raw + updateParts[i] = quotedCol + " = " + raw } else { - updateParts[i] = "\"" + col + "\" = EXCLUDED.\"" + col + "\"" + updateParts[i] = quotedCol + " = EXCLUDED." + quotedCol } } - return "INSERT INTO \"" + that.Prefix + table + "\" (" + strings.Join(quotedCols, ", ") + + return "INSERT INTO " + processor.ProcessTableName(table) + " (" + strings.Join(quotedCols, ", ") + ") VALUES (" + strings.Join(valueParts, ", ") + ") ON CONFLICT (" + strings.Join(quotedUniqueKeys, ", ") + ") DO UPDATE SET " + strings.Join(updateParts, ", ") @@ -563,11 +576,13 @@ func (that *HoTimeDB) buildPostgresUpsert(table string, columns []string, unique func (that *HoTimeDB) buildSQLiteUpsert(table string, columns []string, uniqueKeys []string, updateColumns []string, rawValues map[string]string) string { // INSERT INTO table (col1, col2) VALUES (?, ?) // ON CONFLICT (unique_key) DO UPDATE SET col1 = excluded.col1 + processor := that.GetProcessor() + dialect := that.GetDialect() quotedCols := make([]string, len(columns)) valueParts := make([]string, len(columns)) for i, col := range columns { - quotedCols[i] = "\"" + col + "\"" + quotedCols[i] = dialect.QuoteIdentifier(col) if raw, ok := rawValues[col]; ok { valueParts[i] = raw } else { @@ -577,19 +592,20 @@ func (that *HoTimeDB) buildSQLiteUpsert(table string, columns []string, uniqueKe quotedUniqueKeys := make([]string, len(uniqueKeys)) for i, key := range uniqueKeys { - quotedUniqueKeys[i] = "\"" + key + "\"" + quotedUniqueKeys[i] = dialect.QuoteIdentifier(key) } updateParts := make([]string, len(updateColumns)) for i, col := range updateColumns { + quotedCol := dialect.QuoteIdentifier(col) if raw, ok := rawValues[col]; ok { - updateParts[i] = "\"" + col + "\" = " + raw + updateParts[i] = quotedCol + " = " + raw } else { - updateParts[i] = "\"" + col + "\" = excluded.\"" + col + "\"" + updateParts[i] = quotedCol + " = excluded." + quotedCol } } - return "INSERT INTO \"" + that.Prefix + table + "\" (" + strings.Join(quotedCols, ", ") + + return "INSERT INTO " + processor.ProcessTableName(table) + " (" + strings.Join(quotedCols, ", ") + ") VALUES (" + strings.Join(valueParts, ", ") + ") ON CONFLICT (" + strings.Join(quotedUniqueKeys, ", ") + ") DO UPDATE SET " + strings.Join(updateParts, ", ") @@ -597,7 +613,8 @@ func (that *HoTimeDB) buildSQLiteUpsert(table string, columns []string, uniqueKe // Update 更新数据 func (that *HoTimeDB) Update(table string, data Map, where Map) int64 { - query := "UPDATE `" + that.Prefix + table + "` SET " + processor := that.GetProcessor() + query := "UPDATE " + processor.ProcessTableName(table) + " SET " qs := make([]interface{}, 0) tp := len(data) @@ -609,7 +626,7 @@ func (that *HoTimeDB) Update(table string, data Map, where Map) int64 { } else { qs = append(qs, v) } - query += "`" + k + "`=" + vstr + " " + query += processor.ProcessColumnNoPrefix(k) + "=" + vstr + " " if tp--; tp != 0 { query += ", " } @@ -639,7 +656,8 @@ func (that *HoTimeDB) Update(table string, data Map, where Map) int64 { // Delete 删除数据 func (that *HoTimeDB) Delete(table string, data map[string]interface{}) int64 { - query := "DELETE FROM `" + that.Prefix + table + "` " + processor := that.GetProcessor() + query := "DELETE FROM " + processor.ProcessTableName(table) + " " temp, resWhere := that.where(data) query += temp + ";" diff --git a/db/db.go b/db/db.go index c4b3834..25a5ad2 100644 --- a/db/db.go +++ b/db/db.go @@ -4,10 +4,12 @@ import ( "code.hoteas.com/golang/hotime/cache" . "code.hoteas.com/golang/hotime/common" "database/sql" + "strings" + "sync" + _ "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" "github.com/sirupsen/logrus" - "sync" ) // HoTimeDB 数据库操作核心结构体 @@ -98,3 +100,48 @@ func (that *HoTimeDB) GetType() string { func (that *HoTimeDB) GetPrefix() string { return that.Prefix } + +// GetProcessor 获取标识符处理器 +// 用于处理表名、字段名的前缀添加和引号转换 +func (that *HoTimeDB) GetProcessor() *IdentifierProcessor { + return NewIdentifierProcessor(that.GetDialect(), that.Prefix) +} + +// T 辅助方法:获取带前缀和引号的表名 +// 用于手动构建 SQL 时使用 +// 示例: db.T("order") 返回 "`app_order`" (MySQL) 或 "\"app_order\"" (PostgreSQL) +func (that *HoTimeDB) T(table string) string { + return that.GetProcessor().ProcessTableName(table) +} + +// C 辅助方法:获取带前缀和引号的 table.column +// 支持两种调用方式: +// - db.C("order", "name") 返回 "`app_order`.`name`" +// - db.C("order.name") 返回 "`app_order`.`name`" +func (that *HoTimeDB) C(args ...string) string { + if len(args) == 0 { + return "" + } + if len(args) == 1 { + return that.GetProcessor().ProcessColumn(args[0]) + } + // 两个参数: table, column + dialect := that.GetDialect() + table := args[0] + column := args[1] + // 去除已有引号 + table = trimQuotes(table) + column = trimQuotes(column) + return dialect.QuoteIdentifier(that.Prefix+table) + "." + dialect.QuoteIdentifier(column) +} + +// trimQuotes 去除字符串两端的引号 +func trimQuotes(s string) string { + s = strings.TrimSpace(s) + if len(s) >= 2 { + if (s[0] == '`' && s[len(s)-1] == '`') || (s[0] == '"' && s[len(s)-1] == '"') { + return s[1 : len(s)-1] + } + } + return s +} diff --git a/db/dialect.go b/db/dialect.go index 4743381..fb2a9f4 100644 --- a/db/dialect.go +++ b/db/dialect.go @@ -14,6 +14,15 @@ type Dialect interface { // SQLite 使用双引号或方括号 "name" 或 [name] Quote(name string) string + // QuoteIdentifier 处理单个标识符(去除已有引号,添加正确引号) + // 输入可能带有反引号或双引号,会先去除再添加正确格式 + QuoteIdentifier(name string) string + + // QuoteChar 返回引号字符 + // MySQL: ` + // PostgreSQL/SQLite: " + QuoteChar() string + // Placeholder 生成占位符 // MySQL/SQLite 使用 ? // PostgreSQL 使用 $1, $2, $3... @@ -54,6 +63,16 @@ func (d *MySQLDialect) Quote(name string) string { return "`" + name + "`" } +func (d *MySQLDialect) QuoteIdentifier(name string) string { + // 去除已有的引号(反引号和双引号) + name = strings.Trim(name, "`\"") + return "`" + name + "`" +} + +func (d *MySQLDialect) QuoteChar() string { + return "`" +} + func (d *MySQLDialect) Placeholder(index int) string { return "?" } @@ -121,6 +140,16 @@ func (d *PostgreSQLDialect) Quote(name string) string { return "\"" + name + "\"" } +func (d *PostgreSQLDialect) QuoteIdentifier(name string) string { + // 去除已有的引号(反引号和双引号) + name = strings.Trim(name, "`\"") + return "\"" + name + "\"" +} + +func (d *PostgreSQLDialect) QuoteChar() string { + return "\"" +} + func (d *PostgreSQLDialect) Placeholder(index int) string { return fmt.Sprintf("$%d", index) } @@ -192,6 +221,16 @@ func (d *SQLiteDialect) Quote(name string) string { return "\"" + name + "\"" } +func (d *SQLiteDialect) QuoteIdentifier(name string) string { + // 去除已有的引号(反引号和双引号) + name = strings.Trim(name, "`\"") + return "\"" + name + "\"" +} + +func (d *SQLiteDialect) QuoteChar() string { + return "\"" +} + func (d *SQLiteDialect) Placeholder(index int) string { return "?" } diff --git a/db/dialect_test.go b/db/dialect_test.go new file mode 100644 index 0000000..9d5892d --- /dev/null +++ b/db/dialect_test.go @@ -0,0 +1,276 @@ +package db + +import ( + "fmt" + "strings" + "testing" +) + +// TestDialectQuoteIdentifier 测试方言的 QuoteIdentifier 方法 +func TestDialectQuoteIdentifier(t *testing.T) { + tests := []struct { + name string + dialect Dialect + input string + expected string + }{ + // MySQL 方言测试 + {"MySQL simple", &MySQLDialect{}, "name", "`name`"}, + {"MySQL with backticks", &MySQLDialect{}, "`name`", "`name`"}, + {"MySQL with quotes", &MySQLDialect{}, "\"name\"", "`name`"}, + + // PostgreSQL 方言测试 + {"PostgreSQL simple", &PostgreSQLDialect{}, "name", "\"name\""}, + {"PostgreSQL with backticks", &PostgreSQLDialect{}, "`name`", "\"name\""}, + {"PostgreSQL with quotes", &PostgreSQLDialect{}, "\"name\"", "\"name\""}, + + // SQLite 方言测试 + {"SQLite simple", &SQLiteDialect{}, "name", "\"name\""}, + {"SQLite with backticks", &SQLiteDialect{}, "`name`", "\"name\""}, + {"SQLite with quotes", &SQLiteDialect{}, "\"name\"", "\"name\""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.dialect.QuoteIdentifier(tt.input) + if result != tt.expected { + t.Errorf("QuoteIdentifier(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// TestDialectQuoteChar 测试方言的 QuoteChar 方法 +func TestDialectQuoteChar(t *testing.T) { + tests := []struct { + name string + dialect Dialect + expected string + }{ + {"MySQL", &MySQLDialect{}, "`"}, + {"PostgreSQL", &PostgreSQLDialect{}, "\""}, + {"SQLite", &SQLiteDialect{}, "\""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.dialect.QuoteChar() + if result != tt.expected { + t.Errorf("QuoteChar() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestIdentifierProcessorTableName 测试表名处理 +func TestIdentifierProcessorTableName(t *testing.T) { + tests := []struct { + name string + dialect Dialect + prefix string + input string + expected string + }{ + // MySQL 无前缀 + {"MySQL no prefix", &MySQLDialect{}, "", "order", "`order`"}, + {"MySQL no prefix with backticks", &MySQLDialect{}, "", "`order`", "`order`"}, + + // MySQL 有前缀 + {"MySQL with prefix", &MySQLDialect{}, "app_", "order", "`app_order`"}, + {"MySQL with prefix and backticks", &MySQLDialect{}, "app_", "`order`", "`app_order`"}, + + // PostgreSQL 无前缀 + {"PostgreSQL no prefix", &PostgreSQLDialect{}, "", "order", "\"order\""}, + + // PostgreSQL 有前缀 + {"PostgreSQL with prefix", &PostgreSQLDialect{}, "app_", "order", "\"app_order\""}, + {"PostgreSQL with prefix and quotes", &PostgreSQLDialect{}, "app_", "\"order\"", "\"app_order\""}, + + // SQLite 有前缀 + {"SQLite with prefix", &SQLiteDialect{}, "app_", "user", "\"app_user\""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + processor := NewIdentifierProcessor(tt.dialect, tt.prefix) + result := processor.ProcessTableName(tt.input) + if result != tt.expected { + t.Errorf("ProcessTableName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// TestIdentifierProcessorColumn 测试列名处理(包括 table.column 格式) +func TestIdentifierProcessorColumn(t *testing.T) { + tests := []struct { + name string + dialect Dialect + prefix string + input string + expected string + }{ + // 单独列名 + {"MySQL simple column", &MySQLDialect{}, "", "name", "`name`"}, + {"MySQL simple column with prefix", &MySQLDialect{}, "app_", "name", "`name`"}, + + // table.column 格式 + {"MySQL table.column no prefix", &MySQLDialect{}, "", "order.name", "`order`.`name`"}, + {"MySQL table.column with prefix", &MySQLDialect{}, "app_", "order.name", "`app_order`.`name`"}, + {"MySQL table.column with backticks", &MySQLDialect{}, "app_", "`order`.name", "`app_order`.`name`"}, + + // PostgreSQL + {"PostgreSQL table.column with prefix", &PostgreSQLDialect{}, "app_", "order.name", "\"app_order\".\"name\""}, + {"PostgreSQL table.column with quotes", &PostgreSQLDialect{}, "app_", "\"order\".name", "\"app_order\".\"name\""}, + + // SQLite + {"SQLite table.column with prefix", &SQLiteDialect{}, "app_", "user.email", "\"app_user\".\"email\""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + processor := NewIdentifierProcessor(tt.dialect, tt.prefix) + result := processor.ProcessColumn(tt.input) + if result != tt.expected { + t.Errorf("ProcessColumn(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// TestIdentifierProcessorConditionString 测试条件字符串处理 +func TestIdentifierProcessorConditionString(t *testing.T) { + tests := []struct { + name string + dialect Dialect + prefix string + input string + contains []string // 结果应该包含这些字符串 + }{ + // MySQL 简单条件 + { + "MySQL simple condition", + &MySQLDialect{}, + "app_", + "user.id = order.user_id", + []string{"`app_user`", "`app_order`"}, + }, + // MySQL 复杂条件 + { + "MySQL complex condition", + &MySQLDialect{}, + "app_", + "user.id = order.user_id AND order.status = 1", + []string{"`app_user`", "`app_order`"}, + }, + // PostgreSQL + { + "PostgreSQL condition", + &PostgreSQLDialect{}, + "app_", + "user.id = order.user_id", + []string{"\"app_user\"", "\"app_order\""}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + processor := NewIdentifierProcessor(tt.dialect, tt.prefix) + result := processor.ProcessConditionString(tt.input) + for _, expected := range tt.contains { + if !strings.Contains(result, expected) { + t.Errorf("ProcessConditionString(%q) = %q, should contain %q", tt.input, result, expected) + } + } + }) + } +} + +// TestHoTimeDBHelperMethods 测试 HoTimeDB 的辅助方法 T() 和 C() +func TestHoTimeDBHelperMethods(t *testing.T) { + // 创建 MySQL 数据库实例 + mysqlDB := &HoTimeDB{ + Type: "mysql", + Prefix: "app_", + } + mysqlDB.initDialect() + + // 测试 T() 方法 + t.Run("MySQL T() method", func(t *testing.T) { + result := mysqlDB.T("order") + expected := "`app_order`" + if result != expected { + t.Errorf("T(\"order\") = %q, want %q", result, expected) + } + }) + + // 测试 C() 方法(两个参数) + t.Run("MySQL C() method with two args", func(t *testing.T) { + result := mysqlDB.C("order", "name") + expected := "`app_order`.`name`" + if result != expected { + t.Errorf("C(\"order\", \"name\") = %q, want %q", result, expected) + } + }) + + // 测试 C() 方法(一个参数,点号格式) + t.Run("MySQL C() method with dot notation", func(t *testing.T) { + result := mysqlDB.C("order.name") + expected := "`app_order`.`name`" + if result != expected { + t.Errorf("C(\"order.name\") = %q, want %q", result, expected) + } + }) + + // 创建 PostgreSQL 数据库实例 + pgDB := &HoTimeDB{ + Type: "postgres", + Prefix: "app_", + } + pgDB.initDialect() + + // 测试 PostgreSQL 的 T() 方法 + t.Run("PostgreSQL T() method", func(t *testing.T) { + result := pgDB.T("order") + expected := "\"app_order\"" + if result != expected { + t.Errorf("T(\"order\") = %q, want %q", result, expected) + } + }) + + // 测试 PostgreSQL 的 C() 方法 + t.Run("PostgreSQL C() method", func(t *testing.T) { + result := pgDB.C("order", "name") + expected := "\"app_order\".\"name\"" + if result != expected { + t.Errorf("C(\"order\", \"name\") = %q, want %q", result, expected) + } + }) +} + +// 打印测试结果(用于调试) +func ExampleIdentifierProcessor() { + // MySQL 示例 + mysqlProcessor := NewIdentifierProcessor(&MySQLDialect{}, "app_") + fmt.Println("MySQL:") + fmt.Println(" Table:", mysqlProcessor.ProcessTableName("order")) + fmt.Println(" Column:", mysqlProcessor.ProcessColumn("order.name")) + fmt.Println(" Condition:", mysqlProcessor.ProcessConditionString("user.id = order.user_id")) + + // PostgreSQL 示例 + pgProcessor := NewIdentifierProcessor(&PostgreSQLDialect{}, "app_") + fmt.Println("PostgreSQL:") + fmt.Println(" Table:", pgProcessor.ProcessTableName("order")) + fmt.Println(" Column:", pgProcessor.ProcessColumn("order.name")) + fmt.Println(" Condition:", pgProcessor.ProcessConditionString("user.id = order.user_id")) + + // Output: + // MySQL: + // Table: `app_order` + // Column: `app_order`.`name` + // Condition: `app_user`.`id` = `app_order`.`user_id` + // PostgreSQL: + // Table: "app_order" + // Column: "app_order"."name" + // Condition: "app_user"."id" = "app_order"."user_id" +} diff --git a/db/identifier.go b/db/identifier.go new file mode 100644 index 0000000..9dc23cd --- /dev/null +++ b/db/identifier.go @@ -0,0 +1,239 @@ +package db + +import ( + "regexp" + "strings" +) + +// IdentifierProcessor 标识符处理器 +// 用于处理表名、字段名的前缀添加和引号转换 +type IdentifierProcessor struct { + dialect Dialect + prefix string +} + +// NewIdentifierProcessor 创建标识符处理器 +func NewIdentifierProcessor(dialect Dialect, prefix string) *IdentifierProcessor { + return &IdentifierProcessor{ + dialect: dialect, + prefix: prefix, + } +} + +// ProcessTableName 处理表名(添加前缀+引号) +// 输入: "order" 或 "`order`" 或 "\"order\"" +// 输出: "`app_order`" (MySQL) 或 "\"app_order\"" (PostgreSQL/SQLite) +func (p *IdentifierProcessor) ProcessTableName(name string) string { + // 去除已有的引号 + name = p.stripQuotes(name) + + // 检查是否包含空格(别名情况,如 "order AS o") + if strings.Contains(name, " ") { + // 处理别名情况 + parts := strings.SplitN(name, " ", 2) + tableName := p.stripQuotes(parts[0]) + alias := parts[1] + return p.dialect.QuoteIdentifier(p.prefix+tableName) + " " + alias + } + + // 添加前缀和引号 + return p.dialect.QuoteIdentifier(p.prefix + name) +} + +// ProcessTableNameNoPrefix 处理表名(只添加引号,不添加前缀) +// 用于已经包含前缀的情况 +func (p *IdentifierProcessor) ProcessTableNameNoPrefix(name string) string { + name = p.stripQuotes(name) + if strings.Contains(name, " ") { + parts := strings.SplitN(name, " ", 2) + tableName := p.stripQuotes(parts[0]) + alias := parts[1] + return p.dialect.QuoteIdentifier(tableName) + " " + alias + } + return p.dialect.QuoteIdentifier(name) +} + +// ProcessColumn 处理 table.column 格式 +// 输入: "name" 或 "order.name" 或 "`order`.name" 或 "`order`.`name`" +// 输出: "`name`" 或 "`app_order`.`name`" (MySQL) +func (p *IdentifierProcessor) ProcessColumn(name string) string { + // 检查是否包含点号 + if !strings.Contains(name, ".") { + // 单独的列名,只加引号 + return p.dialect.QuoteIdentifier(p.stripQuotes(name)) + } + + // 处理 table.column 格式 + parts := p.splitTableColumn(name) + if len(parts) == 2 { + tableName := p.stripQuotes(parts[0]) + columnName := p.stripQuotes(parts[1]) + // 表名添加前缀 + return p.dialect.QuoteIdentifier(p.prefix+tableName) + "." + p.dialect.QuoteIdentifier(columnName) + } + + // 无法解析,返回原样但转换引号 + return p.convertQuotes(name) +} + +// ProcessColumnNoPrefix 处理 table.column 格式(不添加前缀) +func (p *IdentifierProcessor) ProcessColumnNoPrefix(name string) string { + if !strings.Contains(name, ".") { + return p.dialect.QuoteIdentifier(p.stripQuotes(name)) + } + + parts := p.splitTableColumn(name) + if len(parts) == 2 { + tableName := p.stripQuotes(parts[0]) + columnName := p.stripQuotes(parts[1]) + return p.dialect.QuoteIdentifier(tableName) + "." + p.dialect.QuoteIdentifier(columnName) + } + + return p.convertQuotes(name) +} + +// ProcessConditionString 智能解析条件字符串(如 ON 条件) +// 输入: "user.id = order.user_id AND order.status = 1" +// 输出: "`app_user`.`id` = `app_order`.`user_id` AND `app_order`.`status` = 1" (MySQL) +func (p *IdentifierProcessor) ProcessConditionString(condition string) string { + if condition == "" { + return condition + } + + result := condition + + // 首先处理已有完整引号的情况 `table`.`column` 或 "table"."column" + // 这些需要先处理,因为它们的格式最明确 + fullyQuotedPattern := regexp.MustCompile("[`\"]([a-zA-Z_][a-zA-Z0-9_]*)[`\"]\\.[`\"]([a-zA-Z_][a-zA-Z0-9_]*)[`\"]") + result = fullyQuotedPattern.ReplaceAllStringFunc(result, func(match string) string { + parts := fullyQuotedPattern.FindStringSubmatch(match) + if len(parts) == 3 { + tableName := parts[1] + colName := parts[2] + return p.dialect.QuoteIdentifier(p.prefix+tableName) + "." + p.dialect.QuoteIdentifier(colName) + } + return match + }) + + // 然后处理部分引号的情况 `table`.column 或 "table".column + // 注意:需要避免匹配已处理的内容(已经是双引号包裹的) + quotedTablePattern := regexp.MustCompile("[`\"]([a-zA-Z_][a-zA-Z0-9_]*)[`\"]\\.([a-zA-Z_][a-zA-Z0-9_]*)(?:[^`\"]|$)") + result = quotedTablePattern.ReplaceAllStringFunc(result, func(match string) string { + parts := quotedTablePattern.FindStringSubmatch(match) + if len(parts) >= 3 { + tableName := parts[1] + colName := parts[2] + // 保留末尾字符(如果有) + suffix := "" + if len(match) > len(parts[0])-1 { + lastChar := match[len(match)-1] + if lastChar != '`' && lastChar != '"' && !isIdentChar(lastChar) { + suffix = string(lastChar) + } + } + return p.dialect.QuoteIdentifier(p.prefix+tableName) + "." + p.dialect.QuoteIdentifier(colName) + suffix + } + return match + }) + + // 最后处理无引号的情况 table.column + // 使用更精确的正则,确保不匹配已处理的内容 + unquotedPattern := regexp.MustCompile(`([^` + "`" + `"\w]|^)([a-zA-Z_][a-zA-Z0-9_]*)\.([a-zA-Z_][a-zA-Z0-9_]*)([^` + "`" + `"\w(]|$)`) + result = unquotedPattern.ReplaceAllStringFunc(result, func(match string) string { + parts := unquotedPattern.FindStringSubmatch(match) + if len(parts) >= 5 { + prefix := parts[1] // 前面的边界字符 + tableName := parts[2] + colName := parts[3] + suffix := parts[4] // 后面的边界字符 + return prefix + p.dialect.QuoteIdentifier(p.prefix+tableName) + "." + p.dialect.QuoteIdentifier(colName) + suffix + } + return match + }) + + return result +} + +// isIdentChar 判断是否是标识符字符 +func isIdentChar(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' +} + +// ProcessFieldList 处理字段列表字符串 +// 输入: "order.id, user.name AS uname, COUNT(*)" +// 输出: "`app_order`.`id`, `app_user`.`name` AS uname, COUNT(*)" (MySQL) +func (p *IdentifierProcessor) ProcessFieldList(fields string) string { + if fields == "" || fields == "*" { + return fields + } + + // 使用与 ProcessConditionString 相同的逻辑 + return p.ProcessConditionString(fields) +} + +// stripQuotes 去除标识符两端的引号(反引号或双引号) +func (p *IdentifierProcessor) stripQuotes(name string) string { + name = strings.TrimSpace(name) + // 去除反引号 + if strings.HasPrefix(name, "`") && strings.HasSuffix(name, "`") { + return name[1 : len(name)-1] + } + // 去除双引号 + if strings.HasPrefix(name, "\"") && strings.HasSuffix(name, "\"") { + return name[1 : len(name)-1] + } + return name +} + +// splitTableColumn 分割 table.column 格式 +// 支持: table.column, `table`.column, `table`.`column`, "table".column 等 +func (p *IdentifierProcessor) splitTableColumn(name string) []string { + // 先尝试按点号分割 + dotIndex := -1 + + // 查找不在引号内的点号 + inQuote := false + quoteChar := byte(0) + for i := 0; i < len(name); i++ { + c := name[i] + if c == '`' || c == '"' { + if !inQuote { + inQuote = true + quoteChar = c + } else if c == quoteChar { + inQuote = false + } + } else if c == '.' && !inQuote { + dotIndex = i + break + } + } + + if dotIndex == -1 { + return []string{name} + } + + return []string{name[:dotIndex], name[dotIndex+1:]} +} + +// convertQuotes 将已有的引号转换为当前方言的引号格式 +func (p *IdentifierProcessor) convertQuotes(name string) string { + quoteChar := p.dialect.QuoteChar() + // 替换反引号 + name = strings.ReplaceAll(name, "`", quoteChar) + // 如果目标是反引号,需要替换双引号 + if quoteChar == "`" { + name = strings.ReplaceAll(name, "\"", quoteChar) + } + return name +} + +// GetDialect 获取方言 +func (p *IdentifierProcessor) GetDialect() Dialect { + return p.dialect +} + +// GetPrefix 获取前缀 +func (p *IdentifierProcessor) GetPrefix() string { + return p.prefix +} diff --git a/db/where.go b/db/where.go index 270aadf..050ab76 100644 --- a/db/where.go +++ b/db/where.go @@ -206,6 +206,7 @@ func (that *HoTimeDB) varCond(k string, v interface{}) (string, []interface{}) { where := "" res := make([]interface{}, 0) length := len(k) + processor := that.GetProcessor() if k == "[#]" { k = strings.Replace(k, "[#]", "", -1) @@ -219,73 +220,53 @@ func (that *HoTimeDB) varCond(k string, v interface{}) (string, []interface{}) { switch Substr(k, length-3, 3) { case "[>]": k = strings.Replace(k, "[>]", "", -1) - if !strings.Contains(k, ".") { - k = "`" + k + "` " - } + k = processor.ProcessColumn(k) + " " where += k + ">? " res = append(res, v) case "[<]": k = strings.Replace(k, "[<]", "", -1) - if !strings.Contains(k, ".") { - k = "`" + k + "` " - } + k = processor.ProcessColumn(k) + " " where += k + "=]": k = strings.Replace(k, "[>=]", "", -1) - if !strings.Contains(k, ".") { - k = "`" + k + "` " - } + k = processor.ProcessColumn(k) + " " where += k + ">=? " res = append(res, v) case "[<=]": k = strings.Replace(k, "[<=]", "", -1) - if !strings.Contains(k, ".") { - k = "`" + k + "` " - } + k = processor.ProcessColumn(k) + " " where += k + "<=? " res = append(res, v) case "[><]": k = strings.Replace(k, "[><]", "", -1) - if !strings.Contains(k, ".") { - k = "`" + k + "` " - } + k = processor.ProcessColumn(k) + " " where += k + " NOT BETWEEN ? AND ? " vs := ObjToSlice(v) res = append(res, vs[0]) res = append(res, vs[1]) case "[<>]": k = strings.Replace(k, "[<>]", "", -1) - if !strings.Contains(k, ".") { - k = "`" + k + "` " - } + k = processor.ProcessColumn(k) + " " where += k + " BETWEEN ? AND ? " vs := ObjToSlice(v) res = append(res, vs[0]) @@ -339,9 +312,8 @@ func (that *HoTimeDB) varCond(k string, v interface{}) (string, []interface{}) { // handleDefaultCondition 处理默认条件(带方括号但不是特殊操作符) func (that *HoTimeDB) handleDefaultCondition(k string, v interface{}, where string, res []interface{}) (string, []interface{}) { - if !strings.Contains(k, ".") { - k = "`" + k + "` " - } + processor := that.GetProcessor() + k = processor.ProcessColumn(k) + " " if reflect.ValueOf(v).Type().String() == "common.Slice" || strings.Contains(reflect.ValueOf(v).Type().String(), "[]") { vs := ObjToSlice(v) @@ -369,9 +341,8 @@ func (that *HoTimeDB) handleDefaultCondition(k string, v interface{}, where stri // handlePlainField 处理普通字段(无方括号) func (that *HoTimeDB) handlePlainField(k string, v interface{}, where string, res []interface{}) (string, []interface{}) { - if !strings.Contains(k, ".") { - k = "`" + k + "` " - } + processor := that.GetProcessor() + k = processor.ProcessColumn(k) + " " if v == nil { where += k + " IS NULL " diff --git a/example/main.go b/example/main.go index 2bf1f89..4d81972 100644 --- a/example/main.go +++ b/example/main.go @@ -1,10 +1,10 @@ package main import ( + . "code.hoteas.com/golang/hotime" "fmt" "time" - . "code.hoteas.com/golang/hotime" . "code.hoteas.com/golang/hotime/common" . "code.hoteas.com/golang/hotime/db" ) diff --git a/example/test_server.exe b/example/test_server.exe index d57aea6..53c6319 100644 Binary files a/example/test_server.exe and b/example/test_server.exe differ