fix(db): 修复事务中SQL错误时的回滚机制

- 添加txFailed标志跟踪事务内的SQL错误状态
- 在事务模式下遇到SQL错误时立即标记事务失败并阻止重试
- 修改Action方法确保SQL出错时强制回滚事务
- 为事务失败场景添加完整的单元测试覆盖
- 防止在事务已回滚的情况下继续执行后续操作导致数据不一致
This commit is contained in:
hoteas 2026-04-11 21:53:28 +08:00
parent 7ab803c5cc
commit 2f1b05f3d7
4 changed files with 197 additions and 17 deletions

View File

@ -31,9 +31,10 @@ type HoTimeDB struct {
Mode int // mode为0生产模式,1为测试模式,2为开发模式
mu sync.RWMutex
limitMu sync.Mutex
Dialect Dialect // 数据库方言适配器
testTx *sql.Tx // 测试事务:设置后所有操作都在此事务内,测试结束回滚
Dialect Dialect // 数据库方言适配器
testTx *sql.Tx // 测试事务:设置后所有操作都在此事务内,测试结束回滚
testMu *sync.Mutex // 保护 testTx 单连接不被并发访问
txFailed *bool // 事务内是否有SQL出错出错则事务必须回滚指针确保按值传递 HoTimeDB 时状态共享)
}
// SetConnect 设置数据库配置连接

View File

@ -62,15 +62,21 @@ func (that *HoTimeDB) queryWithRetry(query string, retried bool, args ...interfa
processedArgs := that.processArgs(args)
if that.testTx != nil {
if that.testMu != nil { that.testMu.Lock() }
if that.testMu != nil {
that.testMu.Lock()
}
resl, err = that.testTx.Query(query, processedArgs...)
that.LastErr.SetError(err)
if err != nil {
if that.testMu != nil { that.testMu.Unlock() }
if that.testMu != nil {
that.testMu.Unlock()
}
return nil
}
result := that.Row(resl)
if that.testMu != nil { that.testMu.Unlock() }
if that.testMu != nil {
that.testMu.Unlock()
}
return result
} else if that.Tx != nil {
resl, err = that.Tx.Query(query, processedArgs...)
@ -80,7 +86,12 @@ func (that *HoTimeDB) queryWithRetry(query string, retried bool, args ...interfa
that.LastErr.SetError(err)
if err != nil {
if !retried {
if that.Tx != nil {
if that.txFailed != nil {
*that.txFailed = true
}
}
if !retried && that.Tx == nil {
if pingErr := db.Ping(); pingErr == nil {
return that.queryWithRetry(query, true, args...)
}
@ -128,9 +139,13 @@ func (that *HoTimeDB) execWithRetry(query string, retried bool, args ...interfac
processedArgs := that.processArgs(args)
if that.testTx != nil {
if that.testMu != nil { that.testMu.Lock() }
if that.testMu != nil {
that.testMu.Lock()
}
resl, e = that.testTx.Exec(query, processedArgs...)
if that.testMu != nil { that.testMu.Unlock() }
if that.testMu != nil {
that.testMu.Unlock()
}
} else if that.Tx != nil {
resl, e = that.Tx.Exec(query, processedArgs...)
} else {
@ -143,6 +158,14 @@ func (that *HoTimeDB) execWithRetry(query string, retried bool, args ...interfac
if that.testTx != nil {
return resl, that.LastErr
}
// 事务内不做连接级重试:死锁等错误会导致 MySQL 自动回滚事务,
// 在已回滚的 Tx 上重试会以 auto-commit 模式执行,造成数据不一致
if that.Tx != nil {
if that.txFailed != nil {
*that.txFailed = true
}
return resl, that.LastErr
}
if !retried {
if pingErr := that.DB.Ping(); pingErr == nil {
return that.execWithRetry(query, true, args...)

View File

@ -35,21 +35,36 @@ func (that *HoTimeDB) Action(action func(db HoTimeDB) (isSuccess bool)) (isSucce
testMu: that.testMu,
}
txFailed := false
db.txFailed = &txFailed
if that.testTx != nil {
spName := fmt.Sprintf("sp_%d", atomic.AddUint64(&savepointCounter, 1))
if that.testMu != nil { that.testMu.Lock() }
if that.testMu != nil {
that.testMu.Lock()
}
_, _ = that.testTx.Exec("SAVEPOINT " + spName)
if that.testMu != nil { that.testMu.Unlock() }
if that.testMu != nil {
that.testMu.Unlock()
}
db.Tx = that.testTx
isSuccess = action(db)
if !isSuccess {
if that.testMu != nil { that.testMu.Lock() }
if txFailed || !isSuccess {
if that.testMu != nil {
that.testMu.Lock()
}
_, _ = that.testTx.Exec("ROLLBACK TO SAVEPOINT " + spName)
if that.testMu != nil { that.testMu.Unlock() }
} else {
if that.testMu != nil { that.testMu.Lock() }
_, _ = that.testTx.Exec("RELEASE SAVEPOINT " + spName)
if that.testMu != nil { that.testMu.Unlock() }
if that.testMu != nil {
that.testMu.Unlock()
}
return false
}
if that.testMu != nil {
that.testMu.Lock()
}
_, _ = that.testTx.Exec("RELEASE SAVEPOINT " + spName)
if that.testMu != nil {
that.testMu.Unlock()
}
return isSuccess
}
@ -65,6 +80,12 @@ func (that *HoTimeDB) Action(action func(db HoTimeDB) (isSuccess bool)) (isSucce
isSuccess = action(db)
// SQL 出过错 → 事务必须回滚,不管回调返回什么
if txFailed {
_ = db.Tx.Rollback()
return false
}
if !isSuccess {
err = db.Tx.Rollback()
if err != nil {

135
db/transaction_test.go Normal file
View File

@ -0,0 +1,135 @@
package db
import (
. "code.hoteas.com/golang/hotime/common"
"database/sql"
"testing"
"github.com/sirupsen/logrus"
)
func newTestDB(t *testing.T) *HoTimeDB {
t.Helper()
sqlDB, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
logger := logrus.New()
logger.SetLevel(logrus.WarnLevel)
db := &HoTimeDB{
DB: sqlDB,
Type: "sqlite3",
Dialect: &SQLiteDialect{},
LastErr: &Error{Logger: logger},
Log: logger,
}
_, execErr := sqlDB.Exec("CREATE TABLE test_item (id INTEGER PRIMARY KEY, name TEXT, value INTEGER)")
if execErr != nil {
t.Fatal(execErr)
}
_, execErr = sqlDB.Exec("INSERT INTO test_item (id, name, value) VALUES (1, 'foo', 100)")
if execErr != nil {
t.Fatal(execErr)
}
return db
}
func TestAction_TxFailedForceRollback(t *testing.T) {
db := newTestDB(t)
defer db.DB.Close()
result := db.Action(func(tx HoTimeDB) (isSuccess bool) {
tx.Update("test_item", Map{"value": 999}, Map{"id": 1})
tx.Exec("THIS IS INVALID SQL THAT WILL FAIL")
return true
})
if result != false {
t.Fatalf("Action should return false when SQL error occurred, got true")
}
row := db.Get("test_item", "value", Map{"id": 1})
if row == nil {
t.Fatal("failed to read test_item")
}
val := row.GetCeilInt64("value")
if val != 100 {
t.Fatalf("value should be rolled back to 100, got %d", val)
}
}
func TestAction_NormalCommit(t *testing.T) {
db := newTestDB(t)
defer db.DB.Close()
result := db.Action(func(tx HoTimeDB) (isSuccess bool) {
tx.Update("test_item", Map{"value": 200}, Map{"id": 1})
return true
})
if result != true {
t.Fatalf("Action should return true on success, got false")
}
row := db.Get("test_item", "value", Map{"id": 1})
if row == nil {
t.Fatal("failed to read test_item")
}
val := row.GetCeilInt64("value")
if val != 200 {
t.Fatalf("value should be committed to 200, got %d", val)
}
}
func TestAction_NormalRollback(t *testing.T) {
db := newTestDB(t)
defer db.DB.Close()
result := db.Action(func(tx HoTimeDB) (isSuccess bool) {
tx.Update("test_item", Map{"value": 300}, Map{"id": 1})
return false
})
if result != false {
t.Fatalf("Action should return false, got true")
}
row := db.Get("test_item", "value", Map{"id": 1})
if row == nil {
t.Fatal("failed to read test_item")
}
val := row.GetCeilInt64("value")
if val != 100 {
t.Fatalf("value should be rolled back to 100, got %d", val)
}
}
func TestAction_SqlErrorThenReturnTrue_MustRollback(t *testing.T) {
db := newTestDB(t)
defer db.DB.Close()
result := db.Action(func(tx HoTimeDB) (isSuccess bool) {
tx.Update("test_item", Map{"value": 500}, Map{"id": 1})
tx.Exec("INSERT INTO nonexistent_table (x) VALUES (1)")
tx.Update("test_item", Map{"value": 600}, Map{"id": 1})
return true
})
if result != false {
t.Fatalf("Action should return false when SQL error occurred mid-transaction, got true")
}
row := db.Get("test_item", "value", Map{"id": 1})
if row == nil {
t.Fatal("failed to read test_item")
}
val := row.GetCeilInt64("value")
if val != 100 {
t.Fatalf("value should be rolled back to 100 (original), got %d", val)
}
}