tcp-proxy/pkg/proxy/proxy.go
2020-01-13 18:05:18 +08:00

164 lines
2.9 KiB
Go

package proxy
import (
"context"
"crypto/tls"
"io"
"net"
"ssh-proxy/pkg/logger"
)
type Proxy struct {
sendBytes uint64
receivedBytes uint64
laddr, raddr *net.TCPAddr
lconn, rconn io.ReadWriteCloser
erred bool
errsig chan error
tlsUnwrapp bool
tlsAddress string
Matcher func([]byte)
Replacer func([]byte) []byte
// setting
Nagles bool
Log logger.Logger
OutputHex bool
}
func New(lconn *net.TCPConn, laddr, raddr *net.TCPAddr) *Proxy {
return &Proxy{
lconn: lconn,
laddr: laddr,
raddr: raddr,
erred: false,
errsig: make(chan error),
Log: logger.NullLogger{},
}
}
func NewTLSUnwarpped(lconn *net.TCPConn, laddr, raddr *net.TCPAddr, addr string) *Proxy {
p := New(lconn, laddr, raddr)
p.tlsUnwrapp = true
p.tlsAddress = addr
return p
}
type setNoDelayer interface {
SetNoDelay(bool) error
}
func (p *Proxy) Start() (err error) {
defer p.lconn.Close()
if p.tlsUnwrapp {
p.rconn, err = tls.Dial("tcp", p.tlsAddress, nil)
} else {
p.rconn, err = net.DialTCP("tcp", nil, p.raddr)
}
if err != nil {
p.Log.Warn("Remote connection failed: %s", err)
return
}
defer p.rconn.Close()
// nagles
if p.Nagles {
if conn, ok := p.lconn.(setNoDelayer); ok {
conn.SetNoDelay(true)
}
if conn, ok := p.rconn.(setNoDelayer); ok {
conn.SetNoDelay(true)
}
}
//display both ends
p.Log.Info("Opened %s >>>> %s", p.laddr.String(), p.raddr.String())
ctx, cancel := context.WithCancel(context.Background())
//bidirectional copy
go p.pipe(ctx, p.lconn, p.rconn)
go p.pipe(ctx, p.rconn, p.lconn)
// wait for close
err = <-p.errsig
p.Log.Info("Closed (%d bytes send, %d bytes received)", p.sendBytes, p.receivedBytes)
cancel()
return
}
func (p *Proxy) err(s string, err error) {
if p.erred {
return
}
if err != io.EOF {
p.Log.Warn(s, err)
}
p.errsig <- err
p.erred = true
}
func (p *Proxy) pipe(ctx context.Context, src, dst io.ReadWriter) {
islocal := src == p.lconn
var dataDirection string
if islocal {
dataDirection = ">>>> %d bytes sent%s"
} else {
dataDirection = "<<<< %d bytes received%s"
}
var byteFormat string
if p.OutputHex {
byteFormat = "%x"
} else {
byteFormat = "%s"
}
// directional copy (64k buffer)
buff := make([]byte, 0xffff)
for {
select {
case <-ctx.Done():
return
default:
n, err := src.Read(buff)
if err != nil {
p.err("Read failed '%s'\n", err)
return
}
b := buff[:n]
// execute match
if p.Matcher != nil {
p.Matcher(b)
}
// execute replace
if p.Replacer != nil {
b = p.Replacer(b)
}
// show output
p.Log.Debug(dataDirection, n, "")
p.Log.Trace(byteFormat, b)
// write to result
n, err = dst.Write(b)
if err != nil {
p.err("Write failed '%s'\n", err)
return
}
if islocal {
p.sendBytes += uint64(n)
} else {
p.receivedBytes += uint64(n)
}
}
}
}