diff --git a/source/extensions/filters/listener/tls_inspector/tls_inspector.cc b/source/extensions/filters/listener/tls_inspector/tls_inspector.cc index abe1586de702..d5bceb8fb29b 100644 --- a/source/extensions/filters/listener/tls_inspector/tls_inspector.cc +++ b/source/extensions/filters/listener/tls_inspector/tls_inspector.cc @@ -176,8 +176,6 @@ void Filter::parseClientHello(const void* data, size_t len) { cb_->socket().setDetectedTransportProtocol(TransportSockets::TransportSocketNames::get().SSL); } else { config_->stats().tls_not_found_.inc(); - cb_->socket().setDetectedTransportProtocol( - TransportSockets::TransportSocketNames::get().RAW_BUFFER); } done(true); break; diff --git a/source/server/BUILD b/source/server/BUILD index cf876b89f16f..2c4364a87a78 100644 --- a/source/server/BUILD +++ b/source/server/BUILD @@ -66,6 +66,7 @@ envoy_cc_library( "//source/common/common:linked_object", "//source/common/common:non_copyable", "//source/common/network:connection_lib", + "//source/extensions/transport_sockets:well_known_names", ], ) diff --git a/source/server/connection_handler_impl.cc b/source/server/connection_handler_impl.cc index e32bfcb835e3..851ca337ad74 100644 --- a/source/server/connection_handler_impl.cc +++ b/source/server/connection_handler_impl.cc @@ -8,6 +8,8 @@ #include "common/network/connection_impl.h" #include "common/network/utility.h" +#include "extensions/transport_sockets/well_known_names.h" + namespace Envoy { namespace Server { @@ -148,6 +150,11 @@ void ConnectionHandlerImpl::ActiveSocket::continueFilterChain(bool success) { // prevent further redirection. new_listener->onAccept(std::move(socket_), false); } else { + // Set default transport protocol if none of the listener filters did it. + if (socket_->detectedTransportProtocol().empty()) { + socket_->setDetectedTransportProtocol( + Extensions::TransportSockets::TransportSocketNames::get().RAW_BUFFER); + } // Create a new connection on this listener. listener_.newConnection(std::move(socket_)); } diff --git a/test/extensions/filters/listener/tls_inspector/tls_inspector_test.cc b/test/extensions/filters/listener/tls_inspector/tls_inspector_test.cc index 84bffbb49715..fd1dda7cec6e 100644 --- a/test/extensions/filters/listener/tls_inspector/tls_inspector_test.cc +++ b/test/extensions/filters/listener/tls_inspector/tls_inspector_test.cc @@ -106,6 +106,8 @@ TEST_F(TlsInspectorTest, SniRegistered) { EXPECT_CALL(socket_, setDetectedTransportProtocol(absl::string_view("ssl"))); EXPECT_CALL(cb_, continueFilterChain(true)); file_event_callback_(Event::FileReadyType::Read); + EXPECT_EQ(1, cfg_->stats().tls_found_.value()); + EXPECT_EQ(1, cfg_->stats().sni_found_.value()); } // Test with the ClientHello spread over multiple socket reads. @@ -138,6 +140,8 @@ TEST_F(TlsInspectorTest, MultipleReads) { while (!got_continue) { file_event_callback_(Event::FileReadyType::Read); } + EXPECT_EQ(1, cfg_->stats().tls_found_.value()); + EXPECT_EQ(1, cfg_->stats().sni_found_.value()); } // Test that the filter correctly handles a ClientHello with no SNI present @@ -154,6 +158,8 @@ TEST_F(TlsInspectorTest, NoSni) { EXPECT_CALL(socket_, setDetectedTransportProtocol(absl::string_view("ssl"))); EXPECT_CALL(cb_, continueFilterChain(true)); file_event_callback_(Event::FileReadyType::Read); + EXPECT_EQ(1, cfg_->stats().tls_found_.value()); + EXPECT_EQ(1, cfg_->stats().sni_not_found_.value()); } // Test that the filter fails if the ClientHello is larger than the @@ -172,6 +178,7 @@ TEST_F(TlsInspectorTest, ClientHelloTooBig) { })); EXPECT_CALL(cb_, continueFilterChain(false)); file_event_callback_(Event::FileReadyType::Read); + EXPECT_EQ(1, cfg_->stats().client_hello_too_large_.value()); } // Test that the filter fails on non-SSL data @@ -188,9 +195,9 @@ TEST_F(TlsInspectorTest, NotSsl) { memcpy(buffer, data.data(), data.size()); return data.size(); })); - EXPECT_CALL(socket_, setDetectedTransportProtocol(absl::string_view("raw_buffer"))); EXPECT_CALL(cb_, continueFilterChain(true)); file_event_callback_(Event::FileReadyType::Read); + EXPECT_EQ(1, cfg_->stats().tls_not_found_.value()); } } // namespace TlsInspector diff --git a/test/server/connection_handler_test.cc b/test/server/connection_handler_test.cc index d71d15190259..448853f3ec61 100644 --- a/test/server/connection_handler_test.cc +++ b/test/server/connection_handler_test.cc @@ -445,5 +445,66 @@ TEST_F(ConnectionHandlerTest, WildcardListenerWithNoOriginalDst) { EXPECT_CALL(*listener1, onDestroy()); } +TEST_F(ConnectionHandlerTest, TransportProtocolDefault) { + TestListener* test_listener = addListener(1, true, false, "test_listener"); + Network::MockListener* listener = new Network::MockListener(); + Network::ListenerCallbacks* listener_callbacks; + EXPECT_CALL(dispatcher_, createListener_(_, _, _, false)) + .WillOnce(Invoke( + [&](Network::Socket&, Network::ListenerCallbacks& cb, bool, bool) -> Network::Listener* { + listener_callbacks = &cb; + return listener; + })); + EXPECT_CALL(test_listener->socket_, localAddress()); + handler_->addListener(*test_listener); + + Network::MockConnectionSocket* accepted_socket = new NiceMock(); + EXPECT_CALL(*accepted_socket, detectedTransportProtocol()) + .WillOnce(Return(absl::string_view(""))); + EXPECT_CALL(*accepted_socket, setDetectedTransportProtocol(absl::string_view("raw_buffer"))); + EXPECT_CALL(factory_, createNetworkFilterChain(_)).WillOnce(Return(true)); + Network::MockConnection* connection = new NiceMock(); + EXPECT_CALL(dispatcher_, createServerConnection_(_, _)).WillOnce(Return(connection)); + listener_callbacks->onAccept(Network::ConnectionSocketPtr{accepted_socket}, true); + + EXPECT_CALL(*listener, onDestroy()); +} + +TEST_F(ConnectionHandlerTest, TransportProtocolCustom) { + TestListener* test_listener = addListener(1, true, false, "test_listener"); + Network::MockListener* listener = new Network::MockListener(); + Network::ListenerCallbacks* listener_callbacks; + EXPECT_CALL(dispatcher_, createListener_(_, _, _, false)) + .WillOnce(Invoke( + [&](Network::Socket&, Network::ListenerCallbacks& cb, bool, bool) -> Network::Listener* { + listener_callbacks = &cb; + return listener; + })); + EXPECT_CALL(test_listener->socket_, localAddress()); + handler_->addListener(*test_listener); + + Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(); + EXPECT_CALL(factory_, createListenerFilterChain(_)) + .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { + manager.addAcceptFilter(Network::ListenerFilterPtr{test_filter}); + return true; + })); + absl::string_view dummy = "dummy"; + EXPECT_CALL(*test_filter, onAccept(_)) + .WillOnce(Invoke([&](Network::ListenerFilterCallbacks& cb) -> Network::FilterStatus { + cb.socket().setDetectedTransportProtocol(dummy); + return Network::FilterStatus::Continue; + })); + Network::MockConnectionSocket* accepted_socket = new NiceMock(); + EXPECT_CALL(*accepted_socket, setDetectedTransportProtocol(dummy)); + EXPECT_CALL(*accepted_socket, detectedTransportProtocol()).WillOnce(Return(dummy)); + EXPECT_CALL(factory_, createNetworkFilterChain(_)).WillOnce(Return(true)); + Network::MockConnection* connection = new NiceMock(); + EXPECT_CALL(dispatcher_, createServerConnection_(_, _)).WillOnce(Return(connection)); + listener_callbacks->onAccept(Network::ConnectionSocketPtr{accepted_socket}, true); + + EXPECT_CALL(*listener, onDestroy()); +} + } // namespace Server } // namespace Envoy