golang实现tcp反代/转发

package main

import (
    "bytes"
    "flag"
    "io"
    "log"
    "net"
    "os"
    "strings"
    "time"
)

type formatter struct {
    w      io.Writer
    offset uint
    prefix string
    tstamp bool
}

func min(a, b int) int {
    if a < b {
        return a
    }

    return b
}

func (f *formatter) Write(p []byte) (int, error) {
    var completedBytes int
    for i := 0; i < len(p); i += 16 {
        m := min(len(p), i+16)
        err := f.format(p[i:m])
        if err != nil {
            return completedBytes, err
        }
        completedBytes += m - i
        f.offset += uint(m - i)
    }
    return len(p), nil
}

var hex = []byte("0123456789abcdef")

func (f *formatter) format(buf []byte) error {

    // prefix addr:_(hex dump)+spaces+space+bar+chars+bar+newline

    // our line is 83 characters of formatting

    plen := len(f.prefix)

    if f.tstamp {
        plen += len(time.Stamp)
    }

    llen := 83 + plen

    ptr := 0

    line := make([]byte, llen)
    if f.tstamp {
        s := time.Now().Format(time.Stamp)
        copy(line, []byte(s))
        ptr += len(time.Stamp)
    }

    copy(line[ptr:], []byte(f.prefix))
    ptr += len(f.prefix)

    offs := f.offset

    ptr += 8

    line[ptr] = ':'
    ptr--

    for offs > 0 {
        line[ptr] = hex[offs&0x0f]
        ptr--
        offs >>= 4
    }

    for ptr >= plen {
        line[ptr] = '0'
        ptr--
    }

    ptr = plen + 9
    line[ptr] = ' '
    ptr++

    for i, b := range buf {
        if i%4 == 0 {
            line[ptr] = ' '
            ptr++
        }

        line[ptr] = hex[b>>4]
        ptr++
        line[ptr] = hex[b&0x0f]
        ptr++
        line[ptr] = ' '
        ptr++

    }

    // fill in rest of line
    for i := len(buf); i < 16; i++ {
        if i%4 == 0 {
            line[ptr] = ' '
            ptr++
        }

        line[ptr] = ' '
        ptr++
        line[ptr] = ' '
        ptr++
        line[ptr] = ' '
        ptr++
    }

    line[ptr] = ' '
    ptr++
    line[ptr] = ' '
    ptr++
    line[ptr] = '|'
    ptr++

    for _, v := range buf {
        if v > 32 && v < 127 {
            line[ptr] = v
        } else {
            line[ptr] = '.'
        }
        ptr++
    }

    line[ptr] = '|'
    ptr++

    line[ptr] = '\n'
    ptr++

    _, err := f.w.Write(line[:ptr])
    return err
}

func copyStream(conn io.WriteCloser, r io.Reader) {
    _, err := io.Copy(conn, r)
    if err != nil {
        log.Printf("error during copy: %v\n", err)
    }
    err = conn.Close()
    if err != nil {
        log.Printf("error during close: %v\n", err)
    }
}

func main() {

    proxy := flag.String("p", "9999:blog.bbzhh.com:80", "proxy line -- <lport>:<rhost>:<rport>")
    tstamps := flag.Bool("t", false, "add time-stamps when proxying")
    debug := flag.Bool("d", false, "print debug information when proxying")

    flag.Parse()

    // provided a proxy line
    if *proxy != "" {
        pieces := strings.Split(*proxy, ":")
        dst := pieces[1] + ":" + pieces[2]

        fprefix := "<= "
        tprefix := "=> "

        if *tstamps {
            fprefix = " <= "
            tprefix = " => "
        }
        var (
            fin  io.Writer
            fout io.Writer
        )
        if *debug {
            fin = &formatter{os.Stdout, 0, fprefix, *tstamps}
            fout = &formatter{os.Stdout, 0, tprefix, *tstamps}
        } else {
            buf := bytes.NewBuffer(make([]byte, 0))
            fin = &formatter{buf, 0, fprefix, *tstamps}
            fout = &formatter{buf, 0, tprefix, *tstamps}
        }

        ln, e := net.Listen("tcp", ":"+pieces[0])
        if e != nil {
            log.Fatal("listen error:", e)
        }

        log.Println("tcp server starting")

        for {
            lconn, err := ln.Accept()
            if err != nil {
                log.Println(err)
                continue
            }

            go func(lconn net.Conn) {
                tl := io.TeeReader(lconn, fout)
                rconn, err := net.Dial("tcp", dst)
                if err != nil {
                    log.Println("error connectiong to", dst, ":", err)
                    err := lconn.Close()
                    if err != nil {
                        log.Printf("error closing connection: %v\n", err)
                    }
                    return
                }
                tr := io.TeeReader(rconn, fin)
                go copyStream(rconn, tl)
                go copyStream(lconn, tr)
            }(lconn)
        }
    }

    if *tstamps {
        log.Println("-t only applies when proxying, ignoring")
    }

    fout := &formatter{os.Stdout, 0, "", false}

    var fin io.Reader

    // process stdin
    if flag.NArg() == 0 {
        fin = os.Stdin
    } else {
        fname := flag.Arg(0)
        var err error
        fin, err = os.Open(fname)
        if err != nil {
            log.Fatal(err)
            return
        }
    }

    _, err := io.Copy(fout, fin)
    if err != nil {
        log.Printf("error during copy: %v\n", err)
    }
}

标签: golang