From ae33d3da00b68d3f72456db914fcf4875e25c314 Mon Sep 17 00:00:00 2001 From: Jay Date: Tue, 14 Jan 2020 09:58:05 +0000 Subject: [PATCH] first version, unittest not fin --- .gitignore | 1 + config/config.yml | 6 ++ go.mod | 4 + go.sum | 9 ++ main.go | 121 +++++++++++++--------- pkg/common/common.go | 62 +++++++++++- pkg/common/common_test.go | 6 +- pkg/config/config.go | 64 ++++++++++++ pkg/listener/listener.go | 207 ++++++++++++++++++++++++++++++++++++++ pkg/option/option.go | 21 ++++ pkg/proxy/proxy.go | 19 ++-- 11 files changed, 458 insertions(+), 62 deletions(-) create mode 100644 .gitignore create mode 100644 config/config.yml create mode 100644 pkg/config/config.go create mode 100644 pkg/listener/listener.go create mode 100644 pkg/option/option.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a0ec596 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/config.yml diff --git a/config/config.yml b/config/config.yml new file mode 100644 index 0000000..afa00ff --- /dev/null +++ b/config/config.yml @@ -0,0 +1,6 @@ +# verbose level 0=info, 1=verbose, 2=very verbose +verbose: 2 +listen: + - target_host: 10.0.1.2-250 + target_port: 22 + listen_port: 22002 diff --git a/go.mod b/go.mod index 30e9a47..15b6149 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,10 @@ module ssh-proxy go 1.13 require ( + git.trj.tw/golang/utils v0.0.0-20190225142552-b019626f0349 + github.com/google/uuid v1.1.1 github.com/mattn/go-colorable v0.1.4 // indirect github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b + github.com/otakukaze/envconfig v1.0.0 + gopkg.in/yaml.v2 v2.2.7 ) diff --git a/go.sum b/go.sum index de1e071..e61e2ad 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,17 @@ +git.trj.tw/golang/utils v0.0.0-20190225142552-b019626f0349 h1:V6ifeiJ3ExnjaUylTOz37n6z5uLwm6fjKjnztbTCaQI= +git.trj.tw/golang/utils v0.0.0-20190225142552-b019626f0349/go.mod h1:yE+qbsUsijCTdwsaQRkPT1CXYk7ftMzXsCaaYx/0QI0= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +github.com/otakukaze/envconfig v1.0.0 h1:VGu7NoDFbReFO72m5Z15h5PH5gVk5fVLZ6iHF58ggFk= +github.com/otakukaze/envconfig v1.0.0/go.mod h1:v2dNv5NX1Lakw3FTAkbxYURyaiOy68M8QpMTZz+ogfs= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223 h1:DH4skfRX4EBpamg7iV4ZlCpblAHI6s6TDM39bFZumv8= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= +gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/main.go b/main.go index 319674e..777c9cc 100644 --- a/main.go +++ b/main.go @@ -1,55 +1,82 @@ package main import ( + "errors" "fmt" "log" - "net" - "ssh-proxy/pkg/logger" - "ssh-proxy/pkg/proxy" + "ssh-proxy/pkg/config" + "ssh-proxy/pkg/listener" + "ssh-proxy/pkg/option" ) -func main() { - fmt.Println("vim-go") - - laddr, err := net.ResolveTCPAddr("tcp", ":9999") - if err != nil { - log.Fatal(err) - } - - raddr, err := net.ResolveTCPAddr("tcp", "localhost:22") - if err != nil { - log.Fatal(err) - } - - listener, err := net.ListenTCP("tcp", laddr) - if err != nil { - log.Fatal(err) - } - - var connid uint64 = 0 - for { - conn, err := listener.AcceptTCP() - if err != nil { - log.Println(err) - continue - } - connid++ - var p *proxy.Proxy - - p = proxy.New(conn, laddr, raddr) - p.Nagles = false - p.OutputHex = true - p.Log = logger.ColorLogger{ - VeryVerbose: true, - Verbose: true, - Prefix: fmt.Sprintf("Connection: #%03d ", connid), - Color: true, - } - go func() { - err := p.Start() - if err != nil { - log.Println("conn err ::: ", err) - } - }() - } +func init() { + option.Parse() +} + +func main() { + var err error + fmt.Println("tcp proxy") + opts := option.Get() + if opts == nil { + log.Fatal(errors.New("no flag parse")) + } + + err = config.Load(opts.Config) + if err != nil { + log.Fatal(err) + } + + conf := config.Get() + + listener, err := listener.New(conf.Listen) + if err != nil { + log.Fatal(err) + } + + err = listener.Listen() + if err != nil { + log.Fatal(err) + } + + // laddr, err := net.ResolveTCPAddr("tcp", ":9999") + // if err != nil { + // log.Fatal(err) + // } + + // raddr, err := net.ResolveTCPAddr("tcp", "localhost:22") + // if err != nil { + // log.Fatal(err) + // } + + // listener, err := net.ListenTCP("tcp", laddr) + // if err != nil { + // log.Fatal(err) + // } + + // var connid uint64 = 0 + // for { + // conn, err := listener.AcceptTCP() + // if err != nil { + // log.Println(err) + // continue + // } + // connid++ + // var p *proxy.Proxy + + // p = proxy.New(conn, laddr, raddr) + // p.Nagles = false + // p.OutputHex = true + // p.Log = logger.ColorLogger{ + // VeryVerbose: true, + // Verbose: true, + // Prefix: fmt.Sprintf("Connection: #%03d ", connid), + // Color: true, + // } + // go func() { + // err := p.Start() + // if err != nil { + // log.Println("conn err ::: ", err) + // } + // }() + // } } diff --git a/pkg/common/common.go b/pkg/common/common.go index c672485..d3758d7 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -1,13 +1,15 @@ package common import ( + "fmt" "net" "ssh-proxy/pkg/logger" "strconv" "time" + "unicode/utf8" ) -func CheckRemoteSSH(host string, port int, timeout int) bool { +func CheckRemoteTCP(host string, port int, timeout int) error { log := logger.ColorLogger{} connTimeout := time.Duration(timeout) * time.Second if connTimeout <= 0 { @@ -17,8 +19,62 @@ func CheckRemoteSSH(host string, port int, timeout int) bool { conn, err := net.DialTimeout("tcp", net.JoinHostPort(host, strconv.Itoa(port)), connTimeout) if err != nil { log.Trace("Check Remote Host SSH Fail: Host(%s:%d) %s", host, port, err) - return false + return err } defer conn.Close() - return true + return nil +} + +// checkDomain returns an error if the domain name is not valid +// See https://tools.ietf.org/html/rfc1034#section-3.5 and +// https://tools.ietf.org/html/rfc1123#section-2. +func CheckDomain(name string) error { + switch { + case len(name) == 0: + return nil // an empty domain name will result in a cookie without a domain restriction + case len(name) > 255: + return fmt.Errorf("cookie domain: name length is %d, can't exceed 255", len(name)) + } + var l int + for i := 0; i < len(name); i++ { + b := name[i] + if b == '.' { + // check domain labels validity + switch { + case i == l: + return fmt.Errorf("cookie domain: invalid character '%c' at offset %d: label can't begin with a period", b, i) + case i-l > 63: + return fmt.Errorf("cookie domain: byte length of label '%s' is %d, can't exceed 63", name[l:i], i-l) + case name[l] == '-': + return fmt.Errorf("cookie domain: label '%s' at offset %d begins with a hyphen", name[l:i], l) + case name[i-1] == '-': + return fmt.Errorf("cookie domain: label '%s' at offset %d ends with a hyphen", name[l:i], l) + } + l = i + 1 + continue + } + // test label character validity, note: tests are ordered by decreasing validity frequency + if !(b >= 'a' && b <= 'z' || b >= '0' && b <= '9' || b == '-' || b >= 'A' && b <= 'Z') { + // show the printable unicode character starting at byte offset i + c, _ := utf8.DecodeRuneInString(name[i:]) + if c == utf8.RuneError { + return fmt.Errorf("cookie domain: invalid rune at offset %d", i) + } + return fmt.Errorf("cookie domain: invalid character '%c' at offset %d", c, i) + } + } + // check top level domain validity + switch { + case l == len(name): + return fmt.Errorf("cookie domain: missing top level domain, domain can't end with a period") + case len(name)-l > 63: + return fmt.Errorf("cookie domain: byte length of top level domain '%s' is %d, can't exceed 63", name[l:], len(name)-l) + case name[l] == '-': + return fmt.Errorf("cookie domain: top level domain '%s' at offset %d begins with a hyphen", name[l:], l) + case name[len(name)-1] == '-': + return fmt.Errorf("cookie domain: top level domain '%s' at offset %d ends with a hyphen", name[l:], l) + case name[l] >= '0' && name[l] <= '9': + return fmt.Errorf("cookie domain: top level domain '%s' at offset %d begins with a digit", name[l:], l) + } + return nil } diff --git a/pkg/common/common_test.go b/pkg/common/common_test.go index 0d81013..a5750d1 100644 --- a/pkg/common/common_test.go +++ b/pkg/common/common_test.go @@ -11,7 +11,7 @@ func TestCheckRemoteSSH(t *testing.T) { tests := []struct { name string args args - want bool + want error }{ { name: "test check localhost ssh port", @@ -20,12 +20,12 @@ func TestCheckRemoteSSH(t *testing.T) { port: 22, timeout: 5, }, - want: true, + want: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := CheckRemoteSSH(tt.args.host, tt.args.port, tt.args.timeout); got != tt.want { + if got := CheckRemoteTCP(tt.args.host, tt.args.port, tt.args.timeout); got != tt.want { t.Errorf("CheckRemoteSSH() = %v, want %v", got, tt.want) } }) diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 0000000..634a6ce --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,64 @@ +package config + +import ( + "errors" + "io/ioutil" + "os" + "path" + + "git.trj.tw/golang/utils" + "github.com/otakukaze/envconfig" + "gopkg.in/yaml.v2" +) + +// Listen - +type Listen struct { + TargetHost string `yaml:"target_host"` + TargetPort int `yaml:"target_port"` + ListenPort int `yaml:"listen_port"` +} + +// Config - +type Config struct { + Verbose int `yaml:"verbose" env:"PROXY_VERBOSE"` + Listen []Listen `yaml:"listen"` +} + +var conf *Config + +// Load config from file and env +func Load(p ...string) (err error) { + var fp string + if len(p) > 0 && len(p[0]) > 0 { + fp = p[0] + } else { + wd, err := os.Getwd() + if err != nil { + return err + } + fp = path.Join(wd, "config.yml") + } + + fp = utils.ParsePath(fp) + + if exists := utils.CheckExists(fp, false); !exists { + return errors.New("config file not exists") + } + + b, err := ioutil.ReadFile(fp) + if err != nil { + return + } + + conf = &Config{} + err = yaml.Unmarshal(b, conf) + if err != nil { + return + } + + envconfig.Parse(conf) + return +} + +// Get config +func Get() *Config { return conf } diff --git a/pkg/listener/listener.go b/pkg/listener/listener.go new file mode 100644 index 0000000..ca42939 --- /dev/null +++ b/pkg/listener/listener.go @@ -0,0 +1,207 @@ +package listener + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "net" + "regexp" + "ssh-proxy/pkg/common" + "ssh-proxy/pkg/config" + "ssh-proxy/pkg/logger" + "ssh-proxy/pkg/proxy" + "strconv" + "strings" + + "github.com/google/uuid" +) + +type Target struct { + Addr string + Listen string +} + +type Listener struct { + Targets []Target + ctx context.Context + cancel context.CancelFunc +} + +type conn struct { + net.Conn +} + +// New listener +func New(list []config.Listen) (listener *Listener, err error) { + if len(list) == 0 { + return nil, errors.New("no listen data") + } + check := make(map[int]interface{}) + listener = &Listener{} + listener.ctx, listener.cancel = context.WithCancel(context.Background()) + listener.Targets = make([]Target, 0) + + for k, v := range list { + hosts, err := parseHost(v.TargetHost) + if err != nil || len(hosts) == 0 { + return nil, fmt.Errorf("host rule error: index: %d, %v", k, err) + } + startBindPort := v.ListenPort + for _, it := range hosts { + if _, ok := check[startBindPort]; ok { + return nil, errors.New("bind port is used") + } + target := Target{} + target.Addr = fmt.Sprintf("%s:%d", it, v.TargetPort) + target.Listen = fmt.Sprintf(":%d", startBindPort) + listener.Targets = append(listener.Targets, target) + check[startBindPort] = nil + startBindPort++ + } + + } + + check = nil + + return +} + +func parseHost(s string) (ss []string, err error) { + ip := net.ParseIP(s) + if ip != nil { + ss = append(ss, ip.String()) + return ss, nil + } + + if match, err := regexp.MatchString(`^(?:(25[0-5]|2[0-4][0-9]|1?[0-9]{2}|[0-9])\.){3}(?:(2[0-5][0-9]|1?[0-9]{2}|[0-9]))\-(?:(25[0-9]|2[0-4][0-9]|1?[0-9]{2}|[0-9]))$`, s); match && err == nil { + iprange := strings.Split(s, "-") + start, err := strconv.Atoi(strings.Split(iprange[0], ".")[3]) + if err != nil { + return nil, err + } + end, err := strconv.Atoi(iprange[1]) + if err != nil { + return nil, err + } + + if start > end { + return nil, errors.New("start > end") + } + + count := end - start + 1 + ss = make([]string, 0, count) + regex := regexp.MustCompile(`\.[0-9]{1,3}$`) + for i := 0; i < count; i++ { + nip := regex.ReplaceAllString(iprange[0], fmt.Sprintf(".%d", start+i)) + ss = append(ss, nip) + } + return ss, nil + } + + err = common.CheckDomain(s) + if err != nil { + return nil, err + } + ss = append(ss, s) + + return +} + +func listen(ctx context.Context, bindPort, addr string) error { + laddr, err := net.ResolveTCPAddr("tcp", bindPort) + if err != nil { + return err + } + + listener, err := net.ListenTCP("tcp", laddr) + if err != nil { + return err + } + defer listener.Close() + + connChan := make(chan *net.TCPConn) + + go accept(ctx, listener, connChan) + go handleConn(ctx, connChan, bindPort, addr) + + for { + select { + case <-ctx.Done(): + return nil + } + } + +} + +func accept(ctx context.Context, listener *net.TCPListener, connChan chan *net.TCPConn) { + for { + conn, err := listener.AcceptTCP() + if err != nil { + log.Println("accept error : ", err) + } else { + connChan <- conn + } + select { + case <-ctx.Done(): + return + } + } +} +func handleConn(ctx context.Context, connChan chan *net.TCPConn, bindPort, addr string) { + conf := config.Get() + for { + select { + case <-ctx.Done(): + return + case conn := <-connChan: + laddr, err := net.ResolveTCPAddr("tcp", bindPort) + if err != nil { + log.Println("resolve local bind fail : ", err) + conn.Close() + break + } + + raddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + log.Println("resolve remote addr fail : ", err) + conn.Close() + break + } + err = common.CheckRemoteTCP(raddr.IP.String(), raddr.Port, 2) + if err != nil { + log.Println("check remote tcp fail : ", err) + conn.Close() + break + } + + connid := uuid.New().String() + p := proxy.New(conn, laddr, raddr) + p.Nagles = false + p.OutputHex = true + p.Log = logger.ColorLogger{ + VeryVerbose: conf.Verbose >= 2, + Verbose: conf.Verbose >= 1, + Prefix: fmt.Sprintf("Connection: %s ", connid), + Color: true, + } + go func() { + defer conn.Close() + err := p.Start(ctx) + if err != nil || err != io.EOF { + log.Println("conn err :: ", err) + } + }() + } + } +} + +func (l *Listener) Listen() (err error) { + for _, it := range l.Targets { + go listen(l.ctx, it.Listen, it.Addr) + } + + <-l.ctx.Done() + return +} diff --git a/pkg/option/option.go b/pkg/option/option.go new file mode 100644 index 0000000..77cd8ad --- /dev/null +++ b/pkg/option/option.go @@ -0,0 +1,21 @@ +package option + +import "flag" + +// Option - +type Option struct { + Config string +} + +var opts *Option + +// Parse flags +func Parse() { + opts = &Option{} + flag.StringVar(&opts.Config, "config", "", "config file path, default: `pwd`/config.yml") + flag.StringVar(&opts.Config, "f", "", "config file path, default: `pwd`/config.yml") + flag.Parse() +} + +// Get parsed options +func Get() *Option { return opts } diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 61ecabe..fac6dab 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -49,7 +49,7 @@ type setNoDelayer interface { SetNoDelay(bool) error } -func (p *Proxy) Start() (err error) { +func (p *Proxy) Start(ctx context.Context) (err error) { defer p.lconn.Close() if p.tlsUnwrapp { @@ -76,18 +76,19 @@ func (p *Proxy) Start() (err error) { //display both ends p.Log.Info("Opened %s >>>> %s", p.laddr.String(), p.raddr.String()) - ctx, cancel := context.WithCancel(context.Background()) - //bidirectional copy go p.pipe(ctx, p.lconn, p.rconn) go p.pipe(ctx, p.rconn, p.lconn) - // wait for close - err = <-p.errsig - p.Log.Info("Closed (%d bytes send, %d bytes received)", p.sendBytes, p.receivedBytes) - cancel() - - return + select { + case <-ctx.Done(): + p.Log.Info("Context Done (%d bytes send, %d bytes received)", p.sendBytes, p.receivedBytes) + return + // wait for close + case err = <-p.errsig: + p.Log.Info("Closed (%d bytes send, %d bytes received)", p.sendBytes, p.receivedBytes) + return err + } } func (p *Proxy) err(s string, err error) {