package proxy import ( "context" "crypto/tls" "io" "net" "tcp-proxy/pkg/logger" ) type Proxy struct { sendBytes uint64 receivedBytes uint64 laddr, raddr *net.TCPAddr lconn, rconn io.ReadWriteCloser erred bool errsig chan error tlsUnwrapp bool tlsAddress string Matcher func([]byte) Replacer func([]byte) []byte // setting Nagles bool Log logger.Logger OutputHex bool } func New(lconn *net.TCPConn, laddr, raddr *net.TCPAddr) *Proxy { return &Proxy{ lconn: lconn, laddr: laddr, raddr: raddr, erred: false, errsig: make(chan error), Log: logger.NullLogger{}, } } func NewTLSUnwarpped(lconn *net.TCPConn, laddr, raddr *net.TCPAddr, addr string) *Proxy { p := New(lconn, laddr, raddr) p.tlsUnwrapp = true p.tlsAddress = addr return p } type setNoDelayer interface { SetNoDelay(bool) error } func (p *Proxy) Start(ctx context.Context) (err error) { defer p.lconn.Close() if p.tlsUnwrapp { p.rconn, err = tls.Dial("tcp", p.tlsAddress, nil) } else { p.rconn, err = net.DialTCP("tcp", nil, p.raddr) } if err != nil { p.Log.Warn("Remote connection failed: %s", err) return } defer p.rconn.Close() // nagles if p.Nagles { if conn, ok := p.lconn.(setNoDelayer); ok { conn.SetNoDelay(true) } if conn, ok := p.rconn.(setNoDelayer); ok { conn.SetNoDelay(true) } } //display both ends p.Log.Info("Opened %s >>>> %s", p.laddr.String(), p.raddr.String()) //bidirectional copy go p.pipe(ctx, p.lconn, p.rconn) go p.pipe(ctx, p.rconn, p.lconn) select { case <-ctx.Done(): p.Log.Info("Context Done (%d bytes send, %d bytes received)", p.sendBytes, p.receivedBytes) return // wait for close case err = <-p.errsig: p.Log.Info("Closed (%d bytes send, %d bytes received)", p.sendBytes, p.receivedBytes) return err } } func (p *Proxy) err(s string, err error) { if p.erred { return } if err != io.EOF { p.Log.Warn(s, err) } p.errsig <- err p.erred = true } func (p *Proxy) pipe(ctx context.Context, src, dst io.ReadWriter) { islocal := src == p.lconn var dataDirection string if islocal { dataDirection = ">>>> %d bytes sent%s" } else { dataDirection = "<<<< %d bytes received%s" } var byteFormat string if p.OutputHex { byteFormat = "%x" } else { byteFormat = "%s" } // directional copy (64k buffer) buff := make([]byte, 0xffff) for { select { case <-ctx.Done(): return default: n, err := src.Read(buff) if err != nil { p.err("Read failed '%s'\n", err) return } b := buff[:n] // execute match if p.Matcher != nil { p.Matcher(b) } // execute replace if p.Replacer != nil { b = p.Replacer(b) } // show output p.Log.Debug(dataDirection, n, "") p.Log.Trace(byteFormat, b) // write to result n, err = dst.Write(b) if err != nil { p.err("Write failed '%s'\n", err) return } if islocal { p.sendBytes += uint64(n) } else { p.receivedBytes += uint64(n) } } } }