first version, unittest not fin

This commit is contained in:
Jay 2020-01-14 09:58:05 +00:00
parent c15c1c0421
commit ae33d3da00
11 changed files with 458 additions and 62 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/config.yml

6
config/config.yml Normal file
View File

@ -0,0 +1,6 @@
# verbose level 0=info, 1=verbose, 2=very verbose
verbose: 2
listen:
- target_host: 10.0.1.2-250
target_port: 22
listen_port: 22002

4
go.mod
View File

@ -3,6 +3,10 @@ module ssh-proxy
go 1.13 go 1.13
require ( require (
git.trj.tw/golang/utils v0.0.0-20190225142552-b019626f0349
github.com/google/uuid v1.1.1
github.com/mattn/go-colorable v0.1.4 // indirect github.com/mattn/go-colorable v0.1.4 // indirect
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b
github.com/otakukaze/envconfig v1.0.0
gopkg.in/yaml.v2 v2.2.7
) )

9
go.sum
View File

@ -1,8 +1,17 @@
git.trj.tw/golang/utils v0.0.0-20190225142552-b019626f0349 h1:V6ifeiJ3ExnjaUylTOz37n6z5uLwm6fjKjnztbTCaQI=
git.trj.tw/golang/utils v0.0.0-20190225142552-b019626f0349/go.mod h1:yE+qbsUsijCTdwsaQRkPT1CXYk7ftMzXsCaaYx/0QI0=
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4=
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/otakukaze/envconfig v1.0.0 h1:VGu7NoDFbReFO72m5Z15h5PH5gVk5fVLZ6iHF58ggFk=
github.com/otakukaze/envconfig v1.0.0/go.mod h1:v2dNv5NX1Lakw3FTAkbxYURyaiOy68M8QpMTZz+ogfs=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223 h1:DH4skfRX4EBpamg7iV4ZlCpblAHI6s6TDM39bFZumv8= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223 h1:DH4skfRX4EBpamg7iV4ZlCpblAHI6s6TDM39bFZumv8=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

121
main.go
View File

@ -1,55 +1,82 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"log" "log"
"net" "ssh-proxy/pkg/config"
"ssh-proxy/pkg/logger" "ssh-proxy/pkg/listener"
"ssh-proxy/pkg/proxy" "ssh-proxy/pkg/option"
) )
func main() { func init() {
fmt.Println("vim-go") option.Parse()
}
laddr, err := net.ResolveTCPAddr("tcp", ":9999")
if err != nil { func main() {
log.Fatal(err) var err error
} fmt.Println("tcp proxy")
opts := option.Get()
raddr, err := net.ResolveTCPAddr("tcp", "localhost:22") if opts == nil {
if err != nil { log.Fatal(errors.New("no flag parse"))
log.Fatal(err) }
}
err = config.Load(opts.Config)
listener, err := net.ListenTCP("tcp", laddr) if err != nil {
if err != nil { log.Fatal(err)
log.Fatal(err) }
}
conf := config.Get()
var connid uint64 = 0
for { listener, err := listener.New(conf.Listen)
conn, err := listener.AcceptTCP() if err != nil {
if err != nil { log.Fatal(err)
log.Println(err) }
continue
} err = listener.Listen()
connid++ if err != nil {
var p *proxy.Proxy log.Fatal(err)
}
p = proxy.New(conn, laddr, raddr)
p.Nagles = false // laddr, err := net.ResolveTCPAddr("tcp", ":9999")
p.OutputHex = true // if err != nil {
p.Log = logger.ColorLogger{ // log.Fatal(err)
VeryVerbose: true, // }
Verbose: true,
Prefix: fmt.Sprintf("Connection: #%03d ", connid), // raddr, err := net.ResolveTCPAddr("tcp", "localhost:22")
Color: true, // if err != nil {
} // log.Fatal(err)
go func() { // }
err := p.Start()
if err != nil { // listener, err := net.ListenTCP("tcp", laddr)
log.Println("conn err ::: ", err) // if err != nil {
} // log.Fatal(err)
}() // }
}
// var connid uint64 = 0
// for {
// conn, err := listener.AcceptTCP()
// if err != nil {
// log.Println(err)
// continue
// }
// connid++
// var p *proxy.Proxy
// p = proxy.New(conn, laddr, raddr)
// p.Nagles = false
// p.OutputHex = true
// p.Log = logger.ColorLogger{
// VeryVerbose: true,
// Verbose: true,
// Prefix: fmt.Sprintf("Connection: #%03d ", connid),
// Color: true,
// }
// go func() {
// err := p.Start()
// if err != nil {
// log.Println("conn err ::: ", err)
// }
// }()
// }
} }

View File

@ -1,13 +1,15 @@
package common package common
import ( import (
"fmt"
"net" "net"
"ssh-proxy/pkg/logger" "ssh-proxy/pkg/logger"
"strconv" "strconv"
"time" "time"
"unicode/utf8"
) )
func CheckRemoteSSH(host string, port int, timeout int) bool { func CheckRemoteTCP(host string, port int, timeout int) error {
log := logger.ColorLogger{} log := logger.ColorLogger{}
connTimeout := time.Duration(timeout) * time.Second connTimeout := time.Duration(timeout) * time.Second
if connTimeout <= 0 { if connTimeout <= 0 {
@ -17,8 +19,62 @@ func CheckRemoteSSH(host string, port int, timeout int) bool {
conn, err := net.DialTimeout("tcp", net.JoinHostPort(host, strconv.Itoa(port)), connTimeout) conn, err := net.DialTimeout("tcp", net.JoinHostPort(host, strconv.Itoa(port)), connTimeout)
if err != nil { if err != nil {
log.Trace("Check Remote Host SSH Fail: Host(%s:%d) %s", host, port, err) log.Trace("Check Remote Host SSH Fail: Host(%s:%d) %s", host, port, err)
return false return err
} }
defer conn.Close() defer conn.Close()
return true return nil
}
// checkDomain returns an error if the domain name is not valid
// See https://tools.ietf.org/html/rfc1034#section-3.5 and
// https://tools.ietf.org/html/rfc1123#section-2.
func CheckDomain(name string) error {
switch {
case len(name) == 0:
return nil // an empty domain name will result in a cookie without a domain restriction
case len(name) > 255:
return fmt.Errorf("cookie domain: name length is %d, can't exceed 255", len(name))
}
var l int
for i := 0; i < len(name); i++ {
b := name[i]
if b == '.' {
// check domain labels validity
switch {
case i == l:
return fmt.Errorf("cookie domain: invalid character '%c' at offset %d: label can't begin with a period", b, i)
case i-l > 63:
return fmt.Errorf("cookie domain: byte length of label '%s' is %d, can't exceed 63", name[l:i], i-l)
case name[l] == '-':
return fmt.Errorf("cookie domain: label '%s' at offset %d begins with a hyphen", name[l:i], l)
case name[i-1] == '-':
return fmt.Errorf("cookie domain: label '%s' at offset %d ends with a hyphen", name[l:i], l)
}
l = i + 1
continue
}
// test label character validity, note: tests are ordered by decreasing validity frequency
if !(b >= 'a' && b <= 'z' || b >= '0' && b <= '9' || b == '-' || b >= 'A' && b <= 'Z') {
// show the printable unicode character starting at byte offset i
c, _ := utf8.DecodeRuneInString(name[i:])
if c == utf8.RuneError {
return fmt.Errorf("cookie domain: invalid rune at offset %d", i)
}
return fmt.Errorf("cookie domain: invalid character '%c' at offset %d", c, i)
}
}
// check top level domain validity
switch {
case l == len(name):
return fmt.Errorf("cookie domain: missing top level domain, domain can't end with a period")
case len(name)-l > 63:
return fmt.Errorf("cookie domain: byte length of top level domain '%s' is %d, can't exceed 63", name[l:], len(name)-l)
case name[l] == '-':
return fmt.Errorf("cookie domain: top level domain '%s' at offset %d begins with a hyphen", name[l:], l)
case name[len(name)-1] == '-':
return fmt.Errorf("cookie domain: top level domain '%s' at offset %d ends with a hyphen", name[l:], l)
case name[l] >= '0' && name[l] <= '9':
return fmt.Errorf("cookie domain: top level domain '%s' at offset %d begins with a digit", name[l:], l)
}
return nil
} }

View File

@ -11,7 +11,7 @@ func TestCheckRemoteSSH(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want bool want error
}{ }{
{ {
name: "test check localhost ssh port", name: "test check localhost ssh port",
@ -20,12 +20,12 @@ func TestCheckRemoteSSH(t *testing.T) {
port: 22, port: 22,
timeout: 5, timeout: 5,
}, },
want: true, want: nil,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := CheckRemoteSSH(tt.args.host, tt.args.port, tt.args.timeout); got != tt.want { if got := CheckRemoteTCP(tt.args.host, tt.args.port, tt.args.timeout); got != tt.want {
t.Errorf("CheckRemoteSSH() = %v, want %v", got, tt.want) t.Errorf("CheckRemoteSSH() = %v, want %v", got, tt.want)
} }
}) })

64
pkg/config/config.go Normal file
View File

@ -0,0 +1,64 @@
package config
import (
"errors"
"io/ioutil"
"os"
"path"
"git.trj.tw/golang/utils"
"github.com/otakukaze/envconfig"
"gopkg.in/yaml.v2"
)
// Listen -
type Listen struct {
TargetHost string `yaml:"target_host"`
TargetPort int `yaml:"target_port"`
ListenPort int `yaml:"listen_port"`
}
// Config -
type Config struct {
Verbose int `yaml:"verbose" env:"PROXY_VERBOSE"`
Listen []Listen `yaml:"listen"`
}
var conf *Config
// Load config from file and env
func Load(p ...string) (err error) {
var fp string
if len(p) > 0 && len(p[0]) > 0 {
fp = p[0]
} else {
wd, err := os.Getwd()
if err != nil {
return err
}
fp = path.Join(wd, "config.yml")
}
fp = utils.ParsePath(fp)
if exists := utils.CheckExists(fp, false); !exists {
return errors.New("config file not exists")
}
b, err := ioutil.ReadFile(fp)
if err != nil {
return
}
conf = &Config{}
err = yaml.Unmarshal(b, conf)
if err != nil {
return
}
envconfig.Parse(conf)
return
}
// Get config
func Get() *Config { return conf }

207
pkg/listener/listener.go Normal file
View File

@ -0,0 +1,207 @@
package listener
import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"regexp"
"ssh-proxy/pkg/common"
"ssh-proxy/pkg/config"
"ssh-proxy/pkg/logger"
"ssh-proxy/pkg/proxy"
"strconv"
"strings"
"github.com/google/uuid"
)
type Target struct {
Addr string
Listen string
}
type Listener struct {
Targets []Target
ctx context.Context
cancel context.CancelFunc
}
type conn struct {
net.Conn
}
// New listener
func New(list []config.Listen) (listener *Listener, err error) {
if len(list) == 0 {
return nil, errors.New("no listen data")
}
check := make(map[int]interface{})
listener = &Listener{}
listener.ctx, listener.cancel = context.WithCancel(context.Background())
listener.Targets = make([]Target, 0)
for k, v := range list {
hosts, err := parseHost(v.TargetHost)
if err != nil || len(hosts) == 0 {
return nil, fmt.Errorf("host rule error: index: %d, %v", k, err)
}
startBindPort := v.ListenPort
for _, it := range hosts {
if _, ok := check[startBindPort]; ok {
return nil, errors.New("bind port is used")
}
target := Target{}
target.Addr = fmt.Sprintf("%s:%d", it, v.TargetPort)
target.Listen = fmt.Sprintf(":%d", startBindPort)
listener.Targets = append(listener.Targets, target)
check[startBindPort] = nil
startBindPort++
}
}
check = nil
return
}
func parseHost(s string) (ss []string, err error) {
ip := net.ParseIP(s)
if ip != nil {
ss = append(ss, ip.String())
return ss, nil
}
if match, err := regexp.MatchString(`^(?:(25[0-5]|2[0-4][0-9]|1?[0-9]{2}|[0-9])\.){3}(?:(2[0-5][0-9]|1?[0-9]{2}|[0-9]))\-(?:(25[0-9]|2[0-4][0-9]|1?[0-9]{2}|[0-9]))$`, s); match && err == nil {
iprange := strings.Split(s, "-")
start, err := strconv.Atoi(strings.Split(iprange[0], ".")[3])
if err != nil {
return nil, err
}
end, err := strconv.Atoi(iprange[1])
if err != nil {
return nil, err
}
if start > end {
return nil, errors.New("start > end")
}
count := end - start + 1
ss = make([]string, 0, count)
regex := regexp.MustCompile(`\.[0-9]{1,3}$`)
for i := 0; i < count; i++ {
nip := regex.ReplaceAllString(iprange[0], fmt.Sprintf(".%d", start+i))
ss = append(ss, nip)
}
return ss, nil
}
err = common.CheckDomain(s)
if err != nil {
return nil, err
}
ss = append(ss, s)
return
}
func listen(ctx context.Context, bindPort, addr string) error {
laddr, err := net.ResolveTCPAddr("tcp", bindPort)
if err != nil {
return err
}
listener, err := net.ListenTCP("tcp", laddr)
if err != nil {
return err
}
defer listener.Close()
connChan := make(chan *net.TCPConn)
go accept(ctx, listener, connChan)
go handleConn(ctx, connChan, bindPort, addr)
for {
select {
case <-ctx.Done():
return nil
}
}
}
func accept(ctx context.Context, listener *net.TCPListener, connChan chan *net.TCPConn) {
for {
conn, err := listener.AcceptTCP()
if err != nil {
log.Println("accept error : ", err)
} else {
connChan <- conn
}
select {
case <-ctx.Done():
return
}
}
}
func handleConn(ctx context.Context, connChan chan *net.TCPConn, bindPort, addr string) {
conf := config.Get()
for {
select {
case <-ctx.Done():
return
case conn := <-connChan:
laddr, err := net.ResolveTCPAddr("tcp", bindPort)
if err != nil {
log.Println("resolve local bind fail : ", err)
conn.Close()
break
}
raddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
log.Println("resolve remote addr fail : ", err)
conn.Close()
break
}
err = common.CheckRemoteTCP(raddr.IP.String(), raddr.Port, 2)
if err != nil {
log.Println("check remote tcp fail : ", err)
conn.Close()
break
}
connid := uuid.New().String()
p := proxy.New(conn, laddr, raddr)
p.Nagles = false
p.OutputHex = true
p.Log = logger.ColorLogger{
VeryVerbose: conf.Verbose >= 2,
Verbose: conf.Verbose >= 1,
Prefix: fmt.Sprintf("Connection: %s ", connid),
Color: true,
}
go func() {
defer conn.Close()
err := p.Start(ctx)
if err != nil || err != io.EOF {
log.Println("conn err :: ", err)
}
}()
}
}
}
func (l *Listener) Listen() (err error) {
for _, it := range l.Targets {
go listen(l.ctx, it.Listen, it.Addr)
}
<-l.ctx.Done()
return
}

21
pkg/option/option.go Normal file
View File

@ -0,0 +1,21 @@
package option
import "flag"
// Option -
type Option struct {
Config string
}
var opts *Option
// Parse flags
func Parse() {
opts = &Option{}
flag.StringVar(&opts.Config, "config", "", "config file path, default: `pwd`/config.yml")
flag.StringVar(&opts.Config, "f", "", "config file path, default: `pwd`/config.yml")
flag.Parse()
}
// Get parsed options
func Get() *Option { return opts }

View File

@ -49,7 +49,7 @@ type setNoDelayer interface {
SetNoDelay(bool) error SetNoDelay(bool) error
} }
func (p *Proxy) Start() (err error) { func (p *Proxy) Start(ctx context.Context) (err error) {
defer p.lconn.Close() defer p.lconn.Close()
if p.tlsUnwrapp { if p.tlsUnwrapp {
@ -76,18 +76,19 @@ func (p *Proxy) Start() (err error) {
//display both ends //display both ends
p.Log.Info("Opened %s >>>> %s", p.laddr.String(), p.raddr.String()) p.Log.Info("Opened %s >>>> %s", p.laddr.String(), p.raddr.String())
ctx, cancel := context.WithCancel(context.Background())
//bidirectional copy //bidirectional copy
go p.pipe(ctx, p.lconn, p.rconn) go p.pipe(ctx, p.lconn, p.rconn)
go p.pipe(ctx, p.rconn, p.lconn) go p.pipe(ctx, p.rconn, p.lconn)
// wait for close select {
err = <-p.errsig case <-ctx.Done():
p.Log.Info("Closed (%d bytes send, %d bytes received)", p.sendBytes, p.receivedBytes) p.Log.Info("Context Done (%d bytes send, %d bytes received)", p.sendBytes, p.receivedBytes)
cancel()
return return
// wait for close
case err = <-p.errsig:
p.Log.Info("Closed (%d bytes send, %d bytes received)", p.sendBytes, p.receivedBytes)
return err
}
} }
func (p *Proxy) err(s string, err error) { func (p *Proxy) err(s string, err error) {