diff --git a/udp/congestion/newreno.go b/udp/congestion/newreno.go index 7234c7a..c1e3448 100644 --- a/udp/congestion/newreno.go +++ b/udp/congestion/newreno.go @@ -80,9 +80,12 @@ func NewNewReno() *NewReno { } func (c *NewReno) ReceivedPacket(seq, nack, ack uint32) { + // decide what acks and nacks to send if seq != 0 { c.receivedSequence(seq) } + + // decide how window size was affected if nack != 0 { c.receivedNack(nack) } @@ -147,16 +150,21 @@ func (c *NewReno) receivedSequence(seq uint32) { c.ackNackMu.Lock() defer c.ackNackMu.Unlock() - if seq != c.ack+1 && seq != c.nack+1 { - if seq > c.ack && seq > c.nack { - c.awaitingAck = append(c.awaitingAck, flightInfo{ - time: time.Now(), - sequence: seq, - }) - } // else discard as it's already been cumulatively ACKed/NACKed - return // if this seq doesn't change the ack field, awaitingAck will be unchanged + if seq < c.nack || seq < c.ack { + // packet received out of order has already been cumulatively NACKed + // or duplicate packet received and already ACKed + return } + if seq != c.ack+1 && seq != c.nack+1 { + c.awaitingAck = append(c.awaitingAck, flightInfo{ + time: time.Now(), + sequence: seq, + }) + return // if this seq doesn't change the ack field, awaitingAck will still not do anything useful + } + + sort.Sort(c.awaitingAck) c.updateAck(seq) } @@ -169,14 +177,15 @@ func (c *NewReno) checkNack() { } sort.Sort(c.awaitingAck) + rtt := time.Duration(c.rttNanos * RttLossDelay) - - if !c.awaitingAck[0].time.Before(time.Now().Add(-rtt)) { - return + if c.awaitingAck[0].time.Before(time.Now().Add(-rtt)) { + // if the next packet sequence to ack was received more than an rttlossdelay ago + // mark the packet(s) blocking it as missing with a nack + // then update ack from the delayed packet + c.nack = c.awaitingAck[0].sequence - 1 + c.updateAck(c.nack) } - - c.nack = c.awaitingAck[0].sequence - 1 - c.updateAck(c.nack) } func (c *NewReno) updateAck(a uint32) { @@ -184,8 +193,8 @@ func (c *NewReno) updateAck(a uint32) { var e flightInfo for i, e = range c.awaitingAck { - if a+1 == e.sequence { - a += 1 + if e.sequence == a+1 { + a = e.sequence } else { break } @@ -195,7 +204,6 @@ func (c *NewReno) updateAck(a uint32) { c.awaitingAck = c.awaitingAck[i:] } -// It is assumed that ReceivedNack will only be called by one thread func (c *NewReno) receivedNack(nack uint32) { c.ackNackMu.Lock() defer c.ackNackMu.Unlock() @@ -209,14 +217,14 @@ func (c *NewReno) receivedNack(nack uint32) { i++ } - c.slowStart = false if i == 0 { return } - c.inFlight = c.inFlight[i-1:] + c.slowStart = false + c.inFlight = c.inFlight[i:] - for i > 0 { + for { s := c.windowSize if s > 1 && atomic.CompareAndSwapUint32(&c.windowSize, s, s/2) { break @@ -229,7 +237,6 @@ func (c *NewReno) receivedNack(nack uint32) { } } -// It is assumed that ReceivedAck will only be called by one thread func (c *NewReno) receivedAck(ack uint32) { c.ackNackMu.Lock() defer c.ackNackMu.Unlock() diff --git a/udp/congestion/newreno_test.go b/udp/congestion/newreno_test.go index 7c37688..44e6f6b 100644 --- a/udp/congestion/newreno_test.go +++ b/udp/congestion/newreno_test.go @@ -22,7 +22,7 @@ type newRenoTest struct { halfRtt time.Duration } -func newNewRenoTest(halfRtt time.Duration) *newRenoTest { +func newNewRenoTest(rtt time.Duration) *newRenoTest { return &newRenoTest{ sideA: NewNewReno(), sideB: NewNewReno(), @@ -33,7 +33,7 @@ func newNewRenoTest(halfRtt time.Duration) *newRenoTest { aInbound: make(chan congestionPacket), bInbound: make(chan congestionPacket), - halfRtt: halfRtt, + halfRtt: rtt / 2, } } @@ -126,20 +126,21 @@ func TestNewReno_Congestion(t *testing.T) { t.Run("OneWay", func(t *testing.T) { t.Run("Lossless", func(t *testing.T) { // ASSIGN + rtt := 80*time.Millisecond + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - c := newNewRenoTest(10 * time.Millisecond) + c := newNewRenoTest(rtt) c.Start(ctx) c.RunSideA(ctx) c.RunSideB(ctx) // ACT - for i := 0; i < 20; i++ { - t1 := time.Now() + for i := 0; i < 50; i++ { + // sleep to simulate preparing packet + time.Sleep(1*time.Millisecond) seq := c.sideA.Sequence() - t2 := time.Now() - fmt.Printf("waited %dms for sequence\n", t2.Sub(t1).Milliseconds()) c.aOutbound <- congestionPacket{ seq: seq, @@ -153,17 +154,11 @@ func TestNewReno_Congestion(t *testing.T) { // ASSERT - assert.InDelta(t, - float64(2*10*time.Millisecond.Nanoseconds()), - c.sideA.rttNanos, - float64(time.Millisecond.Nanoseconds()), - ) - assert.Equal(t, uint32(0), c.sideA.nack) assert.Equal(t, uint32(0), c.sideA.ack) assert.Equal(t, uint32(0), c.sideB.nack) - assert.Equal(t, uint32(20), c.sideB.ack) + assert.Equal(t, uint32(50), c.sideB.ack) }) }) }