package listener import ( "context" "errors" "fmt" "io" "log" "net" "regexp" "strconv" "strings" "tcp-proxy/pkg/common" "tcp-proxy/pkg/config" "tcp-proxy/pkg/logger" "tcp-proxy/pkg/proxy" "github.com/google/uuid" ) type Target struct { Addr string Listen string } type Listener struct { Targets []Target ctx context.Context cancel context.CancelFunc } type conn struct { net.Conn } // New listener func New(list []config.Listen) (listener *Listener, err error) { if len(list) == 0 { return nil, errors.New("no listen data") } check := make(map[int]interface{}) listener = &Listener{} listener.ctx, listener.cancel = context.WithCancel(context.Background()) listener.Targets = make([]Target, 0) for k, v := range list { hosts, err := parseHost(v.TargetHost) if err != nil || len(hosts) == 0 { return nil, fmt.Errorf("host rule error: index: %d, %v", k, err) } startBindPort := v.ListenPort for _, it := range hosts { if _, ok := check[startBindPort]; ok { return nil, errors.New("bind port is used") } target := Target{} target.Addr = fmt.Sprintf("%s:%d", it, v.TargetPort) target.Listen = fmt.Sprintf(":%d", startBindPort) listener.Targets = append(listener.Targets, target) check[startBindPort] = nil startBindPort++ } } check = nil return } func parseHost(s string) (ss []string, err error) { ip := net.ParseIP(s) if ip != nil { ss = append(ss, ip.String()) return ss, nil } if match, err := regexp.MatchString(`^(?:(25[0-5]|2[0-4][0-9]|1?[0-9]{2}|[0-9])\.){3}(?:(2[0-5][0-9]|1?[0-9]{2}|[0-9]))\-(?:(25[0-9]|2[0-4][0-9]|1?[0-9]{2}|[0-9]))$`, s); match && err == nil { iprange := strings.Split(s, "-") start, err := strconv.Atoi(strings.Split(iprange[0], ".")[3]) if err != nil { return nil, err } end, err := strconv.Atoi(iprange[1]) if err != nil { return nil, err } if start > end { return nil, errors.New("start > end") } count := end - start + 1 ss = make([]string, 0, count) regex := regexp.MustCompile(`\.[0-9]{1,3}$`) for i := 0; i < count; i++ { nip := regex.ReplaceAllString(iprange[0], fmt.Sprintf(".%d", start+i)) ss = append(ss, nip) } return ss, nil } err = common.CheckDomain(s) if err != nil { return nil, err } ss = append(ss, s) return } func listen(ctx context.Context, bindPort, addr string) error { laddr, err := net.ResolveTCPAddr("tcp", bindPort) if err != nil { return err } listener, err := net.ListenTCP("tcp", laddr) if err != nil { return err } defer listener.Close() connChan := make(chan *net.TCPConn) go accept(ctx, listener, connChan) go handleConn(ctx, connChan, bindPort, addr) for { select { case <-ctx.Done(): return nil } } } func accept(ctx context.Context, listener *net.TCPListener, connChan chan *net.TCPConn) { for { conn, err := listener.AcceptTCP() if err != nil { log.Println("accept error : ", err) } else { connChan <- conn } select { case <-ctx.Done(): return default: } } } func handleConn(ctx context.Context, connChan chan *net.TCPConn, bindPort, addr string) { conf := config.Get() for { select { case <-ctx.Done(): return case conn := <-connChan: laddr, err := net.ResolveTCPAddr("tcp", bindPort) if err != nil { log.Println("resolve local bind fail : ", err) conn.Close() break } raddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { log.Println("resolve remote addr fail : ", err) conn.Close() break } err = common.CheckRemoteTCP(raddr.IP.String(), raddr.Port, 2) if err != nil { log.Println("check remote tcp fail : ", err) conn.Close() break } connid := uuid.New().String() p := proxy.New(conn, laddr, raddr) p.Nagles = false p.OutputHex = true p.Log = logger.ColorLogger{ VeryVerbose: conf.Verbose >= 2, Verbose: conf.Verbose >= 1, Prefix: fmt.Sprintf("Connection: %s ", connid), Color: true, } go func() { defer conn.Close() err := p.Start(ctx) if err != nil || err != io.EOF { log.Println("conn err :: ", err) } }() } } } func (l *Listener) Listen() (err error) { for _, it := range l.Targets { go listen(l.ctx, it.Listen, it.Addr) } <-l.ctx.Done() return }