fix(db): 修复事务中SQL错误时的回滚机制
- 添加txFailed标志跟踪事务内的SQL错误状态 - 在事务模式下遇到SQL错误时立即标记事务失败并阻止重试 - 修改Action方法确保SQL出错时强制回滚事务 - 为事务失败场景添加完整的单元测试覆盖 - 防止在事务已回滚的情况下继续执行后续操作导致数据不一致
This commit is contained in:
parent
7ab803c5cc
commit
2f1b05f3d7
5
db/db.go
5
db/db.go
@ -31,9 +31,10 @@ type HoTimeDB struct {
|
||||
Mode int // mode为0生产模式,1为测试模式,2为开发模式
|
||||
mu sync.RWMutex
|
||||
limitMu sync.Mutex
|
||||
Dialect Dialect // 数据库方言适配器
|
||||
testTx *sql.Tx // 测试事务:设置后所有操作都在此事务内,测试结束回滚
|
||||
Dialect Dialect // 数据库方言适配器
|
||||
testTx *sql.Tx // 测试事务:设置后所有操作都在此事务内,测试结束回滚
|
||||
testMu *sync.Mutex // 保护 testTx 单连接不被并发访问
|
||||
txFailed *bool // 事务内是否有SQL出错,出错则事务必须回滚(指针确保按值传递 HoTimeDB 时状态共享)
|
||||
}
|
||||
|
||||
// SetConnect 设置数据库配置连接
|
||||
|
||||
35
db/query.go
35
db/query.go
@ -62,15 +62,21 @@ func (that *HoTimeDB) queryWithRetry(query string, retried bool, args ...interfa
|
||||
processedArgs := that.processArgs(args)
|
||||
|
||||
if that.testTx != nil {
|
||||
if that.testMu != nil { that.testMu.Lock() }
|
||||
if that.testMu != nil {
|
||||
that.testMu.Lock()
|
||||
}
|
||||
resl, err = that.testTx.Query(query, processedArgs...)
|
||||
that.LastErr.SetError(err)
|
||||
if err != nil {
|
||||
if that.testMu != nil { that.testMu.Unlock() }
|
||||
if that.testMu != nil {
|
||||
that.testMu.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
result := that.Row(resl)
|
||||
if that.testMu != nil { that.testMu.Unlock() }
|
||||
if that.testMu != nil {
|
||||
that.testMu.Unlock()
|
||||
}
|
||||
return result
|
||||
} else if that.Tx != nil {
|
||||
resl, err = that.Tx.Query(query, processedArgs...)
|
||||
@ -80,7 +86,12 @@ func (that *HoTimeDB) queryWithRetry(query string, retried bool, args ...interfa
|
||||
|
||||
that.LastErr.SetError(err)
|
||||
if err != nil {
|
||||
if !retried {
|
||||
if that.Tx != nil {
|
||||
if that.txFailed != nil {
|
||||
*that.txFailed = true
|
||||
}
|
||||
}
|
||||
if !retried && that.Tx == nil {
|
||||
if pingErr := db.Ping(); pingErr == nil {
|
||||
return that.queryWithRetry(query, true, args...)
|
||||
}
|
||||
@ -128,9 +139,13 @@ func (that *HoTimeDB) execWithRetry(query string, retried bool, args ...interfac
|
||||
processedArgs := that.processArgs(args)
|
||||
|
||||
if that.testTx != nil {
|
||||
if that.testMu != nil { that.testMu.Lock() }
|
||||
if that.testMu != nil {
|
||||
that.testMu.Lock()
|
||||
}
|
||||
resl, e = that.testTx.Exec(query, processedArgs...)
|
||||
if that.testMu != nil { that.testMu.Unlock() }
|
||||
if that.testMu != nil {
|
||||
that.testMu.Unlock()
|
||||
}
|
||||
} else if that.Tx != nil {
|
||||
resl, e = that.Tx.Exec(query, processedArgs...)
|
||||
} else {
|
||||
@ -143,6 +158,14 @@ func (that *HoTimeDB) execWithRetry(query string, retried bool, args ...interfac
|
||||
if that.testTx != nil {
|
||||
return resl, that.LastErr
|
||||
}
|
||||
// 事务内不做连接级重试:死锁等错误会导致 MySQL 自动回滚事务,
|
||||
// 在已回滚的 Tx 上重试会以 auto-commit 模式执行,造成数据不一致
|
||||
if that.Tx != nil {
|
||||
if that.txFailed != nil {
|
||||
*that.txFailed = true
|
||||
}
|
||||
return resl, that.LastErr
|
||||
}
|
||||
if !retried {
|
||||
if pingErr := that.DB.Ping(); pingErr == nil {
|
||||
return that.execWithRetry(query, true, args...)
|
||||
|
||||
@ -35,21 +35,36 @@ func (that *HoTimeDB) Action(action func(db HoTimeDB) (isSuccess bool)) (isSucce
|
||||
testMu: that.testMu,
|
||||
}
|
||||
|
||||
txFailed := false
|
||||
db.txFailed = &txFailed
|
||||
|
||||
if that.testTx != nil {
|
||||
spName := fmt.Sprintf("sp_%d", atomic.AddUint64(&savepointCounter, 1))
|
||||
if that.testMu != nil { that.testMu.Lock() }
|
||||
if that.testMu != nil {
|
||||
that.testMu.Lock()
|
||||
}
|
||||
_, _ = that.testTx.Exec("SAVEPOINT " + spName)
|
||||
if that.testMu != nil { that.testMu.Unlock() }
|
||||
if that.testMu != nil {
|
||||
that.testMu.Unlock()
|
||||
}
|
||||
db.Tx = that.testTx
|
||||
isSuccess = action(db)
|
||||
if !isSuccess {
|
||||
if that.testMu != nil { that.testMu.Lock() }
|
||||
if txFailed || !isSuccess {
|
||||
if that.testMu != nil {
|
||||
that.testMu.Lock()
|
||||
}
|
||||
_, _ = that.testTx.Exec("ROLLBACK TO SAVEPOINT " + spName)
|
||||
if that.testMu != nil { that.testMu.Unlock() }
|
||||
} else {
|
||||
if that.testMu != nil { that.testMu.Lock() }
|
||||
_, _ = that.testTx.Exec("RELEASE SAVEPOINT " + spName)
|
||||
if that.testMu != nil { that.testMu.Unlock() }
|
||||
if that.testMu != nil {
|
||||
that.testMu.Unlock()
|
||||
}
|
||||
return false
|
||||
}
|
||||
if that.testMu != nil {
|
||||
that.testMu.Lock()
|
||||
}
|
||||
_, _ = that.testTx.Exec("RELEASE SAVEPOINT " + spName)
|
||||
if that.testMu != nil {
|
||||
that.testMu.Unlock()
|
||||
}
|
||||
return isSuccess
|
||||
}
|
||||
@ -65,6 +80,12 @@ func (that *HoTimeDB) Action(action func(db HoTimeDB) (isSuccess bool)) (isSucce
|
||||
|
||||
isSuccess = action(db)
|
||||
|
||||
// SQL 出过错 → 事务必须回滚,不管回调返回什么
|
||||
if txFailed {
|
||||
_ = db.Tx.Rollback()
|
||||
return false
|
||||
}
|
||||
|
||||
if !isSuccess {
|
||||
err = db.Tx.Rollback()
|
||||
if err != nil {
|
||||
|
||||
135
db/transaction_test.go
Normal file
135
db/transaction_test.go
Normal file
@ -0,0 +1,135 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
. "code.hoteas.com/golang/hotime/common"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func newTestDB(t *testing.T) *HoTimeDB {
|
||||
t.Helper()
|
||||
sqlDB, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
logger := logrus.New()
|
||||
logger.SetLevel(logrus.WarnLevel)
|
||||
db := &HoTimeDB{
|
||||
DB: sqlDB,
|
||||
Type: "sqlite3",
|
||||
Dialect: &SQLiteDialect{},
|
||||
LastErr: &Error{Logger: logger},
|
||||
Log: logger,
|
||||
}
|
||||
_, execErr := sqlDB.Exec("CREATE TABLE test_item (id INTEGER PRIMARY KEY, name TEXT, value INTEGER)")
|
||||
if execErr != nil {
|
||||
t.Fatal(execErr)
|
||||
}
|
||||
_, execErr = sqlDB.Exec("INSERT INTO test_item (id, name, value) VALUES (1, 'foo', 100)")
|
||||
if execErr != nil {
|
||||
t.Fatal(execErr)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func TestAction_TxFailedForceRollback(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
defer db.DB.Close()
|
||||
|
||||
result := db.Action(func(tx HoTimeDB) (isSuccess bool) {
|
||||
tx.Update("test_item", Map{"value": 999}, Map{"id": 1})
|
||||
|
||||
tx.Exec("THIS IS INVALID SQL THAT WILL FAIL")
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if result != false {
|
||||
t.Fatalf("Action should return false when SQL error occurred, got true")
|
||||
}
|
||||
|
||||
row := db.Get("test_item", "value", Map{"id": 1})
|
||||
if row == nil {
|
||||
t.Fatal("failed to read test_item")
|
||||
}
|
||||
val := row.GetCeilInt64("value")
|
||||
if val != 100 {
|
||||
t.Fatalf("value should be rolled back to 100, got %d", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAction_NormalCommit(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
defer db.DB.Close()
|
||||
|
||||
result := db.Action(func(tx HoTimeDB) (isSuccess bool) {
|
||||
tx.Update("test_item", Map{"value": 200}, Map{"id": 1})
|
||||
return true
|
||||
})
|
||||
|
||||
if result != true {
|
||||
t.Fatalf("Action should return true on success, got false")
|
||||
}
|
||||
|
||||
row := db.Get("test_item", "value", Map{"id": 1})
|
||||
if row == nil {
|
||||
t.Fatal("failed to read test_item")
|
||||
}
|
||||
val := row.GetCeilInt64("value")
|
||||
if val != 200 {
|
||||
t.Fatalf("value should be committed to 200, got %d", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAction_NormalRollback(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
defer db.DB.Close()
|
||||
|
||||
result := db.Action(func(tx HoTimeDB) (isSuccess bool) {
|
||||
tx.Update("test_item", Map{"value": 300}, Map{"id": 1})
|
||||
return false
|
||||
})
|
||||
|
||||
if result != false {
|
||||
t.Fatalf("Action should return false, got true")
|
||||
}
|
||||
|
||||
row := db.Get("test_item", "value", Map{"id": 1})
|
||||
if row == nil {
|
||||
t.Fatal("failed to read test_item")
|
||||
}
|
||||
val := row.GetCeilInt64("value")
|
||||
if val != 100 {
|
||||
t.Fatalf("value should be rolled back to 100, got %d", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAction_SqlErrorThenReturnTrue_MustRollback(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
defer db.DB.Close()
|
||||
|
||||
result := db.Action(func(tx HoTimeDB) (isSuccess bool) {
|
||||
tx.Update("test_item", Map{"value": 500}, Map{"id": 1})
|
||||
|
||||
tx.Exec("INSERT INTO nonexistent_table (x) VALUES (1)")
|
||||
|
||||
tx.Update("test_item", Map{"value": 600}, Map{"id": 1})
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if result != false {
|
||||
t.Fatalf("Action should return false when SQL error occurred mid-transaction, got true")
|
||||
}
|
||||
|
||||
row := db.Get("test_item", "value", Map{"id": 1})
|
||||
if row == nil {
|
||||
t.Fatal("failed to read test_item")
|
||||
}
|
||||
val := row.GetCeilInt64("value")
|
||||
if val != 100 {
|
||||
t.Fatalf("value should be rolled back to 100 (original), got %d", val)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user