go语言gorm实战——工具方法

前言:
一个go项目过后,有许多的精华部分、容易违犯go语言常规的部分,总结、记录,以供后续的学习、参考和大家探讨。

1 启动类

跨域和权限中间件的使用,在router中

package main

import (
	"fmt"
	"github.com/gin-gonic/gin"
	"imcs-designer/controller"
	"imcs-designer/middleware"
	"imcs-designer/model"
	"imcs-designer/utils/config"
	"imcs-designer/utils/orm"
	"log"
)

func main() {
	createTable()

	route := gin.Default()
	route.Use(middleware.Cors()) // 跨域
	port := fmt.Sprintf(":%d", config.GetInt("application.http-server-port"))
	router(route)
	route.Run(port)
}

func router(router *gin.Engine) {
	router.POST("/user/register", controller.UserRegister)
	router.POST("/user/login", controller.UserLogin)
	router.Use(middleware.Authorize()) // 权限判断
	router.GET("/user/list", controller.UserList)
	router.POST("/user/logout", controller.UserLogout)
	formRouter(router)
	menuRouter(router)
	buildRouter(router)
	ruleRouter(router)
	projectRouter(router)
}

2 解析yml

package config

import (
	"fmt"
	"github.com/spf13/viper"
	"os"
	"regexp"
	"strconv"
	"strings"
)

func init() {
	viper.SetConfigName("application")
	viper.AddConfigPath(".")
	viper.AddConfigPath("../")
	viper.AddConfigPath("../../")
	viper.AddConfigPath("../../../")
	viper.AddConfigPath("../../../../")
	err := viper.ReadInConfig()
	if err != nil { // Handle errors reading the config file
		panic(fmt.Errorf("Fatal error config file: %s \n", err))
	}
}

func GetString(key string) string {
	value := viper.GetString(key)
	//判断是否需要从环境变量读取
	exp := regexp.MustCompile(`\${(.*?)}`)
	results := exp.FindAllStringSubmatch(value, -1)
	if len(results) > 0 { //需要从环境变量读取
		for _, result := range results {
			defaultValue := ""
			valueArray := strings.Split(result[1], ":")
			if len(valueArray) >= 2 { //有默认值
				defaultValue = strings.Join(valueArray[1:], ":")
			}
			if len(valueArray) > 0 {
				//获取环境变量
				if env := os.Getenv(valueArray[0]); env != "" {
					defaultValue = env
				}
			}
			value = strings.Replace(value, result[0], defaultValue, -1)
		}
	}
	return value
}

func GetInt(key string) int {
	valueStr := GetString(key)
	value, err := strconv.Atoi(valueStr)
	if err != nil {
		fmt.Println(err)
	}
	return value
}

func GetInt64(key string) int64 {
	valueStr := GetString(key)
	value, err := strconv.ParseInt(valueStr, 10, 64)
	if err != nil {
		fmt.Println(err)
	}
	return value
}

func GetInt32(key string) int32 {
	valueStr := GetString(key)
	value, err := strconv.ParseInt(valueStr, 10, 32)
	if err != nil {
		fmt.Println(err)
	}
	return int32(value)
}

func GetFloat64(key string) float64 {
	valueStr := GetString(key)
	value, err := strconv.ParseFloat(valueStr, 64)
	if err != nil {
		fmt.Println(err)
	}
	return value
}

func GetFloat32(key string) float32 {
	valueStr := GetString(key)
	value, err := strconv.ParseFloat(valueStr, 32)
	if err != nil {
		fmt.Println(err)
	}
	return float32(value)
}

func GetBool(key string) bool {
	valueStr := GetString(key)
	if strings.ToUpper(valueStr) == "TRUE" {
		return true
	}
	return false
}

3 跨域

package middleware

import (
	"fmt"
	"github.com/gin-gonic/gin"
	"net/http"
	"strings"
)

func Cors() gin.HandlerFunc {
	return func(c *gin.Context) {
		method := c.Request.Method               //请求方法
		origin := c.Request.Header.Get("Origin") //请求头部
		var headerKeys []string                  // 声明请求头keys
		for k, _ := range c.Request.Header {
			headerKeys = append(headerKeys, k)
		}
		headerStr := strings.Join(headerKeys, ", ")
		if headerStr != "" {
			headerStr = fmt.Sprintf("access-control-allow-origin, access-control-allow-headers, %s", headerStr)
		} else {
			headerStr = "access-control-allow-origin, access-control-allow-headers"
		}
		if origin != "" {
			c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
			c.Header("Access-Control-Allow-Origin", "*")                                       // 这是允许访问所有域
			c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE,UPDATE") //服务器支持的所有跨域请求的方法,为了避免浏览次请求的多次'预检'请求
			//  header的类型
			c.Header("Access-Control-Allow-Headers", "*")
			//              允许跨域设置                                                                                                      可以返回其他子段
			c.Header("Access-Control-Expose-Headers", "*")        // 跨域关键设置 让浏览器可以解析
			c.Header("Access-Control-Max-Age", "172800")          // 缓存请求信息 单位为秒
			c.Header("Access-Control-Allow-Credentials", "false") //  跨域请求是否需要带cookie信息 默认设置为true
			c.Set("content-type", "application/json")             // 设置返回格式是json
		}

		//放行所有OPTIONS方法
		if method == "OPTIONS" {
			c.JSON(http.StatusOK, "Options Request!")
		}
		// 处理请求
		c.Next() //  处理请求
	}
}

4 权限设置

package middleware

import (
	"github.com/gin-gonic/gin"
	"imcs-designer/response"
	"imcs-designer/utils/redis"
	"net/http"
)

func Authorize() gin.HandlerFunc {
	return func(c *gin.Context) {
		token := c.GetHeader("token")
		if token == "" {
			c.Abort()
			c.JSON(http.StatusUnauthorized, response.NewBaseResponse(response.ResponseCodeUnauthorized))
			return
		}
		userID := c.GetHeader("user_id")
		if userID == "" {
			c.Abort()
			c.JSON(http.StatusBadRequest, response.NewBaseResponse(response.ResponseCodeBadRequest))
			return
		}

		localUser, err := redis.Store.GetString("token:" + token)
		if err != nil || localUser != userID {
			c.Abort()
			c.JSON(http.StatusForbidden, response.NewBaseResponse(response.ResponseCodeForbidden))
			return
		}
		c.Next()
	}
}

5 创建数据库连接

package orm

import (
	"github.com/jinzhu/gorm"
	_ "github.com/jinzhu/gorm/dialects/mysql"
	"imcs-designer/utils/config"
)

var gdb *gorm.DB

func CreateDB() (*gorm.DB, error) {
	if gdb == nil {
		driver := config.GetString("database.driver")
		dsn := ""
		switch driver {
		case "mysql":
			dsn = config.GetString("database.mysql-dsn")
		case "postgres":
			dsn = config.GetString("database.postgres-dsn")
		case "mssql":
			dsn = config.GetString("database.mssql-dsn")
		default:
			dsn = config.GetString("database.mysql-dsn")
		}
		db, err := gorm.Open(driver, dsn)
		if err != nil {
			return nil, err
		}
		db.DB().SetMaxIdleConns(config.GetInt("database.max-idle-conns"))
		db.DB().SetMaxOpenConns(config.GetInt("database.max-open-conns"))
		db.LogMode(config.GetBool("application.debug"))
		gdb = db
	}
	return gdb, nil
}

func CreateTx() (*gorm.DB, error) {
	db, err := CreateDB()
	if err != nil {
		return nil, err
	}
	tx := db.Begin()
	return tx, nil
}

func AutoMigrate(models ...interface{}) error {
	driver := config.GetString("database.driver")

	dbConn, err := CreateDB()
	if err != nil {
		return err
	}

	if driver == "mysql" {
		dbConn = dbConn.Set("gorm:table_options", "ENGINE=InnoDB DEFAULT CHARSET = utf8mb4")
	}
	return dbConn.AutoMigrate(models...).Error
}

package orm

import (
	"github.com/jinzhu/gorm"
	"imcs-designer/utils/config"
)

func init() {
	gorm.DefaultTableNameHandler = func(db *gorm.DB, defaultTableName string) string {
		prefix := config.GetString("database.table-prefix")
		if prefix != "" {
			return prefix + "_" + defaultTableName
		}
		return defaultTableName
	}
}

6 配置路由


func router(router *gin.Engine) {
	router.POST("/user/register", controller.UserRegister)
	router.POST("/user/login", controller.UserLogin)
	router.Use(middleware.Authorize())
	router.GET("/user/list", controller.UserList)
	router.POST("/user/logout", controller.UserLogout)
	formRouter(router)
	menuRouter(router)
	buildRouter(router)
	ruleRouter(router)
	projectRouter(router)
}

func projectRouter(router *gin.Engine) {
	router.GET("/projects", controller.ProjectList)
	router.GET("/project/icon", controller.ProjectIcon)
	router.POST("/project", controller.ProjectCreate)
	router.PUT("/project", controller.ProjectEdit)
	router.DELETE("/project", controller.ProjectDelete)
	router.PUT("/project/upgrade", controller.ProjectUpgrade)
}

7 ID生成器

package utils

import (
	"github.com/satori/go.uuid"
	"strings"
)

//生成UUID,去掉中间"-"
func UUID() string {
	id := uuid.NewV4().String()
	return strings.ReplaceAll(id, "-", "")
}

8 token 生成器

package utils

import "imcs-designer/utils/redis"

//生成token
func TokenGen(userID string) string {
	//生成UUID,存入redis
	uuid := UUID()
	err := redis.Store.Set("token:"+uuid, userID)
	if err != nil {
		return ""
	}
	return uuid
}


9 MD5加密

package utils

import (
	"crypto/md5"
	"encoding/hex"
	"strings"
)

//MD5加密,加密两次并转为大写
func MD5(key string) string {

	for i := 0; i < 2; i++ {
		hash := md5.New()
		hash.Write([]byte(key))
		key = strings.ToUpper(hex.EncodeToString(hash.Sum(nil)))
	}

	return key
}

10 redis 工具

package redis

import (
	"fmt"
	"github.com/gomodule/redigo/redis"
	"imcs-designer/utils/config"
	"time"
)

var Store RedisDataStore

func init() {
	server := config.GetString("redis.server")
	db := config.GetString("redis.db")
	password := config.GetString("redis.password")
	timeout := config.GetInt64("redis.timeout")
	prefix := config.GetString("redis.prefix")
	applicationName := config.GetString("application.name")

	Store = RedisDataStore{
		RedisHost: server,
		RedisDB:   db,
		RedisPwd:  password,
		Timeout:   timeout,
		RedisPool: nil,
		Prefix:    fmt.Sprintf("%s:%s", prefix, applicationName),
	}
	Store.RedisPool = Store.NewPool()
}

type RedisDataStore struct {
	RedisHost string
	RedisDB   string
	RedisPwd  string
	Timeout   int64
	RedisPool *redis.Pool

	Prefix string
}

func (r *RedisDataStore) NewPool() *redis.Pool {

	return &redis.Pool{
		Dial:        r.RedisConnect,
		MaxIdle:     10,
		MaxActive:   0,
		IdleTimeout: 1 * time.Second,
		Wait:        true,
	}
}

func (r *RedisDataStore) RedisConnect() (redis.Conn, error) {
	c, err := redis.Dial("tcp", r.RedisHost)
	if err != nil {
		return nil, err
	}
	_, err = c.Do("AUTH", r.RedisPwd)
	if err != nil {
		return nil, err
	}
	_, err = c.Do("SELECT", r.RedisDB)
	if err != nil {
		return nil, err
	}

	redis.DialConnectTimeout(time.Duration(r.Timeout) * time.Second)
	redis.DialReadTimeout(time.Duration(r.Timeout) * time.Second)
	redis.DialWriteTimeout(time.Duration(r.Timeout) * time.Second)

	return c, nil
}

func (r *RedisDataStore) Get(k string) (interface{}, error) {
	c := r.RedisPool.Get()
	defer c.Close()
	if r.Prefix != "" {
		k = r.Prefix + ":" + k
	}
	v, err := c.Do("GET", k)
	if err != nil {
		return nil, err
	}
	return v, nil
}

func (r *RedisDataStore) GetString(k string) (string, error) {
	return redis.String(r.Get(k))
}

func (r *RedisDataStore) GetInt(k string) (int, error) {
	return redis.Int(r.Get(k))
}

func (r *RedisDataStore) GetInt64(k string) (int64, error) {
	return redis.Int64(r.Get(k))
}

func (r *RedisDataStore) GetFloat64(k string) (float64, error) {
	return redis.Float64(r.Get(k))
}

func (r *RedisDataStore) GetBool(k string) (bool, error) {
	return redis.Bool(r.Get(k))
}

func (r *RedisDataStore) Set(k string, v interface{}) error {
	//data, err := json.Marshal(v)
	//if errors.Error(err) {
	//	return err
	//}
	if r.Prefix != "" {
		k = r.Prefix + ":" + k
	}
	c := r.RedisPool.Get()
	defer c.Close()
	_, err := c.Do("SET", k, v)
	return err
}

func (r *RedisDataStore) SetEx(k string, v interface{}, ex int64) error {
	if r.Prefix != "" {
		k = r.Prefix + ":" + k
	}
	c := r.RedisPool.Get()
	defer c.Close()
	_, err := c.Do("SET", k, v, "EX", ex)
	return err
}

func (r *RedisDataStore) Del(k string) error {
	c := r.RedisPool.Get()
	defer c.Close()
	if r.Prefix != "" {
		k = r.Prefix + ":" + k
	}
	_, err := c.Do("DEL", k)
	if err != nil {
		return err
	}
	return nil
}
发布了93 篇原创文章 · 获赞 20 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/leinminna/article/details/104589170