package tcp import ( "encoding/binary" "errors" "io" "mpbl3p/proxy" "mpbl3p/shared" "net" "sync" "time" ) var ErrNotEnoughBytes = errors.New("not enough bytes") type Conn interface { Read(b []byte) (n int, err error) Write(b []byte) (n int, err error) SetWriteDeadline(time.Time) error } type InitiatedFlow struct { Local string Remote string mu sync.RWMutex Flow } type Flow struct { conn Conn isAlive bool } func InitiateFlow(local, remote string) (*InitiatedFlow, error) { f := InitiatedFlow{ Local: local, Remote: remote, } return &f, nil } func (f *InitiatedFlow) Reconnect() error { f.mu.Lock() defer f.mu.Unlock() if f.isAlive { return nil } localAddr, err := net.ResolveTCPAddr("tcp", f.Local) if err != nil { return err } remoteAddr, err := net.ResolveTCPAddr("tcp", f.Remote) if err != nil { return err } conn, err := net.DialTCP("tcp", localAddr, remoteAddr) if err != nil { return err } err = conn.SetWriteBuffer(0) if err != nil { return err } f.isAlive = true return nil } func (f *Flow) IsAlive() bool { return f.isAlive } func (f *InitiatedFlow) Consume(p proxy.Packet, g proxy.MacGenerator) error { f.mu.RLock() defer f.mu.RUnlock() return f.Flow.Consume(p, g) } func (f *Flow) Consume(p proxy.Packet, g proxy.MacGenerator) (err error) { if !f.isAlive { return shared.ErrDeadConnection } data := p.Marshal(g) err = f.consumeMarshalled(data) if err != nil { f.isAlive = false } return } func (f *Flow) consumeMarshalled(data []byte) error { prefixedData := make([]byte, len(data)+4) binary.LittleEndian.PutUint32(prefixedData, uint32(len(data))) copy(prefixedData[4:], data) err := f.conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err != nil { return err } _, err = f.conn.Write(prefixedData) return err } func (f *InitiatedFlow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { f.mu.RLock() defer f.mu.RUnlock() return f.Flow.Produce(v) } func (f *Flow) Produce(v proxy.MacVerifier) (proxy.Packet, error) { if !f.isAlive { return proxy.Packet{}, shared.ErrDeadConnection } data, err := f.produceMarshalled() if err != nil { f.isAlive = false return proxy.Packet{}, err } return proxy.UnmarshalPacket(data, v) } func (f *Flow) produceMarshalled() ([]byte, error) { lengthBytes := make([]byte, 4) if n, err := io.LimitReader(f.conn, 4).Read(lengthBytes); err != nil { return nil, err } else if n != 4 { return nil, ErrNotEnoughBytes } length := binary.LittleEndian.Uint32(lengthBytes) dataBytes := make([]byte, length) var read uint32 for read < length { if n, err := io.LimitReader(f.conn, int64(length-read)).Read(dataBytes[read:]); err != nil { return nil, err } else { read += uint32(n) } } return dataBytes, nil }