hotime/db/dialect.go

289 lines
7.8 KiB
Go
Raw Normal View History

package db
import (
"fmt"
"strings"
)
// Dialect 数据库方言接口
// 用于处理不同数据库之间的语法差异
type Dialect interface {
// Quote 对表名/字段名添加引号
// MySQL 使用反引号 `name`
// PostgreSQL 使用双引号 "name"
// SQLite 使用双引号或方括号 "name" 或 [name]
Quote(name string) string
// QuoteIdentifier 处理单个标识符(去除已有引号,添加正确引号)
// 输入可能带有反引号或双引号,会先去除再添加正确格式
QuoteIdentifier(name string) string
// QuoteChar 返回引号字符
// MySQL: `
// PostgreSQL/SQLite: "
QuoteChar() string
// Placeholder 生成占位符
// MySQL/SQLite 使用 ?
// PostgreSQL 使用 $1, $2, $3...
Placeholder(index int) string
// Placeholders 生成多个占位符,用逗号分隔
Placeholders(count int, startIndex int) string
// SupportsLastInsertId 是否支持 LastInsertId
// PostgreSQL 不支持,需要使用 RETURNING
SupportsLastInsertId() bool
// ReturningClause 生成 RETURNING 子句(用于 PostgreSQL
ReturningClause(column string) string
// UpsertSQL 生成 Upsert 语句
// MySQL: INSERT ... ON DUPLICATE KEY UPDATE ...
// PostgreSQL: INSERT ... ON CONFLICT ... DO UPDATE SET ...
// SQLite: INSERT OR REPLACE / INSERT ... ON CONFLICT ...
UpsertSQL(table string, columns []string, uniqueKeys []string, updateColumns []string) string
// GetName 获取方言名称
GetName() string
}
// MySQLDialect MySQL 方言实现
type MySQLDialect struct{}
func (d *MySQLDialect) GetName() string {
return "mysql"
}
func (d *MySQLDialect) Quote(name string) string {
// 如果已经包含点号(表.字段)或空格(别名),不添加引号
if strings.Contains(name, ".") || strings.Contains(name, " ") {
return name
}
return "`" + name + "`"
}
func (d *MySQLDialect) QuoteIdentifier(name string) string {
// 去除已有的引号(反引号和双引号)
name = strings.Trim(name, "`\"")
return "`" + name + "`"
}
func (d *MySQLDialect) QuoteChar() string {
return "`"
}
func (d *MySQLDialect) Placeholder(index int) string {
return "?"
}
func (d *MySQLDialect) Placeholders(count int, startIndex int) string {
if count <= 0 {
return ""
}
placeholders := make([]string, count)
for i := 0; i < count; i++ {
placeholders[i] = "?"
}
return strings.Join(placeholders, ",")
}
func (d *MySQLDialect) SupportsLastInsertId() bool {
return true
}
func (d *MySQLDialect) ReturningClause(column string) string {
return "" // MySQL 不支持 RETURNING
}
func (d *MySQLDialect) UpsertSQL(table string, columns []string, uniqueKeys []string, updateColumns []string) string {
// INSERT INTO table (col1, col2) VALUES (?, ?)
// ON DUPLICATE KEY UPDATE col1 = VALUES(col1), col2 = VALUES(col2)
quotedCols := make([]string, len(columns))
for i, col := range columns {
quotedCols[i] = d.Quote(col)
}
placeholders := d.Placeholders(len(columns), 1)
updateParts := make([]string, len(updateColumns))
for i, col := range updateColumns {
// 检查是否是 [#] 标记的直接 SQL
if strings.HasSuffix(col, "[#]") {
// 这种情况在调用处处理
updateParts[i] = col
} else {
quotedCol := d.Quote(col)
updateParts[i] = quotedCol + " = VALUES(" + quotedCol + ")"
}
}
return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE %s",
d.Quote(table),
strings.Join(quotedCols, ", "),
placeholders,
strings.Join(updateParts, ", "))
}
// PostgreSQLDialect PostgreSQL 方言实现
type PostgreSQLDialect struct{}
func (d *PostgreSQLDialect) GetName() string {
return "postgres"
}
func (d *PostgreSQLDialect) Quote(name string) string {
// 如果已经包含点号(表.字段)或空格(别名),不添加引号
if strings.Contains(name, ".") || strings.Contains(name, " ") {
return name
}
return "\"" + name + "\""
}
func (d *PostgreSQLDialect) QuoteIdentifier(name string) string {
// 去除已有的引号(反引号和双引号)
name = strings.Trim(name, "`\"")
return "\"" + name + "\""
}
func (d *PostgreSQLDialect) QuoteChar() string {
return "\""
}
func (d *PostgreSQLDialect) Placeholder(index int) string {
return fmt.Sprintf("$%d", index)
}
func (d *PostgreSQLDialect) Placeholders(count int, startIndex int) string {
if count <= 0 {
return ""
}
placeholders := make([]string, count)
for i := 0; i < count; i++ {
placeholders[i] = fmt.Sprintf("$%d", startIndex+i)
}
return strings.Join(placeholders, ",")
}
func (d *PostgreSQLDialect) SupportsLastInsertId() bool {
return false // PostgreSQL 需要使用 RETURNING
}
func (d *PostgreSQLDialect) ReturningClause(column string) string {
return " RETURNING " + d.Quote(column)
}
func (d *PostgreSQLDialect) UpsertSQL(table string, columns []string, uniqueKeys []string, updateColumns []string) string {
// INSERT INTO table (col1, col2) VALUES ($1, $2)
// ON CONFLICT (unique_key) DO UPDATE SET col1 = EXCLUDED.col1, col2 = EXCLUDED.col2
quotedCols := make([]string, len(columns))
for i, col := range columns {
quotedCols[i] = d.Quote(col)
}
placeholders := d.Placeholders(len(columns), 1)
quotedUniqueKeys := make([]string, len(uniqueKeys))
for i, key := range uniqueKeys {
quotedUniqueKeys[i] = d.Quote(key)
}
updateParts := make([]string, len(updateColumns))
for i, col := range updateColumns {
if strings.HasSuffix(col, "[#]") {
updateParts[i] = col
} else {
quotedCol := d.Quote(col)
updateParts[i] = quotedCol + " = EXCLUDED." + quotedCol
}
}
return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO UPDATE SET %s",
d.Quote(table),
strings.Join(quotedCols, ", "),
placeholders,
strings.Join(quotedUniqueKeys, ", "),
strings.Join(updateParts, ", "))
}
// SQLiteDialect SQLite 方言实现
type SQLiteDialect struct{}
func (d *SQLiteDialect) GetName() string {
return "sqlite3"
}
func (d *SQLiteDialect) Quote(name string) string {
// 如果已经包含点号(表.字段)或空格(别名),不添加引号
if strings.Contains(name, ".") || strings.Contains(name, " ") {
return name
}
return "\"" + name + "\""
}
func (d *SQLiteDialect) QuoteIdentifier(name string) string {
// 去除已有的引号(反引号和双引号)
name = strings.Trim(name, "`\"")
return "\"" + name + "\""
}
func (d *SQLiteDialect) QuoteChar() string {
return "\""
}
func (d *SQLiteDialect) Placeholder(index int) string {
return "?"
}
func (d *SQLiteDialect) Placeholders(count int, startIndex int) string {
if count <= 0 {
return ""
}
placeholders := make([]string, count)
for i := 0; i < count; i++ {
placeholders[i] = "?"
}
return strings.Join(placeholders, ",")
}
func (d *SQLiteDialect) SupportsLastInsertId() bool {
return true
}
func (d *SQLiteDialect) ReturningClause(column string) string {
return "" // SQLite 3.35+ 支持 RETURNING但为兼容性暂不使用
}
func (d *SQLiteDialect) UpsertSQL(table string, columns []string, uniqueKeys []string, updateColumns []string) string {
// INSERT INTO table (col1, col2) VALUES (?, ?)
// ON CONFLICT (unique_key) DO UPDATE SET col1 = excluded.col1, col2 = excluded.col2
quotedCols := make([]string, len(columns))
for i, col := range columns {
quotedCols[i] = d.Quote(col)
}
placeholders := d.Placeholders(len(columns), 1)
quotedUniqueKeys := make([]string, len(uniqueKeys))
for i, key := range uniqueKeys {
quotedUniqueKeys[i] = d.Quote(key)
}
updateParts := make([]string, len(updateColumns))
for i, col := range updateColumns {
if strings.HasSuffix(col, "[#]") {
updateParts[i] = col
} else {
quotedCol := d.Quote(col)
updateParts[i] = quotedCol + " = excluded." + quotedCol
}
}
return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO UPDATE SET %s",
d.Quote(table),
strings.Join(quotedCols, ", "),
placeholders,
strings.Join(quotedUniqueKeys, ", "),
strings.Join(updateParts, ", "))
}