package proxy import ( "crypto/tls" "errors" "fmt" "io" "log" "net" "net/http" "net/url" "os" "strings" ) const ( WsScheme = "ws" WssScheme = "wss" BufSize = 1024 * 32 ) var ErrFormatAddr = errors.New("remote websockets addr format error") type WebsocketProxy struct { // ws, wss scheme string // The target address: host:port remoteAddr string // path defaultPath string tls *tls.Config logger *log.Logger // Send handshake before callback beforeHandshake func(r *http.Request) error } type Options func(wp *WebsocketProxy) // You must carry a port number,ws://ip:80/ssss, wss://ip:443/aaaa // ex: ws://ip:port/ajaxchattest func NewWebsocketProxy(addr string, beforeHandshake func(r *http.Request) error, options ...Options) (*WebsocketProxy, error) { u, err := url.Parse(addr) if err != nil { return nil, ErrFormatAddr } host, port, err := net.SplitHostPort(u.Host) if err != nil { return nil, ErrFormatAddr } if u.Scheme != WsScheme && u.Scheme != WssScheme { return nil, ErrFormatAddr } wp := &WebsocketProxy{ scheme: u.Scheme, remoteAddr: fmt.Sprintf("%s:%s", host, port), defaultPath: u.Path, beforeHandshake: beforeHandshake, logger: log.New(os.Stderr, "", log.LstdFlags), } if u.Scheme == WssScheme { wp.tls = &tls.Config{InsecureSkipVerify: true} } for op := range options { options[op](wp) } return wp, nil } func (wp *WebsocketProxy) Proxy(writer http.ResponseWriter, request *http.Request) bool { if strings.ToLower(request.Header.Get("Connection")) != "upgrade" || strings.ToLower(request.Header.Get("Upgrade")) != "websocket" { _, _ = writer.Write([]byte(`Must be a websocket request`)) return false } hijacker, ok := writer.(http.Hijacker) if !ok { return false } conn, _, err := hijacker.Hijack() if err != nil { return false } defer conn.Close() req := request.Clone(request.Context()) req.URL.Path, req.URL.RawPath, req.RequestURI = wp.defaultPath, wp.defaultPath, wp.defaultPath req.Host = wp.remoteAddr if wp.beforeHandshake != nil { // Add headers, permission authentication + masquerade sources err = wp.beforeHandshake(req) if err != nil { _, _ = writer.Write([]byte(err.Error())) return false } } var remoteConn net.Conn switch wp.scheme { case WsScheme: remoteConn, err = net.Dial("tcp", wp.remoteAddr) case WssScheme: remoteConn, err = tls.Dial("tcp", wp.remoteAddr, wp.tls) } if err != nil { _, _ = writer.Write([]byte(err.Error())) return false } defer remoteConn.Close() err = req.Write(remoteConn) if err != nil { wp.logger.Println("remote write err:", err) return false } errChan := make(chan error, 2) copyConn := func(a, b net.Conn) { buf := ByteSliceGet(BufSize) defer ByteSlicePut(buf) _, err := io.CopyBuffer(a, b, buf) errChan <- err } go copyConn(conn, remoteConn) // response go copyConn(remoteConn, conn) // request select { case err = <-errChan: if err != nil { log.Println(err) } } return true } func SetTLSConfig(tls *tls.Config) Options { return func(wp *WebsocketProxy) { wp.tls = tls } } func SetLogger(l *log.Logger) Options { return func(wp *WebsocketProxy) { if l != nil { wp.logger = l } } }