commit c15c1c0421214ce96022ebae975070063584d725 Author: JayChen Date: Mon Jan 13 18:05:18 2020 +0800 proxy ver1 diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..30e9a47 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module ssh-proxy + +go 1.13 + +require ( + github.com/mattn/go-colorable v0.1.4 // indirect + github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..de1e071 --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= +github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223 h1:DH4skfRX4EBpamg7iV4ZlCpblAHI6s6TDM39bFZumv8= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/main.go b/main.go new file mode 100644 index 0000000..319674e --- /dev/null +++ b/main.go @@ -0,0 +1,55 @@ +package main + +import ( + "fmt" + "log" + "net" + "ssh-proxy/pkg/logger" + "ssh-proxy/pkg/proxy" +) + +func main() { + fmt.Println("vim-go") + + laddr, err := net.ResolveTCPAddr("tcp", ":9999") + if err != nil { + log.Fatal(err) + } + + raddr, err := net.ResolveTCPAddr("tcp", "localhost:22") + if err != nil { + log.Fatal(err) + } + + listener, err := net.ListenTCP("tcp", laddr) + if err != nil { + log.Fatal(err) + } + + var connid uint64 = 0 + for { + conn, err := listener.AcceptTCP() + if err != nil { + log.Println(err) + continue + } + connid++ + var p *proxy.Proxy + + p = proxy.New(conn, laddr, raddr) + p.Nagles = false + p.OutputHex = true + p.Log = logger.ColorLogger{ + VeryVerbose: true, + Verbose: true, + Prefix: fmt.Sprintf("Connection: #%03d ", connid), + Color: true, + } + go func() { + err := p.Start() + if err != nil { + log.Println("conn err ::: ", err) + } + }() + } +} diff --git a/pkg/common/common.go b/pkg/common/common.go new file mode 100644 index 0000000..c672485 --- /dev/null +++ b/pkg/common/common.go @@ -0,0 +1,24 @@ +package common + +import ( + "net" + "ssh-proxy/pkg/logger" + "strconv" + "time" +) + +func CheckRemoteSSH(host string, port int, timeout int) bool { + log := logger.ColorLogger{} + connTimeout := time.Duration(timeout) * time.Second + if connTimeout <= 0 { + connTimeout = time.Second * 5 + } + + conn, err := net.DialTimeout("tcp", net.JoinHostPort(host, strconv.Itoa(port)), connTimeout) + if err != nil { + log.Trace("Check Remote Host SSH Fail: Host(%s:%d) %s", host, port, err) + return false + } + defer conn.Close() + return true +} diff --git a/pkg/common/common_test.go b/pkg/common/common_test.go new file mode 100644 index 0000000..0d81013 --- /dev/null +++ b/pkg/common/common_test.go @@ -0,0 +1,33 @@ +package common + +import "testing" + +func TestCheckRemoteSSH(t *testing.T) { + type args struct { + host string + port int + timeout int + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "test check localhost ssh port", + args: args{ + host: "localhost", + port: 22, + timeout: 5, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CheckRemoteSSH(tt.args.host, tt.args.port, tt.args.timeout); got != tt.want { + t.Errorf("CheckRemoteSSH() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go new file mode 100644 index 0000000..f67d99b --- /dev/null +++ b/pkg/logger/logger.go @@ -0,0 +1,60 @@ +package logger + +import ( + "fmt" + + "github.com/mgutz/ansi" +) + +type Logger interface { + Trace(f string, args ...interface{}) + Debug(f string, args ...interface{}) + Info(f string, args ...interface{}) + Warn(f string, args ...interface{}) +} + +type NullLogger struct{} + +func (n NullLogger) Trace(f string, args ...interface{}) {} + +func (n NullLogger) Debug(f string, args ...interface{}) {} + +func (n NullLogger) Info(f string, args ...interface{}) {} + +func (n NullLogger) Warn(f string, args ...interface{}) {} + +type ColorLogger struct { + VeryVerbose bool + Verbose bool + Prefix string + Color bool +} + +func (c ColorLogger) Trace(f string, args ...interface{}) { + if !c.VeryVerbose { + return + } + c.output("blue", f, args...) +} + +func (c ColorLogger) Debug(f string, args ...interface{}) { + if !c.Verbose { + return + } + c.output("green", f, args...) +} + +func (c ColorLogger) Info(f string, args ...interface{}) { + c.output("green", f, args...) +} + +func (c ColorLogger) Warn(f string, args ...interface{}) { + c.output("red", f, args...) +} + +func (c ColorLogger) output(color, f string, args ...interface{}) { + if c.Color && color != "" { + f = ansi.Color(f, color) + } + fmt.Printf(fmt.Sprintf("%s%s\n", c.Prefix, f), args...) +} diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go new file mode 100644 index 0000000..61ecabe --- /dev/null +++ b/pkg/proxy/proxy.go @@ -0,0 +1,163 @@ +package proxy + +import ( + "context" + "crypto/tls" + "io" + "net" + "ssh-proxy/pkg/logger" +) + +type Proxy struct { + sendBytes uint64 + receivedBytes uint64 + laddr, raddr *net.TCPAddr + lconn, rconn io.ReadWriteCloser + erred bool + errsig chan error + tlsUnwrapp bool + tlsAddress string + + Matcher func([]byte) + Replacer func([]byte) []byte + + // setting + Nagles bool + Log logger.Logger + OutputHex bool +} + +func New(lconn *net.TCPConn, laddr, raddr *net.TCPAddr) *Proxy { + return &Proxy{ + lconn: lconn, + laddr: laddr, + raddr: raddr, + erred: false, + errsig: make(chan error), + Log: logger.NullLogger{}, + } +} + +func NewTLSUnwarpped(lconn *net.TCPConn, laddr, raddr *net.TCPAddr, addr string) *Proxy { + p := New(lconn, laddr, raddr) + p.tlsUnwrapp = true + p.tlsAddress = addr + return p +} + +type setNoDelayer interface { + SetNoDelay(bool) error +} + +func (p *Proxy) Start() (err error) { + defer p.lconn.Close() + + if p.tlsUnwrapp { + p.rconn, err = tls.Dial("tcp", p.tlsAddress, nil) + } else { + p.rconn, err = net.DialTCP("tcp", nil, p.raddr) + } + if err != nil { + p.Log.Warn("Remote connection failed: %s", err) + return + } + defer p.rconn.Close() + + // nagles + if p.Nagles { + if conn, ok := p.lconn.(setNoDelayer); ok { + conn.SetNoDelay(true) + } + if conn, ok := p.rconn.(setNoDelayer); ok { + conn.SetNoDelay(true) + } + } + + //display both ends + p.Log.Info("Opened %s >>>> %s", p.laddr.String(), p.raddr.String()) + + ctx, cancel := context.WithCancel(context.Background()) + + //bidirectional copy + go p.pipe(ctx, p.lconn, p.rconn) + go p.pipe(ctx, p.rconn, p.lconn) + + // wait for close + err = <-p.errsig + p.Log.Info("Closed (%d bytes send, %d bytes received)", p.sendBytes, p.receivedBytes) + cancel() + + return +} + +func (p *Proxy) err(s string, err error) { + if p.erred { + return + } + + if err != io.EOF { + p.Log.Warn(s, err) + } + p.errsig <- err + p.erred = true +} + +func (p *Proxy) pipe(ctx context.Context, src, dst io.ReadWriter) { + islocal := src == p.lconn + + var dataDirection string + if islocal { + dataDirection = ">>>> %d bytes sent%s" + } else { + dataDirection = "<<<< %d bytes received%s" + } + + var byteFormat string + if p.OutputHex { + byteFormat = "%x" + } else { + byteFormat = "%s" + } + + // directional copy (64k buffer) + buff := make([]byte, 0xffff) + for { + select { + case <-ctx.Done(): + return + default: + n, err := src.Read(buff) + if err != nil { + p.err("Read failed '%s'\n", err) + return + } + b := buff[:n] + + // execute match + if p.Matcher != nil { + p.Matcher(b) + } + // execute replace + if p.Replacer != nil { + b = p.Replacer(b) + } + + // show output + p.Log.Debug(dataDirection, n, "") + p.Log.Trace(byteFormat, b) + + // write to result + n, err = dst.Write(b) + if err != nil { + p.err("Write failed '%s'\n", err) + return + } + + if islocal { + p.sendBytes += uint64(n) + } else { + p.receivedBytes += uint64(n) + } + } + } +}