165 lines
2.9 KiB
Go
165 lines
2.9 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"io"
|
|
"net"
|
|
"tcp-proxy/pkg/logger"
|
|
)
|
|
|
|
type Proxy struct {
|
|
sendBytes uint64
|
|
receivedBytes uint64
|
|
laddr, raddr *net.TCPAddr
|
|
lconn, rconn io.ReadWriteCloser
|
|
erred bool
|
|
errsig chan error
|
|
tlsUnwrapp bool
|
|
tlsAddress string
|
|
|
|
Matcher func([]byte)
|
|
Replacer func([]byte) []byte
|
|
|
|
// setting
|
|
Nagles bool
|
|
Log logger.Logger
|
|
OutputHex bool
|
|
}
|
|
|
|
func New(lconn *net.TCPConn, laddr, raddr *net.TCPAddr) *Proxy {
|
|
return &Proxy{
|
|
lconn: lconn,
|
|
laddr: laddr,
|
|
raddr: raddr,
|
|
erred: false,
|
|
errsig: make(chan error),
|
|
Log: logger.NullLogger{},
|
|
}
|
|
}
|
|
|
|
func NewTLSUnwarpped(lconn *net.TCPConn, laddr, raddr *net.TCPAddr, addr string) *Proxy {
|
|
p := New(lconn, laddr, raddr)
|
|
p.tlsUnwrapp = true
|
|
p.tlsAddress = addr
|
|
return p
|
|
}
|
|
|
|
type setNoDelayer interface {
|
|
SetNoDelay(bool) error
|
|
}
|
|
|
|
func (p *Proxy) Start(ctx context.Context) (err error) {
|
|
defer p.lconn.Close()
|
|
|
|
if p.tlsUnwrapp {
|
|
p.rconn, err = tls.Dial("tcp", p.tlsAddress, nil)
|
|
} else {
|
|
p.rconn, err = net.DialTCP("tcp", nil, p.raddr)
|
|
}
|
|
if err != nil {
|
|
p.Log.Warn("Remote connection failed: %s", err)
|
|
return
|
|
}
|
|
defer p.rconn.Close()
|
|
|
|
// nagles
|
|
if p.Nagles {
|
|
if conn, ok := p.lconn.(setNoDelayer); ok {
|
|
conn.SetNoDelay(true)
|
|
}
|
|
if conn, ok := p.rconn.(setNoDelayer); ok {
|
|
conn.SetNoDelay(true)
|
|
}
|
|
}
|
|
|
|
//display both ends
|
|
p.Log.Info("Opened %s >>>> %s", p.laddr.String(), p.raddr.String())
|
|
|
|
//bidirectional copy
|
|
go p.pipe(ctx, p.lconn, p.rconn)
|
|
go p.pipe(ctx, p.rconn, p.lconn)
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
p.Log.Info("Context Done (%d bytes send, %d bytes received)", p.sendBytes, p.receivedBytes)
|
|
return
|
|
// wait for close
|
|
case err = <-p.errsig:
|
|
p.Log.Info("Closed (%d bytes send, %d bytes received)", p.sendBytes, p.receivedBytes)
|
|
return err
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) err(s string, err error) {
|
|
if p.erred {
|
|
return
|
|
}
|
|
|
|
if err != io.EOF {
|
|
p.Log.Warn(s, err)
|
|
}
|
|
p.errsig <- err
|
|
p.erred = true
|
|
}
|
|
|
|
func (p *Proxy) pipe(ctx context.Context, src, dst io.ReadWriter) {
|
|
islocal := src == p.lconn
|
|
|
|
var dataDirection string
|
|
if islocal {
|
|
dataDirection = ">>>> %d bytes sent%s"
|
|
} else {
|
|
dataDirection = "<<<< %d bytes received%s"
|
|
}
|
|
|
|
var byteFormat string
|
|
if p.OutputHex {
|
|
byteFormat = "%x"
|
|
} else {
|
|
byteFormat = "%s"
|
|
}
|
|
|
|
// directional copy (64k buffer)
|
|
buff := make([]byte, 0xffff)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
n, err := src.Read(buff)
|
|
if err != nil {
|
|
p.err("Read failed '%s'\n", err)
|
|
return
|
|
}
|
|
b := buff[:n]
|
|
|
|
// execute match
|
|
if p.Matcher != nil {
|
|
p.Matcher(b)
|
|
}
|
|
// execute replace
|
|
if p.Replacer != nil {
|
|
b = p.Replacer(b)
|
|
}
|
|
|
|
// show output
|
|
p.Log.Debug(dataDirection, n, "")
|
|
p.Log.Trace(byteFormat, b)
|
|
|
|
// write to result
|
|
n, err = dst.Write(b)
|
|
if err != nil {
|
|
p.err("Write failed '%s'\n", err)
|
|
return
|
|
}
|
|
|
|
if islocal {
|
|
p.sendBytes += uint64(n)
|
|
} else {
|
|
p.receivedBytes += uint64(n)
|
|
}
|
|
}
|
|
}
|
|
}
|