package tcp import ( "encoding/binary" "errors" "io" "mpbl3p/proxy" "net" "sync" ) var ErrNotEnoughBytes = errors.New("not enough bytes") type Conn interface { Read(b []byte) (n int, err error) Write(b []byte) (n int, err error) } type InitiatedFlow struct { Local string Remote string mu sync.RWMutex Flow } type Flow struct { conn Conn isDead bool } func InitiateFlow(local, remote string) (*InitiatedFlow, error) { f := InitiatedFlow{ Local: local, Remote: remote, Flow: Flow{isDead: true}, } return &f, nil } func (f *InitiatedFlow) Reconnect() error { f.mu.Lock() defer f.mu.Unlock() if !f.isDead { return nil } localAddr, err := net.ResolveTCPAddr("tcp", f.Local) if err != nil { return err } remoteAddr, err := net.ResolveTCPAddr("tcp", f.Remote) if err != nil { return err } f.conn, err = net.DialTCP("tcp", localAddr, remoteAddr) if err != nil { return err } f.isDead = false return nil } func (f *Flow) IsAlive() bool { return !f.isDead } func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { f.mu.RLock() defer f.mu.RUnlock() return f.Flow.Consume(p, g) } func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) (err error) { data := p.Marshal(g) err = f.consumeMarshalled(data) if err != nil { f.isDead = true } return } func (f *Flow) consumeMarshalled(data []byte) error { prefixedData := make([]byte, len(data)+4) binary.LittleEndian.PutUint32(prefixedData, uint32(len(data))) copy(prefixedData[4:], data) _, err := f.conn.Write(prefixedData) return err } func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { f.mu.RLock() defer f.mu.RUnlock() return f.Flow.Produce(v) } func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { data, err := f.produceMarshalled() if err != nil { f.isDead = true return proxy.Packet{}, err } return proxy.UnmarshalPacket(data, v) } func (f *Flow) produceMarshalled() ([]byte, error) { lengthBytes := make([]byte, 4) if n, err := io.LimitReader(f.conn, 4).Read(lengthBytes); err != nil { return nil, err } else if n != 4 { return nil, ErrNotEnoughBytes } length := binary.LittleEndian.Uint32(lengthBytes) dataBytes := make([]byte, length) var read uint32 for read < length { if n, err := io.LimitReader(f.conn, int64(length-read)).Read(dataBytes); err != nil { return nil, err } else { read += uint32(n) } } return dataBytes, nil }