123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- package proxy
- import (
- "crypto/tls"
- "errors"
- "fmt"
- "io"
- "log"
- "net"
- "net/http"
- "net/url"
- "os"
- "strings"
- )
- const (
- WsScheme = "ws"
- WssScheme = "wss"
- BufSize = 1024 * 32
- )
- var ErrFormatAddr = errors.New("remote websockets addr format error")
- type WebsocketProxy struct {
- // ws, wss
- scheme string
- // The target address: host:port
- remoteAddr string
- // path
- defaultPath string
- tls *tls.Config
- logger *log.Logger
- // Send handshake before callback
- beforeHandshake func(r *http.Request) error
- }
- type Options func(wp *WebsocketProxy)
- // You must carry a port number,ws://ip:80/ssss, wss://ip:443/aaaa
- // ex: ws://ip:port/ajaxchattest
- func NewWebsocketProxy(addr string, beforeHandshake func(r *http.Request) error, options ...Options) (*WebsocketProxy, error) {
- u, err := url.Parse(addr)
- if err != nil {
- return nil, ErrFormatAddr
- }
- host, port, err := net.SplitHostPort(u.Host)
- if err != nil {
- return nil, ErrFormatAddr
- }
- if u.Scheme != WsScheme && u.Scheme != WssScheme {
- return nil, ErrFormatAddr
- }
- wp := &WebsocketProxy{
- scheme: u.Scheme,
- remoteAddr: fmt.Sprintf("%s:%s", host, port),
- defaultPath: u.Path,
- beforeHandshake: beforeHandshake,
- logger: log.New(os.Stderr, "", log.LstdFlags),
- }
- if u.Scheme == WssScheme {
- wp.tls = &tls.Config{InsecureSkipVerify: true}
- }
- for op := range options {
- options[op](wp)
- }
- return wp, nil
- }
- func (wp *WebsocketProxy) Proxy(writer http.ResponseWriter, request *http.Request) bool {
- if strings.ToLower(request.Header.Get("Connection")) != "upgrade" ||
- strings.ToLower(request.Header.Get("Upgrade")) != "websocket" {
- _, _ = writer.Write([]byte(`Must be a websocket request`))
- return false
- }
- hijacker, ok := writer.(http.Hijacker)
- if !ok {
- return false
- }
- conn, _, err := hijacker.Hijack()
- if err != nil {
- return false
- }
- defer conn.Close()
- req := request.Clone(request.Context())
- req.URL.Path, req.URL.RawPath, req.RequestURI = wp.defaultPath, wp.defaultPath, wp.defaultPath
- req.Host = wp.remoteAddr
- if wp.beforeHandshake != nil {
- // Add headers, permission authentication + masquerade sources
- err = wp.beforeHandshake(req)
- if err != nil {
- _, _ = writer.Write([]byte(err.Error()))
- return false
- }
- }
- var remoteConn net.Conn
- switch wp.scheme {
- case WsScheme:
- remoteConn, err = net.Dial("tcp", wp.remoteAddr)
- case WssScheme:
- remoteConn, err = tls.Dial("tcp", wp.remoteAddr, wp.tls)
- }
- if err != nil {
- _, _ = writer.Write([]byte(err.Error()))
- return false
- }
- defer remoteConn.Close()
- err = req.Write(remoteConn)
- if err != nil {
- wp.logger.Println("remote write err:", err)
- return false
- }
- errChan := make(chan error, 2)
- copyConn := func(a, b net.Conn) {
- buf := ByteSliceGet(BufSize)
- defer ByteSlicePut(buf)
- _, err := io.CopyBuffer(a, b, buf)
- errChan <- err
- }
- go copyConn(conn, remoteConn) // response
- go copyConn(remoteConn, conn) // request
- select {
- case err = <-errChan:
- if err != nil {
- log.Println(err)
- }
- }
- return true
- }
- func SetTLSConfig(tls *tls.Config) Options {
- return func(wp *WebsocketProxy) {
- wp.tls = tls
- }
- }
- func SetLogger(l *log.Logger) Options {
- return func(wp *WebsocketProxy) {
- if l != nil {
- wp.logger = l
- }
- }
- }
|