package tcp import ( "bufio" "encoding/binary" "fmt" "io" "mpbl3p/proxy" "mpbl3p/shared" "net" "sync" "time" ) type Conn interface { Read(b []byte) (n int, err error) Write(b []byte) (n int, err error) SetWriteDeadline(time.Time) error // For printing LocalAddr() net.Addr RemoteAddr() net.Addr } type InitiatedFlow struct { Local func() string Remote string mu sync.RWMutex Flow } func (f *InitiatedFlow) String() string { return fmt.Sprintf("TcpOutbound{%v -> %v}", f.Local(), f.Remote) } type Flow struct { conn Conn isAlive bool toConsume, produced chan []byte consumeErrors, produceErrors chan error } func NewFlow() Flow { return Flow{ toConsume: make(chan []byte), produced: make(chan []byte), consumeErrors: make(chan error), produceErrors: make(chan error), } } func NewFlowConn(conn Conn) Flow { f := Flow{ conn: conn, isAlive: true, toConsume: make(chan []byte), produced: make(chan []byte), consumeErrors: make(chan error), produceErrors: make(chan error), } go f.produceMarshalled() go f.consumeMarshalled() return f } func (f Flow) String() string { return fmt.Sprintf("TcpInbound{%v -> %v}", f.conn.RemoteAddr(), f.conn.LocalAddr()) } func (f *Flow) IsAlive() bool { return f.isAlive } func InitiateFlow(local func() string, remote string) (*InitiatedFlow, error) { f := InitiatedFlow{ Local: local, Remote: remote, Flow: NewFlow(), } return &f, nil } func (f *InitiatedFlow) Reconnect() error { f.mu.Lock() defer f.mu.Unlock() if f.isAlive { return nil } localAddr, err := net.ResolveTCPAddr("tcp", f.Local()) if err != nil { return err } remoteAddr, err := net.ResolveTCPAddr("tcp", f.Remote) if err != nil { return err } conn, err := net.DialTCP("tcp", localAddr, remoteAddr) if err != nil { return err } if err := conn.SetWriteBuffer(0); err != nil { return err } f.conn = conn f.isAlive = true go f.produceMarshalled() go f.consumeMarshalled() return nil } func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { f.mu.RLock() defer f.mu.RUnlock() return f.Flow.Consume(p, g) } func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { f.mu.RLock() defer f.mu.RUnlock() return f.Flow.Produce(v) } func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) error { if !f.isAlive { return shared.ErrDeadConnection } select { case err := <-f.consumeErrors: f.isAlive = false return err default: } marshalled := p.Marshal() data := proxy.AppendMac(marshalled, g) prefixedData := make([]byte, len(data)+4) binary.LittleEndian.PutUint32(prefixedData, uint32(len(data))) copy(prefixedData[4:], data) f.toConsume <- prefixedData return nil } func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { if !f.isAlive { return nil, shared.ErrDeadConnection } var data []byte select { case data = <-f.produced: case err := <-f.produceErrors: f.isAlive = false return nil, err } b, err := proxy.StripMac(data, v) if err != nil { return nil, err } return proxy.SimplePacket(b), nil } func (f *Flow) consumeMarshalled() { for { data := <-f.toConsume err := f.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err != nil { f.consumeErrors <- err return } _, err = f.conn.Write(data) if err != nil { f.consumeErrors <- err return } } } func (f *Flow) produceMarshalled() { buf := bufio.NewReader(f.conn) for { lengthBytes := make([]byte, 4) if n, err := io.LimitReader(buf, 4).Read(lengthBytes); err != nil { f.produceErrors <- err return } else if n != 4 { f.produceErrors <- shared.ErrNotEnoughBytes return } length := binary.LittleEndian.Uint32(lengthBytes) dataBytes := make([]byte, length) var read uint32 for read < length { if n, err := io.LimitReader(buf, int64(length-read)).Read(dataBytes[read:]); err != nil { f.produceErrors <- err return } else { read += uint32(n) } } f.produced <- dataBytes } }