Go语言实现ORM框架

一、项目结构

beego:请求访问控制器

controller:控制器结构体

models:定义需要映射建表的结构体

orm:实现models中结构体

routers:初始化路由入口

二、控制器接口设计

ControllerInterface:控制器接口,提供Get和Post函数。

Controller:公共控制器,实现ControllerInterface接口,并实现了Get和Post函数。

自定义Controller:继承Controller,根据业务需求处理不同请求。

上面三个Controller的关系如下所示:

自定义Controller(我们程序员写的)----------->Controller(公共Controller)------------->ControllerInterface(接口)

(1)定义ControllerInterface接口

// 控制器接口
type ControllerInterface interface {
	Get()
	Post()
}

(2)定义Controller结构体,并实现ControllerInterface的Get和Post方法。

// 定义一个控制器结构,实现ControllerInterface接口
type Controller struct {}

func (c *Controller) Get() {
	fmt.Println("GET空实现")
}

func (c *Controller) Post() {
	fmt.Println("Post空实现")
}

(3)定义一个LoginController和RegistController,继承Controller。

type LoginController struct {
	beego.Controller
}

func (this *LoginController) Get(){
	fmt.Println("LoginController get method is executing....")
}

func (this *LoginController) Post(){
	fmt.Println("LoginController post method is executing....")
}

// 自己定义方法,处理Post或者Get请求
func (this *LoginController) LoginGet(){
	fmt.Println("LoginController LoginGet method is executing...")
}

func (this *LoginController) LoginPost(){
	fmt.Println("LoginController LoginPost method is executing...")
}


type RegisterController struct {
	beego.Controller
}

func (this *RegisterController) Get(){
	fmt.Println("RegistController get method is executing....")
}

func (this *RegisterController) Post(){
	fmt.Println("RegistController post method is executing....")
}

func (this *RegisterController) RegistGet(){
	fmt.Println("RegistController RegistGet method is executing....")
}

三、路由设计

定义ControllerInfo结构体,用于存储路由信息 。

// 该结构体存储路由信息
type ControllerInfo struct {
	pattern        string                     // 存储请求url
	controllerType reflect.Type               // 存储控制器类型
	methods        map[string]string          // 存储请求方法
	initialize     func() ControllerInterface // 初始化方法
}

在Routers.go文件中定义init方法,程序启动时候调用该方法初始化路由信息。

func init() {
	//将自己编写的控制器,匹配到路由上
	beego.Router("/login",&controller.LoginController{},"get:LoginGet;post:LoginPost")
	beego.Router("/register",&controller.RegisterController{})
}

在main方法中导入routers包。

import (
	......
	_ "Project012/routers"
)

定义Router方法,把路由信息解析成ControllerInfo对象,并保存在上下文中。

// 该切片存储所有的路由信息
var routers []*ControllerInfo
var ctx context.Context // context代表​

// 定义路由方法
// 参数一:访问地址
// 参数二:控制器接口对象的地址
// 参数三:访问方式和方法的映射关系
func Router(pattern string, c ControllerInterface, mappingMethods ...string) {
	// 获取控制器接口的Value对象
	cValue := reflect.ValueOf(c)
	//获取控制器接口的Type对象
	cType := reflect.Indirect(cValue).Type()
	// 该map存储请求方式和方法的映射关系
	methods := make(map[string]string)
	if len(mappingMethods) > 0 {
		// 映射格式一:get:Login;post:HandleLogin
		mappings := strings.Split(mappingMethods[0], ";")
		for _, v := range mappings {
			mapping := strings.Split(v, ":")
			if len(mapping) != 2 {
				panic("method mapping format is invalid")
			}
			// 映射格式二:get,post:login
			requestMethods := strings.Split(mapping[0], ",")
			for _, requestMethod := range requestMethods {
				// 判断请求方式是否正确和映射方法是否存在,只有两个条件同时成立,则保存在methods中
				if requestMethod == "*" || HTTPMETHOD[strings.ToUpper(requestMethod)] {
					// 判断映射方法是否在Controller中存在
					if val := cValue.MethodByName(mapping[1]); val.IsValid() {
						// 把方法名保存在methods中
						methods[strings.ToUpper(requestMethod)] = mapping[1]
					} else {
						//异常
						panic("'" + mapping[1] + "' method doesn't exist in the controller " + cType.Name())
					}
				} else {
					panic(v + " is an invalid method mapping. Method doesn't exist " + requestMethod)
				}
			}
		}
	}
	// 创建ControllerInfo对象
	route := &ControllerInfo{}
	// 初始化ControllerInfo对象
	route.pattern = pattern
	route.methods = methods
	route.controllerType = cType
	route.initialize = func() ControllerInterface {
		// 使用放射创建Controller的Type对象
		o := reflect.New(route.controllerType)
		// 把o转换成ControllerInterface类型
		controllerInterface, ok := o.Interface().(ControllerInterface)
		// 判断控制器的类型是否是ControllerInterface类型,如果不是抛出异常
		if !ok {
			panic("controller is not ControllerInterface")
		}
		// ----------把控制器c中的数据赋值到控制器controllerInterface中---------
		// 获取传入控制器接口的值
		controllerValueEl := reflect.ValueOf(c).Elem()
		// 获取控制器的类型
		controllerTypeEl := reflect.TypeOf(c).Elem()
		// 获取新创建控制器的值
		controllerInterfaceValue := reflect.ValueOf(controllerInterface).Elem()
		// 获取控制器中字段个数
		numOfFields := controllerValueEl.NumField()
		for i := 0; i < numOfFields; i++ {
			// 根据索引获取字段类型
			fieldType := controllerTypeEl.Field(i)
			// 判断字段是否可以修改,只有该字段可以修改的情况下才设置字段的值
			elemField := controllerInterfaceValue.FieldByName(fieldType.Name)
			if elemField.CanSet() {
				// 设置字段的值
				fieldVal := controllerValueEl.Field(i) //赋值给elemField
				elemField.Set(fieldVal)
			}
		}
		return controllerInterface
	}
	routers = append(routers, route)
	// 将每个ControllerInfo中信息存储在routes中后,需要将routes中信息存储在上下文环境中,便于后续获取
	ctx = context.WithValue(context.Background(), "routers", routers)
}

​

定义请求结构体。

var (
	HTTPMETHOD = map[string]bool{
		"GET":  true,
		"POST": true,
	}
)

 

四、监听浏览器请求

(1)定义Run函数,项目启动时候调用该函数。

func Run() {
	//监听网络请求
	http.ListenAndServe(":8080", HttpHandler{})
}

(2)定义HttpHandler,该对象提供了ServeHTTP方法处理请求。

type HttpHandler struct {
	// 实现Handler接口
}

func (HttpHandler) ServeHTTP(respWriter http.ResponseWriter, request *http.Request) {
	//在此处去获取请求方式,请求链接,根据这2部分信息,匹配routes中的路由信息,调用相应方法
	//获取Get或者Post方法,拿到方法以后,首先匹配路由,然后看看此路由,是否重新定义Get或者Post方法
	method := request.Method     // 请求方式
	path := request.URL.Path // 请求路径
	// 过滤favicon.icon请求,google浏览器默认会请求这个图片,所以过滤掉
	if path != "/favicon.ico" {
		var controllerInterface ControllerInterface
		// 从上下文中获取routers
		routers := ctx.Value("routers").([]*ControllerInfo)
		if routers != nil {
			// 遍历routers
			for _, controllerInfo := range routers {
				// 判断请求路径是否匹配
				if controllerInfo.pattern == path {
					// 如果路径匹配,再判断initialize方法是否存在。如果存在则调用该方法执行初始化操作
					if controllerInfo.initialize != nil{
						controllerInterface = controllerInfo.initialize()
						// 遍历controllerInfo中的methods
						for key, value := range controllerInfo.methods {
							if key == method {
								method = value //获取控制器中处理请求的方法名
								break
							}
						}
						break
					}
				}
			}
		}
		//fmt.Println("controllerInterface = ", controllerInterface)
		switch method {
			case http.MethodGet:
				controllerInterface.Get()
			case http.MethodPost:
				controllerInterface.Post()
			default:
				// 调用控制器的处理方法
				vc := reflect.ValueOf(controllerInterface)
				method := vc.MethodByName(method)
				in := make([]reflect.Value, 0)
				method.Call(in)
		}
	}
}

(3)修改main函数。

func main() {
	beego.Run()
}

 

五、建表

(1)表设计

type User struct {
   Id       int
   Name     string `orm:"size(20)"` //用户名
   Age        int    `orm:"size(10)"`        //登陆密码
   Email    string `orm:"size(50)"`        //邮箱
}

type Goods struct { //商品SPU表
   Id        int
   Name   string `orm:"size(20)"`  //商品名称
   Detail     string `orm:"size(200)"` //详细描述
}

(2)在models包下定义init方法,程序启动时候自动调用该方法进行建表操作。

func init(){
   orm.RegisterDataBase("mysql", "root:root@tcp(127.0.0.1:3306)/user?charset=utf8")
   orm.RegisterModel(new(User),new(Goods))
}

(3)修改main.go文件,导入models包。

import (
	......
	_ "Project012/models"
)

(4)在orm包下定义RegisterDataBase和RegisterModel函数。

var db *sql.DB
var err error

// 连接数据库
func RegisterDataBase(driverName ,dataSourceName string) {
	db, err := sql.Open(driverName,dataSourceName)
	if err != nil {
		fmt.Println("数据库开启失败")
		return
	}
	errPing := db.Ping()
	if errPing!=nil {
		fmt.Println("数据库链接失败")
		return
	}
}

//在数据库中建表方法
func RegisterModel(models ...interface{}){
	//根据传递进来的多个结构体的对象,进行建表操作
	//1.因为models中有多个对象,所以在此处需要对models进行循环遍历,每遍历一次,就需要创建一个sql
	//2.根据遍历的sql,进行sql语句触发,相应进行建表操作
	sqlInfo := getModelsInfo(models)
	for _, sql := range sqlInfo {
		pstmt, err := db.Prepare(sql)
		if err != nil {
			return
		}
		pstmt.Exec()
	}
	defer db.Close()
}

var sqlStr []string

// 通过反射获取结构体中字段,拼接sql
func getModelsInfo(models []interface{}) []string {
	for index, model := range models {
		modelType := reflect.TypeOf(model) //*models.User结构体指针
		tableName := modelType.Elem().Name()
		sql := ""
		for i := 0; i < modelType.Elem().NumField(); i++ {
			curFieldName := modelType.Elem().Field(i).Name
			curFieldTag := modelType.Elem().Field(i).Tag
			curFieldType := modelType.Elem().Field(i).Type.String()
			tagValue := modelType.Elem().Field(i).Tag.Get("orm") // 获取指定名字的标签的内容

			if strings.ToLower(curFieldName) == "id" {
				//得到id
				sql = "CREATE TABLE "+strings.ToLower(tableName)+" (id int(10) unsigned NOT NULL AUTO_INCREMENT,"
				continue
			}
			sql = strings.Join([]string{sql,strings.ToLower(curFieldName)},"")

			split := strings.Split(tagValue, ";")
			sizeNumber := "10"
			fileType := "undefined Type"
			for _, value := range split {
				indexStart := strings.Index(value,"(")
				indexEnd := strings.Index(value,")")

				if strings.Contains(value,"size") {
					sizeNumber = value[indexStart+1 : indexEnd]
					fmt.Println("sizeNumber = ",sizeNumber)
				}
				if curFieldType == "int" {
					fileType ="int("+sizeNumber+")"
				}

				if curFieldType == "string"{
					fileType ="varchar("+sizeNumber+")"
				}
			}
			sql = strings.Join([]string{sql," ",fileType},"")
			if i == modelType.Elem().NumField()-1 {
				sql = strings.Join([]string{sql," ","DEFAULT NULL,PRIMARY KEY (`id`)"},"")
			}else{
				sql = strings.Join([]string{sql," ","DEFAULT NULL,"},"")
			}
		}
		sql = strings.Join([]string{sql,") ENGINE=InnoDB DEFAULT CHARSET=utf8;"},"")
		sqlStr = append(sqlStr, sql)
	}
	return sqlStr
}

(5)导入mysql驱动。

import (
	......
	_ "github.com/go-sql-driver/mysql"
)

 

 

猜你喜欢

转载自blog.csdn.net/zhongliwen1981/article/details/89643938