用Go写一个内网穿透工具

系统架构

系统分为两个部分,client 和 server,client运行在内网服务器中,server运行在公网服务器中,当我们想访问内网中的服务,我们通过公网服务器做一个中继。

下面是展示我灵魂画手的时刻了

user发送请求给 server,server和client建立连接,将请求发给client,client再将请求发给本地程序处理(内网中),然后本地程序将处理结果返回给client,client将结果返回给server,server再将结果返回给用户,这样用户就访问到了内网中的程序了。

代码流程

  1. server端监听两个端口,一个用来和user通信,一个和client通信
  2. client启动时连接server端,并启动一个端口监听本地某程序
  3. 当User连接到server端口,将User请求内容发给client
  4. client将从server收到的请求发给本地程序
  5. client将从本地程序收到的内容发给server
  6. server将从client收到的内容发给User即可

  1. 当Server与client没有消息通信,连接会断开
  2. client断开后,再启动会连接不到Server
  3. Server端会因为client断开而引发panic

为了解决这种坑点,加入了心跳包机制,通过5s发送一次心跳包,保持client与server的连接,同时建立一个重连通道,监听该通道,如果当Client被断开后,则往重连通道放一个值,告诉Server端,等待新的Client连接,而避免引发Panic

代码

更详细的我就不说了,直接看代码,代码里面有详细的注释

代码仓库地址: https://github.com/pibigstar/go-proxy

Server端

运行在具有公网IP地址的服务器端

package main
import (
	"flag"
	"fmt"
	"io"
	"net"
	"runtime"
	"strings"
	"time"
)

var (
	localPort  int
	remotePort int
)

func init() {
	flag.IntVar(&localPort, "l", 5200, "the user link port")
	flag.IntVar(&remotePort, "r", 3333, "client listen port")
}

type client struct {
	conn net.Conn
	// 数据传输通道
	read  chan []byte
	write chan []byte
	// 异常退出通道
	exit chan error
	// 重连通道
	reConn chan bool
}

// 从Client端读取数据
func (c *client) Read() {
	// 如果10秒钟内没有消息传输,则Read函数会返回一个timeout的错误
	_ = c.conn.SetReadDeadline(time.Now().Add(time.Second * 10))
	for {
		data := make([]byte, 10240)
		n, err := c.conn.Read(data)
		if err != nil && err != io.EOF {
			if strings.Contains(err.Error(), "timeout") {
				// 设置读取时间为3秒,3秒后若读取不到, 则err会抛出timeout,然后发送心跳
				_ = c.conn.SetReadDeadline(time.Now().Add(time.Second * 3))
				c.conn.Write([]byte("pi"))
				continue
			}
			fmt.Println("读取出现错误...")
			c.exit <- err
		}

		// 收到心跳包,则跳过
		if data[0] == 'p' && data[1] == 'i' {
			fmt.Println("server收到心跳包")
			continue
		}
		c.read <- data[:n]
	}
}

// 将数据写入到Client端
func (c *client) Write() {
	for {
		select {
		case data := <-c.write:
			_, err := c.conn.Write(data)
			if err != nil && err != io.EOF {
				c.exit <- err
			}
		}
	}
}

type user struct {
	conn net.Conn
	// 数据传输通道
	read  chan []byte
	write chan []byte
	// 异常退出通道
	exit chan error
}

// 从User端读取数据
func (u *user) Read() {
	_ = u.conn.SetReadDeadline(time.Now().Add(time.Second * 200))
	for {
		data := make([]byte, 10240)
		n, err := u.conn.Read(data)
		if err != nil && err != io.EOF {
			u.exit <- err
		}
		u.read <- data[:n]
	}
}

// 将数据写给User端
func (u *user) Write() {
	for {
		select {
		case data := <-u.write:
			_, err := u.conn.Write(data)
			if err != nil && err != io.EOF {
				u.exit <- err
			}
		}
	}
}

func main() {
	flag.Parse()

	defer func() {
		err := recover()
		if err != nil {
			fmt.Println(err)
		}
	}()

	clientListener, err := net.Listen("tcp", fmt.Sprintf(":%d", remotePort))
	if err != nil {
		panic(err)
	}
	fmt.Printf("监听:%d端口, 等待client连接... \n", remotePort)
	// 监听User来连接
	userListener, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort))
	if err != nil {
		panic(err)
	}
	fmt.Printf("监听:%d端口, 等待user连接.... \n", localPort)

	for {
		// 有Client来连接了
		clientConn, err := clientListener.Accept()
		if err != nil {
			panic(err)
		}

		fmt.Printf("有Client连接: %s \n", clientConn.RemoteAddr())

		client := &client{
			conn:   clientConn,
			read:   make(chan []byte),
			write:  make(chan []byte),
			exit:   make(chan error),
			reConn: make(chan bool),
		}

		userConnChan := make(chan net.Conn)
		go AcceptUserConn(userListener, userConnChan)

		go HandleClient(client, userConnChan)

		<-client.reConn
		fmt.Println("重新等待新的client连接..")
	}
}

func HandleClient(client *client, userConnChan chan net.Conn) {

	go client.Read()
	go client.Write()

	for {
		select {
		case err := <-client.exit:
			fmt.Printf("client出现错误, 开始重试, err: %s \n", err.Error())
			client.reConn <- true
			runtime.Goexit()

		case userConn := <-userConnChan:
			user := &user{
				conn:  userConn,
				read:  make(chan []byte),
				write: make(chan []byte),
				exit:  make(chan error),
			}
			go user.Read()
			go user.Write()

			go handle(client, user)
		}
	}
}

// 将两个Socket通道链接
// 1. 将从user收到的信息发给client
// 2. 将从client收到信息发给user
func handle(client *client, user *user) {
	for {
		select {
		case userRecv := <-user.read:
			// 收到从user发来的信息
			client.write <- userRecv
		case clientRecv := <-client.read:
			// 收到从client发来的信息
			user.write <- clientRecv

		case err := <-client.exit:
			fmt.Println("client出现错误, 关闭连接", err.Error())
			_ = client.conn.Close()
			_ = user.conn.Close()
			client.reConn <- true
			// 结束当前goroutine
			runtime.Goexit()

		case err := <-user.exit:
			fmt.Println("user出现错误,关闭连接", err.Error())
			_ = user.conn.Close()
		}
	}
}

// 等待user连接
func AcceptUserConn(userListener net.Listener, connChan chan net.Conn) {
	userConn, err := userListener.Accept()
	if err != nil {
		panic(err)
	}
	fmt.Printf("user connect: %s \n", userConn.RemoteAddr())
	connChan <- userConn
}

Client端

运行在需要内网穿透的客户端中

package main

import (
	"flag"
	"fmt"
	"io"
	"net"
	"runtime"
	"strings"
	"time"
)

var (
	host       string
	localPort  int
	remotePort int
)

func init() {
	flag.StringVar(&host, "h", "127.0.0.1", "remote server ip")
	flag.IntVar(&localPort, "l", 8080, "the local port")
	flag.IntVar(&remotePort, "r", 3333, "remote server port")
}

type server struct {
	conn net.Conn
	// 数据传输通道
	read  chan []byte
	write chan []byte
	// 异常退出通道
	exit chan error
	// 重连通道
	reConn chan bool
}

// 从Server端读取数据
func (s *server) Read() {
	// 如果10秒钟内没有消息传输,则Read函数会返回一个timeout的错误
	_ = s.conn.SetReadDeadline(time.Now().Add(time.Second * 10))
	for {
		data := make([]byte, 10240)
		n, err := s.conn.Read(data)
		if err != nil && err != io.EOF {
			// 读取超时,发送一个心跳包过去
			if strings.Contains(err.Error(), "timeout") {
				// 3秒发一次心跳
				_ = s.conn.SetReadDeadline(time.Now().Add(time.Second * 3))
				s.conn.Write([]byte("pi"))
				continue
			}
			fmt.Println("从server读取数据失败, ", err.Error())
			s.exit <- err
			runtime.Goexit()
		}

		// 如果收到心跳包, 则跳过
		if data[0] == 'p' && data[1] == 'i' {
			fmt.Println("client收到心跳包")
			continue
		}
		s.read <- data[:n]
	}
}

// 将数据写入到Server端
func (s *server) Write() {
	for {
		select {
		case data := <-s.write:
			_, err := s.conn.Write(data)
			if err != nil && err != io.EOF {
				s.exit <- err
			}
		}
	}
}

type local struct {
	conn net.Conn
	// 数据传输通道
	read  chan []byte
	write chan []byte
	// 有异常退出通道
	exit chan error
}

func (l *local) Read() {

	for {
		data := make([]byte, 10240)
		n, err := l.conn.Read(data)
		if err != nil {
			l.exit <- err
		}
		l.read <- data[:n]
	}
}

func (l *local) Write() {
	for {
		select {
		case data := <-l.write:
			_, err := l.conn.Write(data)
			if err != nil {
				l.exit <- err
			}
		}
	}
}

func main() {
	flag.Parse()

	target := net.JoinHostPort(host, fmt.Sprintf("%d", remotePort))
	for {
		serverConn, err := net.Dial("tcp", target)
		if err != nil {
			panic(err)
		}

		fmt.Printf("已连接server: %s \n", serverConn.RemoteAddr())
		server := &server{
			conn:   serverConn,
			read:   make(chan []byte),
			write:  make(chan []byte),
			exit:   make(chan error),
			reConn: make(chan bool),
		}

		go server.Read()
		go server.Write()

		go handle(server)

		<-server.reConn
		_ = server.conn.Close()
	}

}

func handle(server *server) {
	// 等待server端发来的信息,也就是说user来请求server了
	data := <-server.read

	localConn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort))
	if err != nil {
		panic(err)
	}

	local := &local{
		conn:  localConn,
		read:  make(chan []byte),
		write: make(chan []byte),
		exit:  make(chan error),
	}

	go local.Read()
	go local.Write()

	local.write <- data

	for {
		select {
		case data := <-server.read:
			local.write <- data

		case data := <-local.read:
			server.write <- data

		case err := <-server.exit:
			fmt.Printf("server have err: %s", err.Error())
			_ = server.conn.Close()
			_ = local.conn.Close()
			server.reConn <- true

		case err := <-local.exit:
			fmt.Printf("server have err: %s", err.Error())
			_ = local.conn.Close()
		}
	}
}
发布了237 篇原创文章 · 获赞 215 · 访问量 39万+

猜你喜欢

转载自blog.csdn.net/junmoxi/article/details/103481058