tcp-proxy/pkg/listener/listener.go
2020-01-14 14:11:31 +00:00

209 lines
4.2 KiB
Go

package listener
import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"regexp"
"strconv"
"strings"
"tcp-proxy/pkg/common"
"tcp-proxy/pkg/config"
"tcp-proxy/pkg/logger"
"tcp-proxy/pkg/proxy"
"github.com/google/uuid"
)
type Target struct {
Addr string
Listen string
}
type Listener struct {
Targets []Target
ctx context.Context
cancel context.CancelFunc
}
type conn struct {
net.Conn
}
// New listener
func New(list []config.Listen) (listener *Listener, err error) {
if len(list) == 0 {
return nil, errors.New("no listen data")
}
check := make(map[int]interface{})
listener = &Listener{}
listener.ctx, listener.cancel = context.WithCancel(context.Background())
listener.Targets = make([]Target, 0)
for k, v := range list {
hosts, err := parseHost(v.TargetHost)
if err != nil || len(hosts) == 0 {
return nil, fmt.Errorf("host rule error: index: %d, %v", k, err)
}
startBindPort := v.ListenPort
for _, it := range hosts {
if _, ok := check[startBindPort]; ok {
return nil, errors.New("bind port is used")
}
target := Target{}
target.Addr = fmt.Sprintf("%s:%d", it, v.TargetPort)
target.Listen = fmt.Sprintf(":%d", startBindPort)
listener.Targets = append(listener.Targets, target)
check[startBindPort] = nil
startBindPort++
}
}
check = nil
return
}
func parseHost(s string) (ss []string, err error) {
ip := net.ParseIP(s)
if ip != nil {
ss = append(ss, ip.String())
return ss, nil
}
if match, err := regexp.MatchString(`^(?:(25[0-5]|2[0-4][0-9]|1?[0-9]{2}|[0-9])\.){3}(?:(2[0-5][0-9]|1?[0-9]{2}|[0-9]))\-(?:(25[0-9]|2[0-4][0-9]|1?[0-9]{2}|[0-9]))$`, s); match && err == nil {
iprange := strings.Split(s, "-")
start, err := strconv.Atoi(strings.Split(iprange[0], ".")[3])
if err != nil {
return nil, err
}
end, err := strconv.Atoi(iprange[1])
if err != nil {
return nil, err
}
if start > end {
return nil, errors.New("start > end")
}
count := end - start + 1
ss = make([]string, 0, count)
regex := regexp.MustCompile(`\.[0-9]{1,3}$`)
for i := 0; i < count; i++ {
nip := regex.ReplaceAllString(iprange[0], fmt.Sprintf(".%d", start+i))
ss = append(ss, nip)
}
return ss, nil
}
err = common.CheckDomain(s)
if err != nil {
return nil, err
}
ss = append(ss, s)
return
}
func listen(ctx context.Context, bindPort, addr string) error {
laddr, err := net.ResolveTCPAddr("tcp", bindPort)
if err != nil {
return err
}
listener, err := net.ListenTCP("tcp", laddr)
if err != nil {
return err
}
defer listener.Close()
connChan := make(chan *net.TCPConn)
go accept(ctx, listener, connChan)
go handleConn(ctx, connChan, bindPort, addr)
for {
select {
case <-ctx.Done():
return nil
}
}
}
func accept(ctx context.Context, listener *net.TCPListener, connChan chan *net.TCPConn) {
for {
conn, err := listener.AcceptTCP()
if err != nil {
log.Println("accept error : ", err)
} else {
connChan <- conn
}
select {
case <-ctx.Done():
return
default:
}
}
}
func handleConn(ctx context.Context, connChan chan *net.TCPConn, bindPort, addr string) {
conf := config.Get()
for {
select {
case <-ctx.Done():
return
case conn := <-connChan:
laddr, err := net.ResolveTCPAddr("tcp", bindPort)
if err != nil {
log.Println("resolve local bind fail : ", err)
conn.Close()
break
}
raddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
log.Println("resolve remote addr fail : ", err)
conn.Close()
break
}
err = common.CheckRemoteTCP(raddr.IP.String(), raddr.Port, 2)
if err != nil {
log.Println("check remote tcp fail : ", err)
conn.Close()
break
}
connid := uuid.New().String()
p := proxy.New(conn, laddr, raddr)
p.Nagles = false
p.OutputHex = true
p.Log = logger.ColorLogger{
VeryVerbose: conf.Verbose >= 2,
Verbose: conf.Verbose >= 1,
Prefix: fmt.Sprintf("Connection: %s ", connid),
Color: true,
}
go func() {
defer conn.Close()
err := p.Start(ctx)
if err != nil || err != io.EOF {
log.Println("conn err :: ", err)
}
}()
}
}
}
func (l *Listener) Listen() (err error) {
for _, it := range l.Targets {
go listen(l.ctx, it.Listen, it.Addr)
}
<-l.ctx.Done()
return
}