的今天开始学习数据库操作:

  • database/sql:标准库,偏底层,显示SQL
  • GORM: ORM 开发效率高,适合快速CRUD

database/sql

database/sql是Go标准库提供的通用SQL数据库接口,需要配合具体数据库driver使用。

安装依赖

1
go get github.com/go-sql-driver/mysql

创建pkg/db/db.go文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
package db

import (
"database/sql"
"log"
"time"

_ "github.com/go-sql-driver/mysql" // 注意这里,忘记引用的话会报错
)

func NewDB(dsn string) (*sql.DB, error) {
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
}

db.SetMaxOpenConns(20)
db.SetMaxIdleConns(10)
db.SetConnMaxLifetime(time.Hour)
// sql.Open不一定马上建立真实连接,需要执行一次Ping()命令才能建立真实链接,并检测是否可用。
if err := db.Ping(); err != nil {
return nil, err
}
log.Println("pong....")
return db, nil
}

sql.DB并不是一个链接,而是数据库连接池的句柄,内部管理多个链接,应该在应用启动时创建一次并复用,而不是每次请求都创建。

创建表结构

1
2
3
4
5
6
7
8
9
10
CREATE TABLE users (
id INT PRIMARY KEY AUTO_INCREMENT,
name VARCHAR(50) NOT NULL,
email VARCHAR(100) NOT NULL UNIQUE,
age INT NOT NULL,
password VARCHAR(255) NOT NULL,
status VARCHAR(20) NOT NULL,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL
);

编写repository

internal/repository/user_repository.go(重新调整了下目录)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package repository

import (
"context"
"database/sql"
"errors"
"log"
"sync"

"dev.net.cn/goweb/errs"
"dev.net.cn/goweb/model"
)

type UserRepository struct {
mu sync.RWMutex
users map[int]*model.User
db *sql.DB
}

func NewUserRepository(db *sql.DB) *UserRepository {
return &UserRepository{
users: make(map[int]*model.User),
db: db,
}
}

func (r *UserRepository) FindByID(ctx context.Context, id int) (*model.User, error) {
query := `
select id,name,age,email,password,create_at,update_at from users where id = ?
`

var user model.User
err := r.db.QueryRowContext(ctx, query, id).Scan(
&user.ID,
&user.Name,
&user.Age,
&user.Email,
&user.Status,
&user.CreateAt,
&user.UpdateAt,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, errs.ErrUserNotFound
}
return nil, err
}

return &user, nil
}

func (r *UserRepository) Create(ctx context.Context, user *model.User) (*model.User, error) {
query := `
insert into users(name,email,age,password,status,create_at,update_at) values (?,?,?,?,?,?,?)
`

result, err := r.db.ExecContext(
ctx,
query,
user.Name,
user.Email,
user.Age,
user.Password,
user.Status,
user.CreateAt,
user.UpdateAt,
)

if err != nil {
return nil, err
}

id, err := result.LastInsertId()
if err != nil {
return nil, err
}
user.ID = int(id)
log.Println("mysql ....")
return user, nil
}

func (r *UserRepository) ListUser(ctx context.Context) ([]*model.User, error) {
query := `select id,name,email,age,password,status,create_at,update_at from users`

rows, err := r.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
/**
QueryContext执行后,需要关闭,避免因异常导致持续占有链接,导致资源耗尽。
*/
defer func() {
err := rows.Close()
if err != nil {
log.Fatalf("查询关闭失败 : %v", err)
}
}()

var users []*model.User

for rows.Next() {
var u model.User

err := rows.Scan(
&u.ID,
&u.Name,
&u.Email,
&u.Age,
&u.Password,
&u.Status,
&u.CreateAt,
&u.UpdateAt,
)
if err != nil {
return nil, err
}
users = append(users, &u)
}
/*
这里需要补一个Err()检查,因为rows.Next()返回false退出循环后,有可能是没数据了,也有可能时网络传输出现异常,
为了区分这个问题,就需要做一次rows.Err()检查
*/
if err = rows.Err(); err != nil {
return nil, err
}

return users, nil
}

注意,请使用带Context的方法:ExecContextQueryContextQueryRowContext,这样请求取消或者超时时,数据库也有机会取消。

QueryRow用于查询单行结果,通常配合ScanQuery用于查询多行结果,需要遍历rows.Next(),最后检查rows.Err()

注意点1:

对于QueryContext,查询完成后需要rows.Close(),因为它执行后会从底层得连接池sql.DB中独占一个有效的数据库链接,知道rows的数据全部读完。如果你得代码在rows.Next()循环中因为某种原因提前return,或者逻辑问题没处理完且没有调用Close(),那么这个TCP连接将会被永远挂起,无法放回连接池中。一单达到SetMaxOpenConns得上线就会彻底卡死。

注意点2:

循环完成后,还需要补一个rows.Err()检查。避免读取某一行是因网络波动发生异常,为了区分是读取完数据,还是发生异常,这里就需要在循环外再调用rows.Err()进行二次确认。避免将不完整得数据返回给上层。

注意点3:

还需要注意的是NULL值,Go语言的intstring等类型不能接收数据库得NULL值,所以建表的时候必须加上NOT NULL DEFAULT ''或者DEFAULT 0,如果表里就是有NULL,那就需要标准库自带的sql.NullInt32sql.NullString,或者将字段定义为指针类型(Age *int),这样遇到NULL时,Go会自动赋值为nil

修改service

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
package service

import (
"context"
"log"
"time"

"dev.net.cn/goweb/internal/repository"
"dev.net.cn/goweb/model"
)

type UserService struct {
userRepo *repository.UserRepository
}

func NewUserService(userRepo *repository.UserRepository) *UserService {
return &UserService{
userRepo: userRepo,
}
}

type CreateUserInput struct {
Name string
Email string
Age int
Password string
}

func (s *UserService) CreateUser(ctx context.Context, input CreateUserInput) (*model.User, error) {
user := &model.User{
Name: input.Name,
Email: input.Email,
Age: input.Age,
Password: input.Password,
Status: model.UserStatusActive,
CreateAt: time.Now(),
UpdateAt: time.Now(),
}
log.Println("service ....")
//调用repository
return s.userRepo.Create(ctx, user)
}

func (s *UserService) FindUserByID(ctx context.Context, id int) (*model.User, error) {
user, err := s.userRepo.FindByID(ctx, id)
if err != nil {
log.Fatalf("查询用户失败: %v", err)
return nil, err
}
return user, nil
}

func (s *UserService) ListUser(ctx context.Context) ([]*model.User, error) {
users, err := s.userRepo.ListUser(ctx)
if err != nil {
return nil, err
}
return users, nil
}

修改handler

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package handler

import (
"strconv"
"strings"

"dev.net.cn/goweb/errs"
"dev.net.cn/goweb/internal/service"
"dev.net.cn/goweb/model"
"dev.net.cn/goweb/response"
"github.com/gin-gonic/gin"
)

type UserHandler struct {
userService *service.UserService
}

func NewUserHandler(userService *service.UserService) *UserHandler {
return &UserHandler{
userService: userService,
}
}

func (h *UserHandler) GetUserListHandler(c *gin.Context) {
page, err := strconv.Atoi(c.DefaultQuery("page", "1"))
if err != nil {
response.BadRequest(c, err.Error())
return
}
pageSize, err := strconv.Atoi(c.DefaultQuery("page_size", "10"))
if err != nil {
response.BadRequest(c, err.Error())
return
}

users, err := h.userService.ListUser(c.Request.Context())
if err != nil {
response.HandleError(c, err)
return
}

response.PageOK(c, users, page, pageSize)
}

func (h *UserHandler) GetUserInfoByIdHandler(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
response.BadRequest(c, "id must be number")
return
}
if id < 0 {
response.BadRequest(c, "id is null")
return
}
user, err := h.userService.FindUserByID(c.Request.Context(), id)

response.OK(c, user)
}

func (h *UserHandler) CreateUserHandler(c *gin.Context) {
var user model.CreateUserReq
if err := c.ShouldBindJSON(&user); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := validateEmail(user.Email); err != nil {
response.HandleError(c, err)
return
}

user1, err := h.userService.CreateUser(c.Request.Context(), service.CreateUserInput{
Name: user.Name,
Age: user.Age,
Email: user.Email,
Password: user.Password,
})

if err != nil {
response.BadRequest(c, err.Error())
return
}

res := model.UserResponse{
ID: 1,
Name: user1.Name,
Age: user1.Age,
Email: user1.Email,
}
response.OK(c, res)
}

func UpdateUserHandler(c *gin.Context) {
id := c.Param("id")
if id == "" {
response.BadRequest(c, "id is null")
return
}
// 模拟根据ID获取用户

var user model.CreateUserReq

if err := c.ShouldBindJSON(&user); err != nil {
response.BadRequest(c, err.Error())
}

res := model.UserResponse{
ID: 1,
Name: user.Name,
Age: user.Age,
Email: user.Email,
}
response.OK(c, res)

}

func DeleteUserByIdHandler(c *gin.Context) {
id := c.Param("id")
if id == "" {
response.BadRequest(c, "id is null")
return
}

response.OK(c, "用户已被删除")
}

// FormLogin /*
func FormLogin(c *gin.Context) {
username := c.PostForm("username")
password := c.PostForm("password")

response.OK(c, gin.H{
"username": username,
"password": password,
})
}

func validateEmail(email string) error {
if strings.HasSuffix(email, "@qq.com") {
return nil
}
return errs.ErrEmailExists
}

修改router.go

修改router.go

1
2
3
4
5
6
7
8
9
10
userRepo := repository.NewUserRepository(db)

d1Group := r.Group("/api/d1/")
{
d1Group.GET("/users", userHandler.GetUserListHandler)
d1Group.GET("/users/:id", userHandler.GetUserInfoByIdHandler)
d1Group.POST("/users", userHandler.CreateUserHandler)
d1Group.PUT("/users/:id", handler.UpdateUserHandler)
d1Group.DELETE("users/:id", handler.DeleteUserByIdHandler)
}

修改main.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
dsn := "root:123456@tcp(127.0.0.1:3306)/goweb?charset=utf8mb4&parseTime=True&loc=Local"

dbInstance, err := db.NewDB(dsn)
if err != nil {
log.Fatalf("数据库连接失败: %v", err)
}

log.Println("数据库连接池初始化成功...")
defer func() {
err := dbInstance.Close()
if err != nil {
log.Fatalf("数据库关闭失败: %v", err)
}
}()
r := routers.SetupRouter(dbInstance)

GORM

应该是类似于Java的Hibernate

安装依赖

1
2
go get gorm.io/gorm
go get gorm.io/driver/mysql

修改Model

model/user_dto.go

1
2
3
4
5
6
7
8
9
10
type User struct {
ID int `gorm:"primaryKey"`
Name string `gorm:"size:50;not null"`
Email string `gorm:"size:100;not null"`
Age int `gorm:"not null"`
Password string `gorm:"size:255;not null"`
Status string `gorm:"size:20;not null"`
CreateAt time.Time
UpdateAt time.Time
}

修改repository

internal/repository/user_repository

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// 实际用下来,应该新建一个user_gorm_resp.go好一点
func NewUserRepository(db *sql.DB) *UserRepository {
return &UserRepository{
users: make(map[int]*model.User),
db: db,
}
}

type GormUserRepository struct {
db *gorm.DB
}

func NewGormUserRepository(db *gorm.DB) *GormUserRepository {
return &GormUserRepository{db: db}
}

func (r *GormUserRepository) Create(ctx context.Context, user *model.User) (*model.User, error) {
if err := r.db.WithContext(ctx).Create(user).Error; err != nil {
return nil, err
}
return user, nil
}

func (r *GormUserRepository) FindByID(ctx context.Context, id int) (*model.User, error) {
var user model.User

err := r.db.WithContext(ctx).First(&user, id).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errs.ErrUserNotFound
}
return nil, err
}
return &user, nil
}

修改service

internal/service/user_service.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
type UserService struct {
userRepo *repository.UserRepository
userGormRepo *repository.GormUserRepository
}

func NewUserService(userRepo *repository.UserRepository, gormRepo *repository.GormUserRepository) *UserService {
return &UserService{
userRepo: userRepo,
userGormRepo: gormRepo,
}
}

func (s *UserService) FindUserByID(ctx context.Context, id int) (*model.User, error) {
//user, err := s.userRepo.FindByID(ctx, id)
user, err := s.userGormRepo.FindByID(ctx, id)
if err != nil {
log.Fatalf("查询用户失败: %v", err)
return nil, err
}
return user, nil
}

修改db.go

pkg/db/db.go

1
2
3
4
5
6
7
8
9
10
11
12
13
// 新增
func NewGormDB(dsn string) (*gorm.DB, error) {
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
return nil, err
}

if err := db.AutoMigrate(&model.User{}); err != nil {
return nil, err
}

return db, nil
}

修改router.go

routers/router.go

1
2
userGormRepo := repository.NewGormUserRepository(gorm)
userService := service.NewUserService(userRepo, userGormRepo)

方法签名修改为

1
func SetupRouter(db *sql.DB, gorm *gorm.DB) *gin.Engine

修改main.go

1
2
3
4
5
gormInstance, err := db.NewGormDB(dsn)
if err != nil {
log.Fatalf("Gorm数据库连接失败: %v", err)
}
r := routers.SetupRouter(dbInstance, gormInstance)

使用GORM可以明显感觉到开发效率高很多,不需要自己去写SQL,然后做映射,像Hibernate那样,CRUD很快自动迁移方便关联查询方便Hook / Transaction / Preload功能完整

当然有优点,那就有缺点:

  • 复杂SQL可读性差
  • 需要理解ORM生成的SQL
  • 很难做到性能调优

选择database/sql的理由:

  • 需要自己掌控SQL
  • 需要性能调优
  • 需要复杂查询

除此之外,GORM开发效率更高。

事务

谈到DB,就绕不开事务。

database/sql事务

在Go语言中,事物的生命周期由sql.Tx对象管理。一个标准得事务包含以下四个阶段:

  1. 开启事务:调用db.BeginTx(ctx,opts)。此时,底层连接池会独占一格固定的TCP连接,并向MySQL发送BEGIN指令。
  2. 执行业务:后续所有的SQL语句(CRUD)必须调用tx.ExecContext或者tx.QueryContext,而不是db.ExceContext
  3. 异常回滚:如果中途任何一步报错,或者程序发生Panic,必须执行tx.Rollback()
  4. 成功提交: 所有步骤完美通过后,执行tx.Commit(),底层连接释放,重归连接池。

下面以一个比较通用的事务模板代码为例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
package repository

import (
"context"
"database/sql"
"errors"
"log"
)

type AccountRepository struct {
db *sql.DB
}

func (r *AccountRepository) Transfer(ctx context.Context, fromID, toID int, amount float64) (err error) {
// 开启事务(可以传入Isolation等级配置,默认使用的是Read Commited / Repeatable Read)
tx, err := r.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelDefault})
if err != nil {
return err
}

// 利用defer确保事务绝对不会死锁
// 如果程序中图Panic奔溃了,或者函数因报错提前return,defer会自动触发Rollback
// 如果后面成功执行了 tx.Commit(),tx.Rollback() 会安全得返回sql.ErrTxDone错误,底层会自动忽略。
defer func() {
if p := recover(); p != nil {
// 发生Panic,强制回滚
_ = tx.Rollback()
// 继续排除Panic 让上层拦截
panic(p)
} else if err != nil {
// 发送业务错误,回滚事务
_ = tx.Rollback()
log.Printf("事务回滚成功,原因:%v", err)
}
}()

// 必须使用 tx.ExceContext,绝对不能用 r.db.ExecContext,否则就不是在同一个事务里了
// 扣钱
query1 := `UPDATE accounts SET balance=balance - ? WHERE id = ? AND balance >= ?`
res1, err := tx.ExecContext(ctx, query1, amount, fromID, amount)
if err != nil {
// err被赋值,出发defer中的 Rollback
return err
}

rowsAffected1, _ := res1.RowsAffected()
if rowsAffected1 == 0 {
err = errors.New("扣款失败,余额不足或账户不存在")
return err
}
// 执行加钱
query2 := `UPDATE accounts SET balance=balance + ? WHERE id = ?`
res2, err := tx.ExecContext(ctx, query2, amount, toID)
if err != nil {
return err
}
rowsAffected2, _ := res2.RowsAffected()
if rowsAffected2 == 0 {
err = errors.New("收款账号不存在")
return err
}
// 所有SQL都成功的话,他显示提交事务
// 只有这一步成功,MySQL才会真正持久化数据
if err = tx.Commit(); err != nil {
return err
}
return nil
}

对于事务,需要特别注意如下三个问题:

避免混用dbtx:

开启事务后,中间得更新语句如果写成r.db.ExecContext(),那么就会导致r.db从数据库连接池再去获取一个新的连接,这条语句就在事务之外独立运行。而tx连接依然在等待。不仅破坏了原子性,还极其容易导致数据库死锁(Deadlock),因为两个TCP连接可能在互相等待对方释放同一行数据得锁。

忘记写defer tx.Rollback()导致连接池枯竭:

如果执行中途报错,你没有写defer,也没有在if err != nil里手动调用Rollback(),那么这个事务在MySQL里保持PENDING状态,它占用得TCP连接永远不会释放。高并发下,几秒钟就能把项目的数据库连接池(SetMaxOpenConns)全部耗尽,导致整个系统瘫痪。

在事务里操作其他事情

事务开启后,MySQL会对相关的行加排他锁(X锁)。如果中间请求了第三方网络接口,就意味着请求的时间就加到MySQL的锁的时间。所以,事务内只做纯粹的数据库增删改查,其他的操作必须在开启事务之前。或者提交事务之后。

GORM事务

GORM提供了两种事务处理机制:自动事务(闭包机制:推荐)手动事务。可以类比于Spring框架中的@Transcational注解。

自动事务

自动事务的核心规则:

  • 如果闭包函数返回了nil,GORM会自动提交事务。
  • 如果闭包函数返回了任何error,或者内部发生了Panic,GORM会自动回滚事务,并把对应的错误或者Panic往外抛。

下面是一段比较模板化的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
type Profile struct {
UserID int `gorm:"uniqueIndex"`
Hobby string
}

func (r *GormUserRepository) RegisterUser(ctx context.Context, name, hobby string) error {
// 开启事务
// 在闭包内部,所有数据库操作必须使用参数传递进来的tx,不能使用 r.db
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 创建用户
user := model.User{Name: name, Status: model.UserStatusActive}

if err := tx.Create(&user).Error; err != nil {
// 返回错误,GORM自动回滚
return err
}
// 模拟业务不合规
if name == "hack" {
// 自动回滚
return errors.New("违规用户名,拒绝注册")
}
// 业务步骤2
profile := Profile{UserID: user.ID, Hobby: hobby}
if err := tx.Create(&profile).Error; err != nil {
// 返回错误,自动回滚
return err
}
// 返回nil,GORM底层挥拳自动执行commit
return nil
})
// 闭包完成后,外界能拿到最终的错误状态
return err
}

手动事务

适合精细化控制的方式,如果不想在闭包、或者事务链路非常长、需要根据复杂的跨系统业务接过来人为决定何时提交,GORM也保留了类似原生database/sql的手动事务机制。

手动事务的核心步骤:

  1. 开启:tx := db.Begin()
  2. 回滚:tx.Rollback()
  3. 提交:tx.Commit()

下面是一段手动事务的模板代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
func (r *GormUserRepository) RegisterUserManual(ctx context.Context, name string) (err error) {
// 开启事务
tx := r.db.WithContext(ctx).Begin()
if tx.Error != nil {
return tx.Error
}
// defer 拦截panic和Error
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if err != nil {
tx.Rollback()
}
}()
// 执行业务
user := model.User{Name: name}
if err = tx.Create(&user).Error; err != nil {
return err
}
// 提交事务
if err = tx.Commit().Error; err != nil {
return err
}
return nil
}

全局禁用事务

GORM为了保证单条增删改语句的数据安全,默认情况下哪怕只调用一次db.Create(&user),它也会在底层隐式的开启BEGINCOMMIT

  • 如果追求极致的并发性能,可以通过在初始化数据库时,配置SkipDefaultTransaction来关掉这个隐式特性。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
func NewGormDB(dsn string) (*gorm.DB, error) {
//db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
SkipDefaultTransaction: true,
})
if err != nil {
return nil, err
}

if err := db.AutoMigrate(&model.User{}); err != nil {
return nil, err
}

return db, nil
}

事务传递

我查了下GORM不支持像Spring那样的事务传递(@Transactional(propagation=Propagation.REQUIRED)),但也可以依靠*gorm.DB上下文对象的显示传递和Transaction闭包机制,可以实现媲美于Spring的事务传播效果。

下面模拟一个PROPAGATION_REQUIRED

internal/service/user_service.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
func (s *UserService) UpdateUser(ctx context.Context, u model.User) (err error) {
// 开启外部事务,记得讲userGormRepo.Db大写,导出。
return s.userGormRepo.Db.Transaction(func(tx *gorm.DB) error {
// 调用更新,讲tx显示传递过去
if err := s.userGormRepo.UpdateInfo(ctx, tx, &u); err != nil {
// 外层回滚
return err
}
// 调用更新,讲tx显示传递过去(模拟嵌套事务)
if err := s.userGormRepo.Rename(ctx, tx, "Jerry", 1); err != nil {
// 子事务报错,外层也会一起回滚
return err
}

return nil
})
}

internal/repository/user_repository.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
func (r *GormUserRepository) UpdateInfo(ctx context.Context, tx *gorm.DB, user *model.User) error {
// 这里必须再次调用tx.Transaction() ,并且建议加上WithContext(ctx)
// 如果外部传入的tx已经是一个十五,GORM底层不会发送BEGIN,而是发送SAVEPOINT gorm_sp_1
return tx.WithContext(ctx).Transaction(func(nestedTX *gorm.DB) error {
updateSql := "update users set name=? where id=?"
if err := nestedTX.Exec(updateSql, user.Name, user.ID).Error; err != nil {
// 嵌套子十五失败,自动回滚到SAVEPOINT
return err
}
// 子事务成功,释放当前SAVEPOINT
return nil
})
}

func (r *GormUserRepository) Rename(ctx context.Context, tx *gorm.DB, name string, id int) error {
return tx.WithContext(ctx).Transaction(func(nestedTx *gorm.DB) error {
updateSql := "update users set name=? where id=?"
if err := nestedTx.Exec(updateSql, name, id).Error; err != nil {
return err
}
return nil
})
}