|
@@ -0,0 +1,142 @@
|
|
|
+package proxy
|
|
|
+
|
|
|
+import (
|
|
|
+ "crypto/tls"
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "log"
|
|
|
+ "net"
|
|
|
+ "net/http"
|
|
|
+ "net/url"
|
|
|
+ "os"
|
|
|
+ "strings"
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ WsScheme = "ws"
|
|
|
+ WssScheme = "wss"
|
|
|
+ BufSize = 1024 * 32
|
|
|
+)
|
|
|
+
|
|
|
+var ErrFormatAddr = errors.New("remote websockets addr format error")
|
|
|
+
|
|
|
+type WebsocketProxy struct {
|
|
|
+ // ws, wss
|
|
|
+ scheme string
|
|
|
+ // The target address: host:port
|
|
|
+ remoteAddr string
|
|
|
+ // path
|
|
|
+ defaultPath string
|
|
|
+ tls *tls.Config
|
|
|
+ logger *log.Logger
|
|
|
+ // Send handshake before callback
|
|
|
+ beforeHandshake func(r *http.Request) error
|
|
|
+}
|
|
|
+
|
|
|
+type Options func(wp *WebsocketProxy)
|
|
|
+
|
|
|
+// You must carry a port number,ws://ip:80/ssss, wss://ip:443/aaaa
|
|
|
+// ex: ws://ip:port/ajaxchattest
|
|
|
+func NewWebsocketProxy(addr string, beforeHandshake func(r *http.Request) error, options ...Options) (*WebsocketProxy, error) {
|
|
|
+ u, err := url.Parse(addr)
|
|
|
+ if err != nil {
|
|
|
+ return nil, ErrFormatAddr
|
|
|
+ }
|
|
|
+ host, port, err := net.SplitHostPort(u.Host)
|
|
|
+ if err != nil {
|
|
|
+ return nil, ErrFormatAddr
|
|
|
+ }
|
|
|
+ if u.Scheme != WsScheme && u.Scheme != WssScheme {
|
|
|
+ return nil, ErrFormatAddr
|
|
|
+ }
|
|
|
+ wp := &WebsocketProxy{
|
|
|
+ scheme: u.Scheme,
|
|
|
+ remoteAddr: fmt.Sprintf("%s:%s", host, port),
|
|
|
+ defaultPath: u.Path,
|
|
|
+ beforeHandshake: beforeHandshake,
|
|
|
+ logger: log.New(os.Stderr, "", log.LstdFlags),
|
|
|
+ }
|
|
|
+ if u.Scheme == WssScheme {
|
|
|
+ wp.tls = &tls.Config{InsecureSkipVerify: true}
|
|
|
+ }
|
|
|
+ for op := range options {
|
|
|
+ options[op](wp)
|
|
|
+ }
|
|
|
+ return wp, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (wp *WebsocketProxy) Proxy(writer http.ResponseWriter, request *http.Request) bool {
|
|
|
+ if strings.ToLower(request.Header.Get("Connection")) != "upgrade" ||
|
|
|
+ strings.ToLower(request.Header.Get("Upgrade")) != "websocket" {
|
|
|
+ _, _ = writer.Write([]byte(`Must be a websocket request`))
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ hijacker, ok := writer.(http.Hijacker)
|
|
|
+ if !ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ conn, _, err := hijacker.Hijack()
|
|
|
+ if err != nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ defer conn.Close()
|
|
|
+ req := request.Clone(request.Context())
|
|
|
+ req.URL.Path, req.URL.RawPath, req.RequestURI = wp.defaultPath, wp.defaultPath, wp.defaultPath
|
|
|
+ req.Host = wp.remoteAddr
|
|
|
+ if wp.beforeHandshake != nil {
|
|
|
+ // Add headers, permission authentication + masquerade sources
|
|
|
+ err = wp.beforeHandshake(req)
|
|
|
+ if err != nil {
|
|
|
+ _, _ = writer.Write([]byte(err.Error()))
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+ var remoteConn net.Conn
|
|
|
+ switch wp.scheme {
|
|
|
+ case WsScheme:
|
|
|
+ remoteConn, err = net.Dial("tcp", wp.remoteAddr)
|
|
|
+ case WssScheme:
|
|
|
+ remoteConn, err = tls.Dial("tcp", wp.remoteAddr, wp.tls)
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ _, _ = writer.Write([]byte(err.Error()))
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ defer remoteConn.Close()
|
|
|
+ err = req.Write(remoteConn)
|
|
|
+ if err != nil {
|
|
|
+ wp.logger.Println("remote write err:", err)
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ errChan := make(chan error, 2)
|
|
|
+ copyConn := func(a, b net.Conn) {
|
|
|
+ buf := ByteSliceGet(BufSize)
|
|
|
+ defer ByteSlicePut(buf)
|
|
|
+ _, err := io.CopyBuffer(a, b, buf)
|
|
|
+ errChan <- err
|
|
|
+ }
|
|
|
+ go copyConn(conn, remoteConn) // response
|
|
|
+ go copyConn(remoteConn, conn) // request
|
|
|
+ select {
|
|
|
+ case err = <-errChan:
|
|
|
+ if err != nil {
|
|
|
+ log.Println(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
+func SetTLSConfig(tls *tls.Config) Options {
|
|
|
+ return func(wp *WebsocketProxy) {
|
|
|
+ wp.tls = tls
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func SetLogger(l *log.Logger) Options {
|
|
|
+ return func(wp *WebsocketProxy) {
|
|
|
+ if l != nil {
|
|
|
+ wp.logger = l
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|