From 1dc85c62274be5ddd05f560a96f5ea468d0b92f0 Mon Sep 17 00:00:00 2001
From: Marten Seemann <martenseemann@gmail.com>
Date: Tue, 21 Dec 2021 16:54:21 +0400
Subject: [PATCH] fix flaky tests

---
 autonat_test.go | 56 +++++++++++++++++++++++++++++--------------------
 1 file changed, 33 insertions(+), 23 deletions(-)

diff --git a/autonat_test.go b/autonat_test.go
index 1e598c2..51ab181 100644
--- a/autonat_test.go
+++ b/autonat_test.go
@@ -22,34 +22,44 @@ import (
 // these are mock service implementations for testing
 func makeAutoNATServicePrivate(t *testing.T) host.Host {
 	h := bhost.NewBlankHost(swarmt.GenSwarm(t))
-	h.SetStreamHandler(AutoNATProto, sayAutoNATPrivate)
+	h.SetStreamHandler(AutoNATProto, sayPrivateStreamHandler(t))
 	return h
 }
 
-func makeAutoNATServicePublic(t *testing.T) host.Host {
-	h := bhost.NewBlankHost(swarmt.GenSwarm(t))
-	h.SetStreamHandler(AutoNATProto, sayAutoNATPublic)
-	return h
-}
-
-func sayAutoNATPrivate(s network.Stream) {
-	defer s.Close()
-	w := protoio.NewDelimitedWriter(s)
-	res := pb.Message{
-		Type:         pb.Message_DIAL_RESPONSE.Enum(),
-		DialResponse: newDialResponseError(pb.Message_E_DIAL_ERROR, "no dialable addresses"),
+func sayPrivateStreamHandler(t *testing.T) network.StreamHandler {
+	return func(s network.Stream) {
+		defer s.Close()
+		r := protoio.NewDelimitedReader(s, network.MessageSizeMax)
+		if err := r.ReadMsg(&pb.Message{}); err != nil {
+			t.Error(err)
+			return
+		}
+		w := protoio.NewDelimitedWriter(s)
+		res := pb.Message{
+			Type:         pb.Message_DIAL_RESPONSE.Enum(),
+			DialResponse: newDialResponseError(pb.Message_E_DIAL_ERROR, "no dialable addresses"),
+		}
+		w.WriteMsg(&res)
 	}
-	w.WriteMsg(&res)
 }
 
-func sayAutoNATPublic(s network.Stream) {
-	defer s.Close()
-	w := protoio.NewDelimitedWriter(s)
-	res := pb.Message{
-		Type:         pb.Message_DIAL_RESPONSE.Enum(),
-		DialResponse: newDialResponseOK(s.Conn().RemoteMultiaddr()),
-	}
-	w.WriteMsg(&res)
+func makeAutoNATServicePublic(t *testing.T) host.Host {
+	h := bhost.NewBlankHost(swarmt.GenSwarm(t))
+	h.SetStreamHandler(AutoNATProto, func(s network.Stream) {
+		defer s.Close()
+		r := protoio.NewDelimitedReader(s, network.MessageSizeMax)
+		if err := r.ReadMsg(&pb.Message{}); err != nil {
+			t.Error(err)
+			return
+		}
+		w := protoio.NewDelimitedWriter(s)
+		res := pb.Message{
+			Type:         pb.Message_DIAL_RESPONSE.Enum(),
+			DialResponse: newDialResponseOK(s.Conn().RemoteMultiaddr()),
+		}
+		w.WriteMsg(&res)
+	})
+	return h
 }
 
 func makeAutoNAT(t *testing.T, ash host.Host) (host.Host, AutoNAT) {
@@ -173,7 +183,7 @@ func TestAutoNATPublictoPrivate(t *testing.T) {
 	)
 	expectEvent(t, s, network.ReachabilityPublic)
 
-	hs.SetStreamHandler(AutoNATProto, sayAutoNATPrivate)
+	hs.SetStreamHandler(AutoNATProto, sayPrivateStreamHandler(t))
 	hps := makeAutoNATServicePrivate(t)
 	connect(t, hps, hc)
 	identifyAsServer(hps, hc)