Go 中使用 gorm 适配自定义数据库驱动
GORM 官方支持的数据库类型有:MySQL, PostgreSQL, SQLite, SQL Server 和 TiD。但是我们有的时候需要使用 gorm 接入一些其他自定义的数据库,例如: Oracle 或者 Yashan 。
在本文中,我们将介绍如何在 Go 中使用 gorm 这个流行的 ORM 框架来连接 Yashan 数据库,并进行一些基本的增删改查操作。事实上这个方法适用于任何一个你想适配的数据库,如果其官方未适配 gorm 的话。
前提环境
在开始之前,我们需要准备以下内容:
- 一台安装了 Yashan 数据库的服务器,以及一个可以访问的数据库用户和密码。在本文中,我们假设服务器的 IP 地址是
192.168.1.100
,端口号是1688
,数据库名是yasdb
,用户名是sys
,密码是123456
。 - 一台安装了 Go 的开发环境,以及设置好了
GOPATH
和GOROOT
环境变量。在本文中,我们假设 Go 的版本是1.20
,并且使用了go mod
来管理依赖包。 - 一个可以编写和运行 Go 代码的编辑器或 IDE。
安装 gorm 和驱动
要使用 gorm 来链接 Yashan 数据库,我们需要安装 gorm 本身,以及一个适用于 Yashan 的驱动。由于当前 Yashan 官网上未给出 Go 语言的相关驱动,我们使用官网提供的 C 驱动然后利用 go 调用 C 实现。
go get -u gorm.io/gorm
执行上述命令后,我们在 Go 项目中成功安装了 gorm。
连接数据库
首先我们来看一下 gorm 内部支持的数据库是如何连接的。
main.go
pacakge main
import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
dsn := "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"
db, err := gorm.Open ( postgres.Open ( dsn ) , &gorm.Config{} )
从 gorm 官方文档-Open 中可以看到,func Open
接受一个 Dialector 对象。而我们现在需要做的就是根据 Yashan 数据库的 C 驱动,实现 Dialector
对象相关的所有接口即可。
dialect.go
package yasdb
import (
"database/sql"
"fmt"
"github.com/thoas/go-funk"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
"regexp"
"strconv"
"strings"
)
const (
driverType = "yasdb"
)
type Dialector struct {
*Config
}
func ( d Dialector ) DummyTableName ( ) string {
return "DUAL"
}
type Config struct {
DriverName string
DSN string
PreferSimpleProtocol bool
WithoutReturning bool
Conn gorm.ConnPool
DefaultStringSize uint
}
func Open ( dsn string ) gorm.Dialector {
return &Dialector{&Config{DSN: dsn, DefaultStringSize: 255}}
}
func ( d Dialector ) Name ( ) string {
return driverType
}
func ( d Dialector ) Initialize ( db *gorm.DB ) ( err error ) {
// register callbacks
if !d.WithoutReturning {
callbacks.RegisterDefaultCallbacks ( db, &callbacks.Config{} )
}
db.ConnPool, err = sql.Open ( d.Name ( ) , d.Config.DSN )
if err != nil {
return
}
//if err = db.Callback ( ) .Create ( ) .Replace ( "gorm: create", Create ) ; err != nil {
// return
//}
//for k, v := range d.ClauseBuilders ( ) {
// db.ClauseBuilders [k] = v
//}
return
}
var numericPlaceholder = regexp.MustCompile ( `\$( \d+ )` )
type Migrator struct {
migrator.Migrator
}
func ( d Dialector ) Migrator ( db *gorm.DB ) gorm.Migrator {
return Migrator{
Migrator: migrator.Migrator{
Config: migrator.Config{
DB: db,
Dialector: d,
CreateIndexAfterCreateTable: true,
},
},
}
}
func ( d Dialector ) DefaultValueOf ( *schema.Field ) clause.Expression {
return clause.Expr{SQL: "VALUES ( DEFAULT ) "}
}
func ( d Dialector ) BindVarTo ( writer clause.Writer, stmt *gorm.Statement, v interface{} ) {
_, err := writer.WriteString ( ":" )
if err != nil {
return
}
_, err2 := writer.WriteString ( strconv.Itoa ( len ( stmt.Vars )) )
if err2 != nil {
return
}
}
func ( d Dialector ) QuoteTo ( writer clause.Writer, str string ) {
var (
underQuoted, selfQuoted bool
continuousBacktick int8
shiftDelimiter int8
)
for _, v := range []byte ( str ) {
switch v {
case '"':
continuousBacktick++
if continuousBacktick == 2 {
_, err := writer.WriteString ( `""` )
if err != nil {
return
}
continuousBacktick = 0
}
case '.':
if continuousBacktick > 0 || !selfQuoted {
shiftDelimiter = 0
underQuoted = false
continuousBacktick = 0
err := writer.WriteByte ( '"' )
if err != nil {
return
}
}
err := writer.WriteByte ( v )
if err != nil {
return
}
continue
default:
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
err := writer.WriteByte ( '"' )
if err != nil {
return
}
underQuoted = true
if selfQuoted = continuousBacktick > 0; selfQuoted {
continuousBacktick -= 1
}
}
for ; continuousBacktick > 0; continuousBacktick -= 1 {
_, err := writer.WriteString ( `""` )
if err != nil {
return
}
}
err := writer.WriteByte ( v )
if err != nil {
return
}
}
shiftDelimiter++
}
if continuousBacktick > 0 && !selfQuoted {
_, err := writer.WriteString ( `""` )
if err != nil {
return
}
}
err := writer.WriteByte ( '"' )
if err != nil {
return
}
}
func ( d Dialector ) Explain ( sql string, vars ...interface{} ) string {
return logger.ExplainSQL ( sql, numericPlaceholder, `'`, funk.Map ( vars, func ( v interface{} ) interface{} {
switch v := v. ( type ) {
case bool:
if v {
return 1
}
return 0
default:
return v
}
}) . ( []interface{} ) ... )
}
func ( d Dialector ) DataTypeOf ( field *schema.Field ) string {
if _, found := field.TagSettings["RESTRICT"]; found {
delete ( field.TagSettings, "RESTRICT" )
}
var sqlType string
switch field.DataType {
case schema.Bool, schema.Int, schema.Uint, schema.Float:
sqlType = "INTEGER"
switch {
case field.DataType == schema.Float:
sqlType = "FLOAT"
case field.Size <= 8:
sqlType = "SMALLINT"
case field.Size >= 64:
sqlType = "BIGINT ( 8 ) "
}
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth ( val ) {
sqlType += " GENERATED BY DEFAULT AS IDENTITY"
}
case schema.String:
size := field.Size
defaultSize := d.DefaultStringSize
if size == 0 {
if defaultSize > 0 {
size = int ( defaultSize )
} else {
hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
// TEXT, GEOMETRY or JSON column can't have a default value
if field.PrimaryKey || field.HasDefaultValue || hasIndex {
size = 191 // utf8mb4
}
}
}
if size >= 2000 {
sqlType = "CLOB"
} else {
sqlType = fmt.Sprintf ( "VARCHAR2 ( %d ) ", size )
}
case schema.Time:
sqlType = "TIMESTAMP"
if field.NotNull || field.PrimaryKey {
sqlType += " NOT NULL"
}
case schema.Bytes:
sqlType = "BLOB"
default:
sqlType := string ( field.DataType )
if strings.EqualFold ( sqlType, "text" ) {
sqlType = "CLOB"
}
if sqlType == "" {
panic ( fmt.Sprintf ( "invalid sql type %s ( %s ) for oracle", field.FieldType.Name ( ) , field.FieldType.String ( )) )
}
notNull, _ := field.TagSettings["NOT NULL"]
unique, _ := field.TagSettings["UNIQUE"]
additionalType := fmt.Sprintf ( "%s %s", notNull, unique )
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = fmt.Sprintf ( "%s %s %s%s", "DEFAULT", value, additionalType, func ( ) string {
if value, ok := field.TagSettings["COMMENT"]; ok {
return " COMMENT " + value
}
return ""
} ( ))
}
sqlType = fmt.Sprintf ( "%v %v", sqlType, additionalType )
}
return sqlType
}
现在,我们即可以使用以下代码进行数据库的连接了。
main.go
package main
import (
"fmt"
// 此处需要引入驱动外部库
"gorm.io/gorm"
"strings"
)
const (
connectFormat = `%s/%s@%s:%s`
)
func InitYashan ( ) {
replacer := strings.NewReplacer ( "@", "\\@", "/", "\\/", "\\", "\\\\" )
datasource := config.Conf.Datasource
dsn := fmt.Sprintf ( connectFormat, replacer.Replace ( datasource.Username ) , replacer.Replace ( datasource.Password ) , datasource.Host, datasource.Port )
yasDB, err := gorm.Open ( Open ( dsn ) , &gorm.Config{} )
// 检查是否有错误
if err != nil {
fmt.Println ( "连接数据库失败: ", err )
panic ( err )
}
dataBaseModel, err := yasDB.DB ( )
if err != nil {
global.LOG.Error ( "连接数据库失败, error=" + err.Error ( ))
panic ( err )
}
dataBaseModel.SetMaxOpenConns ( datasource.MaxOpenConns )
dataBaseModel.SetMaxIdleConns ( datasource.MaxIdleConns )
}
贡献者
更新日志
2025/8/5 08:34
查看所有更新日志
eb6eb
-improve(docs): use pangu formatter于27a4a
-improve(docs): use chinese punctuation于fea7c
-improve(docs): delete extra whitespace and blank lines于3434c
-fix(docs): text typo于c1c02
-modify(docs): remanage folders and rename files于90e37
-docs: update docs于120cb
-整理tag于2d37d
-整理文章代码格式于c0924
-整理文章格式于4e139
-update于