209 lines
4.2 KiB
Go
209 lines
4.2 KiB
Go
package listener
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"tcp-proxy/pkg/common"
|
|
"tcp-proxy/pkg/config"
|
|
"tcp-proxy/pkg/logger"
|
|
"tcp-proxy/pkg/proxy"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type Target struct {
|
|
Addr string
|
|
Listen string
|
|
}
|
|
|
|
type Listener struct {
|
|
Targets []Target
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
type conn struct {
|
|
net.Conn
|
|
}
|
|
|
|
// New listener
|
|
func New(list []config.Listen) (listener *Listener, err error) {
|
|
if len(list) == 0 {
|
|
return nil, errors.New("no listen data")
|
|
}
|
|
check := make(map[int]interface{})
|
|
listener = &Listener{}
|
|
listener.ctx, listener.cancel = context.WithCancel(context.Background())
|
|
listener.Targets = make([]Target, 0)
|
|
|
|
for k, v := range list {
|
|
hosts, err := parseHost(v.TargetHost)
|
|
if err != nil || len(hosts) == 0 {
|
|
return nil, fmt.Errorf("host rule error: index: %d, %v", k, err)
|
|
}
|
|
startBindPort := v.ListenPort
|
|
for _, it := range hosts {
|
|
if _, ok := check[startBindPort]; ok {
|
|
return nil, errors.New("bind port is used")
|
|
}
|
|
target := Target{}
|
|
target.Addr = fmt.Sprintf("%s:%d", it, v.TargetPort)
|
|
target.Listen = fmt.Sprintf(":%d", startBindPort)
|
|
listener.Targets = append(listener.Targets, target)
|
|
check[startBindPort] = nil
|
|
startBindPort++
|
|
}
|
|
|
|
}
|
|
|
|
check = nil
|
|
|
|
return
|
|
}
|
|
|
|
func parseHost(s string) (ss []string, err error) {
|
|
ip := net.ParseIP(s)
|
|
if ip != nil {
|
|
ss = append(ss, ip.String())
|
|
return ss, nil
|
|
}
|
|
|
|
if match, err := regexp.MatchString(`^(?:(25[0-5]|2[0-4][0-9]|1?[0-9]{2}|[0-9])\.){3}(?:(2[0-5][0-9]|1?[0-9]{2}|[0-9]))\-(?:(25[0-9]|2[0-4][0-9]|1?[0-9]{2}|[0-9]))$`, s); match && err == nil {
|
|
iprange := strings.Split(s, "-")
|
|
start, err := strconv.Atoi(strings.Split(iprange[0], ".")[3])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
end, err := strconv.Atoi(iprange[1])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if start > end {
|
|
return nil, errors.New("start > end")
|
|
}
|
|
|
|
count := end - start + 1
|
|
ss = make([]string, 0, count)
|
|
regex := regexp.MustCompile(`\.[0-9]{1,3}$`)
|
|
for i := 0; i < count; i++ {
|
|
nip := regex.ReplaceAllString(iprange[0], fmt.Sprintf(".%d", start+i))
|
|
ss = append(ss, nip)
|
|
}
|
|
return ss, nil
|
|
}
|
|
|
|
err = common.CheckDomain(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ss = append(ss, s)
|
|
|
|
return
|
|
}
|
|
|
|
func listen(ctx context.Context, bindPort, addr string) error {
|
|
laddr, err := net.ResolveTCPAddr("tcp", bindPort)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
listener, err := net.ListenTCP("tcp", laddr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer listener.Close()
|
|
|
|
connChan := make(chan *net.TCPConn)
|
|
|
|
go accept(ctx, listener, connChan)
|
|
go handleConn(ctx, connChan, bindPort, addr)
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
func accept(ctx context.Context, listener *net.TCPListener, connChan chan *net.TCPConn) {
|
|
for {
|
|
conn, err := listener.AcceptTCP()
|
|
if err != nil {
|
|
log.Println("accept error : ", err)
|
|
} else {
|
|
connChan <- conn
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
func handleConn(ctx context.Context, connChan chan *net.TCPConn, bindPort, addr string) {
|
|
conf := config.Get()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case conn := <-connChan:
|
|
laddr, err := net.ResolveTCPAddr("tcp", bindPort)
|
|
if err != nil {
|
|
log.Println("resolve local bind fail : ", err)
|
|
conn.Close()
|
|
break
|
|
}
|
|
|
|
raddr, err := net.ResolveTCPAddr("tcp", addr)
|
|
if err != nil {
|
|
log.Println("resolve remote addr fail : ", err)
|
|
conn.Close()
|
|
break
|
|
}
|
|
err = common.CheckRemoteTCP(raddr.IP.String(), raddr.Port, 2)
|
|
if err != nil {
|
|
log.Println("check remote tcp fail : ", err)
|
|
conn.Close()
|
|
break
|
|
}
|
|
|
|
connid := uuid.New().String()
|
|
p := proxy.New(conn, laddr, raddr)
|
|
p.Nagles = false
|
|
p.OutputHex = true
|
|
p.Log = logger.ColorLogger{
|
|
VeryVerbose: conf.Verbose >= 2,
|
|
Verbose: conf.Verbose >= 1,
|
|
Prefix: fmt.Sprintf("Connection: %s ", connid),
|
|
Color: true,
|
|
}
|
|
go func() {
|
|
defer conn.Close()
|
|
err := p.Start(ctx)
|
|
if err != nil || err != io.EOF {
|
|
log.Println("conn err :: ", err)
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (l *Listener) Listen() (err error) {
|
|
for _, it := range l.Targets {
|
|
go listen(l.ctx, it.Listen, it.Addr)
|
|
}
|
|
|
|
<-l.ctx.Done()
|
|
return
|
|
}
|