websocket.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. package proxy
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log"
  8. "net"
  9. "net/http"
  10. "net/url"
  11. "os"
  12. "strings"
  13. )
  14. const (
  15. WsScheme = "ws"
  16. WssScheme = "wss"
  17. BufSize = 1024 * 32
  18. )
  19. var ErrFormatAddr = errors.New("remote websockets addr format error")
  20. type WebsocketProxy struct {
  21. // ws, wss
  22. scheme string
  23. // The target address: host:port
  24. remoteAddr string
  25. // path
  26. defaultPath string
  27. tls *tls.Config
  28. logger *log.Logger
  29. // Send handshake before callback
  30. beforeHandshake func(r *http.Request) error
  31. }
  32. type Options func(wp *WebsocketProxy)
  33. // You must carry a port number,ws://ip:80/ssss, wss://ip:443/aaaa
  34. // ex: ws://ip:port/ajaxchattest
  35. func NewWebsocketProxy(addr string, beforeHandshake func(r *http.Request) error, options ...Options) (*WebsocketProxy, error) {
  36. u, err := url.Parse(addr)
  37. if err != nil {
  38. return nil, ErrFormatAddr
  39. }
  40. host, port, err := net.SplitHostPort(u.Host)
  41. if err != nil {
  42. return nil, ErrFormatAddr
  43. }
  44. if u.Scheme != WsScheme && u.Scheme != WssScheme {
  45. return nil, ErrFormatAddr
  46. }
  47. wp := &WebsocketProxy{
  48. scheme: u.Scheme,
  49. remoteAddr: fmt.Sprintf("%s:%s", host, port),
  50. defaultPath: u.Path,
  51. beforeHandshake: beforeHandshake,
  52. logger: log.New(os.Stderr, "", log.LstdFlags),
  53. }
  54. if u.Scheme == WssScheme {
  55. wp.tls = &tls.Config{InsecureSkipVerify: true}
  56. }
  57. for op := range options {
  58. options[op](wp)
  59. }
  60. return wp, nil
  61. }
  62. func (wp *WebsocketProxy) Proxy(writer http.ResponseWriter, request *http.Request) bool {
  63. if strings.ToLower(request.Header.Get("Connection")) != "upgrade" ||
  64. strings.ToLower(request.Header.Get("Upgrade")) != "websocket" {
  65. _, _ = writer.Write([]byte(`Must be a websocket request`))
  66. return false
  67. }
  68. hijacker, ok := writer.(http.Hijacker)
  69. if !ok {
  70. return false
  71. }
  72. conn, _, err := hijacker.Hijack()
  73. if err != nil {
  74. return false
  75. }
  76. defer conn.Close()
  77. req := request.Clone(request.Context())
  78. req.URL.Path, req.URL.RawPath, req.RequestURI = wp.defaultPath, wp.defaultPath, wp.defaultPath
  79. req.Host = wp.remoteAddr
  80. if wp.beforeHandshake != nil {
  81. // Add headers, permission authentication + masquerade sources
  82. err = wp.beforeHandshake(req)
  83. if err != nil {
  84. _, _ = writer.Write([]byte(err.Error()))
  85. return false
  86. }
  87. }
  88. var remoteConn net.Conn
  89. switch wp.scheme {
  90. case WsScheme:
  91. remoteConn, err = net.Dial("tcp", wp.remoteAddr)
  92. case WssScheme:
  93. remoteConn, err = tls.Dial("tcp", wp.remoteAddr, wp.tls)
  94. }
  95. if err != nil {
  96. _, _ = writer.Write([]byte(err.Error()))
  97. return false
  98. }
  99. defer remoteConn.Close()
  100. err = req.Write(remoteConn)
  101. if err != nil {
  102. wp.logger.Println("remote write err:", err)
  103. return false
  104. }
  105. errChan := make(chan error, 2)
  106. copyConn := func(a, b net.Conn) {
  107. buf := ByteSliceGet(BufSize)
  108. defer ByteSlicePut(buf)
  109. _, err := io.CopyBuffer(a, b, buf)
  110. errChan <- err
  111. }
  112. go copyConn(conn, remoteConn) // response
  113. go copyConn(remoteConn, conn) // request
  114. select {
  115. case err = <-errChan:
  116. if err != nil {
  117. log.Println(err)
  118. }
  119. }
  120. return true
  121. }
  122. func SetTLSConfig(tls *tls.Config) Options {
  123. return func(wp *WebsocketProxy) {
  124. wp.tls = tls
  125. }
  126. }
  127. func SetLogger(l *log.Logger) Options {
  128. return func(wp *WebsocketProxy) {
  129. if l != nil {
  130. wp.logger = l
  131. }
  132. }
  133. }