Skip to content

Commit

Permalink
Prevent stuck connections after local reply in proxy-wasm
Browse files Browse the repository at this point in the history
This change fixes envoyproxy#28826.
Some additional discussions for context can be found in
proxy-wasm/proxy-wasm-cpp-host#423.

The issue reported in envoyproxy#28826
happens when proxy-wasm plugin calls proxy_send_local_response during
the HTTP request proessing and HTTP response processing.

This happens because in attempt to mitigate a use-after-free issue
(see envoyproxy#23049) we added logic
to proxy-wasm that avoids calling sendLocalReply multiple times.

So now when proxy-wasm plugin calls proxy_send_local_response only
the first call will result in sendLocalReply, while all subsequent
calls will get ignored. At the same time, when proxy-wasm plugins
call proxy_send_local_response, because it's used to report an
error in the plugin, proxy-wasm also stops iteration.

During HTTP request processing this leads to the following chain
of events:

1. During request proxy-wasm plugin calls proxy_send_local_response
2. proxy_send_local_response calls sendLocalReply, which schedules
   the local reply to be processed later through the filter chain
3. Request processing filter chain gets aborted and Envoy sends the
   previous created local reply though the filter chain
4. Proxy-wasm plugin gets called to process the response it generated
   and it calls proxy_send_local_response
5. proxy_send_local_response **does not** call sendLocalReply, because
   proxy-wasm prevents multiple calls to sendLocalReply currently
6. proxy-wasm stops iteration

So in the end the filter chain iteration is stopped for the response
and because proxy_send_local_respose does not actually call
sendLocalReply we don't send another locally generated response
either.

I think we can do slightly better and close the connection in this
case. This change includes the following parts:

1. Partial rollback of envoyproxy#23049
2. Tests covering this case and some other using the actual FilterManager.

The most important question is why rolling back
envoyproxy#23049 now is safe?

The reason why it's safe, is that since introduction of
prepareLocalReplyViaFilterChain in
envoyproxy#24367, calling sendLocalReply
multiple times is safe - that PR basically address the issue in a generic
way for all the plugins, so a proxy-wasm specific fix is not needed anymore.

On top of being safe, there are additional benefits to making this change:

1. We don't end up with a stuck connection in case of errors, which is
   slightly better
2. We remove a failure mode from proxy_send_local_response that was
   introduced in envoyproxy#23049 - which
   is good, because proxy-wasm plugins don't have a good fallback when
   proxy_send_local_response is failing.

Finally, why replace the current mocks with a real FilterManager?

Mock implementation of sendLocalReply works fine for tests that just need
to assert that sendLocalReply gets called. However, in this case we rely
on the fact that it's safe to call sendLocalReply multiple times and it
will do the right thing and we want to assert that the connection will
get closed in the end - that cannot be tested by just checking that the
sendLocalReply gets called or by relying on a simplistic mock
implementation of sendLocalReply.

Signed-off-by: Mikhail Krinkin <krinkin.m.u@gmail.com>
  • Loading branch information
krinkinmu committed Oct 24, 2024
1 parent 742a3b0 commit 1c1a86b
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 53 deletions.
13 changes: 0 additions & 13 deletions source/extensions/common/wasm/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1595,12 +1595,6 @@ void Context::failStream(WasmStreamType stream_type) {
WasmResult Context::sendLocalResponse(uint32_t response_code, std::string_view body_text,
Pairs additional_headers, uint32_t grpc_status,
std::string_view details) {
// This flag is used to avoid calling sendLocalReply() twice, even if wasm code has this
// logic. We can't reuse "local_reply_sent_" here because it can't avoid calling nested
// sendLocalReply() during encodeHeaders().
if (local_reply_hold_) {
return WasmResult::BadArgument;
}
// "additional_headers" is a collection of string_views. These will no longer
// be valid when "modify_headers" is finally called below, so we must
// make copies of all the headers.
Expand All @@ -1625,11 +1619,6 @@ WasmResult Context::sendLocalResponse(uint32_t response_code, std::string_view b
modify_headers = std::move(modify_headers), grpc_status,
details = StringUtil::replaceAllEmptySpace(
absl::string_view(details.data(), details.size()))] {
// When the wasm vm fails, failStream() is called if the plugin is fail-closed, we need
// this flag to avoid calling sendLocalReply() twice.
if (local_reply_sent_) {
return;
}
// C++, Rust and other SDKs use -1 (InvalidCode) as the default value if gRPC code is not set,
// which should be mapped to nullopt in Envoy to prevent it from sending a grpc-status trailer
// at all.
Expand All @@ -1640,10 +1629,8 @@ WasmResult Context::sendLocalResponse(uint32_t response_code, std::string_view b
}
decoder_callbacks_->sendLocalReply(static_cast<Envoy::Http::Code>(response_code), body_text,
modify_headers, grpc_status_code, details);
local_reply_sent_ = true;
});
}
local_reply_hold_ = true;
return WasmResult::Ok;
}

Expand Down
1 change: 0 additions & 1 deletion source/extensions/common/wasm/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,6 @@ class Context : public proxy_wasm::ContextBase,
bool buffering_response_body_ = false;
bool end_of_stream_ = false;
bool local_reply_sent_ = false;
bool local_reply_hold_ = false;
ProtobufWkt::Struct temporary_metadata_;

// MB: must be a node-type map as we take persistent references to the entries.
Expand Down
1 change: 1 addition & 0 deletions test/extensions/common/wasm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ envoy_cc_test(
"//test/extensions/common/wasm/test_data:test_context_cpp_plugin",
"//test/extensions/common/wasm/test_data:test_cpp_plugin",
"//test/extensions/common/wasm/test_data:test_restriction_cpp_plugin",
"//test/mocks/local_reply:local_reply_mocks",
"//test/mocks/server:server_mocks",
"//test/test_common:environment_lib",
"//test/test_common:registry_lib",
Expand Down
59 changes: 51 additions & 8 deletions test/extensions/common/wasm/test_data/test_context_cpp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,59 @@ FilterDataStatus DupReplyContext::onRequestBody(size_t, bool) {
return FilterDataStatus::Continue;
}

class PanicReplyContext : public Context {
class LocalReplyInRequestAndResponseContext : public Context {
public:
explicit PanicReplyContext(uint32_t id, RootContext* root) : Context(id, root) {}
explicit LocalReplyInRequestAndResponseContext(uint32_t id, RootContext* root) : Context(id, root) {}
FilterHeadersStatus onRequestHeaders(uint32_t, bool) override;
FilterHeadersStatus onResponseHeaders(uint32_t, bool) override;
private:
EnvoyRootContext* root() { return static_cast<EnvoyRootContext*>(Context::root()); }
};

FilterHeadersStatus LocalReplyInRequestAndResponseContext::onRequestHeaders(uint32_t, bool) {
sendLocalResponse(200, "ok", "body", {});
return FilterHeadersStatus::Continue;
}

FilterHeadersStatus LocalReplyInRequestAndResponseContext::onResponseHeaders(uint32_t, bool) {
sendLocalResponse(200, "ok", "body", {});
return FilterHeadersStatus::Continue;
}

class PanicInRequestContext : public Context {
public:
explicit PanicInRequestContext(uint32_t id, RootContext* root) : Context(id, root) {}
FilterDataStatus onRequestBody(size_t body_buffer_length, bool end_of_stream) override;

private:
EnvoyRootContext* root() { return static_cast<EnvoyRootContext*>(Context::root()); }
};

FilterDataStatus PanicReplyContext::onRequestBody(size_t, bool) {
sendLocalResponse(200, "not send", "body", {});
int* badptr = nullptr;
*badptr = 0; // NOLINT(clang-analyzer-core.NullDereference)
FilterDataStatus PanicInRequestContext::onRequestBody(size_t, bool) {
abort();
return FilterDataStatus::Continue;
}

class PanicInResponseContext : public Context {
public:
explicit PanicInResponseContext(uint32_t id, RootContext* root) : Context(id, root) {}
FilterHeadersStatus onResponseHeaders(uint32_t, bool) override;
FilterHeadersStatus onRequestHeaders(uint32_t, bool) override;

private:
EnvoyRootContext* root() { return static_cast<EnvoyRootContext*>(Context::root()); }
};

FilterHeadersStatus PanicInResponseContext::onRequestHeaders(uint32_t, bool) {
sendLocalResponse(200, "ok", "body", {});
return FilterHeadersStatus::Continue;
}

FilterHeadersStatus PanicInResponseContext::onResponseHeaders(uint32_t, bool) {
abort();
return FilterHeadersStatus::Continue;
}

class InvalidGrpcStatusReplyContext : public Context {
public:
explicit InvalidGrpcStatusReplyContext(uint32_t id, RootContext* root) : Context(id, root) {}
Expand All @@ -127,9 +164,15 @@ FilterDataStatus InvalidGrpcStatusReplyContext::onRequestBody(size_t size, bool)
static RegisterContextFactory register_DupReplyContext(CONTEXT_FACTORY(DupReplyContext),
ROOT_FACTORY(EnvoyRootContext),
"send local reply twice");
static RegisterContextFactory register_PanicReplyContext(CONTEXT_FACTORY(PanicReplyContext),
static RegisterContextFactory register_LocalReplyInRequestAndResponseContext(CONTEXT_FACTORY(LocalReplyInRequestAndResponseContext),
ROOT_FACTORY(EnvoyRootContext),
"local reply in request and response");
static RegisterContextFactory register_PanicInRequestContext(CONTEXT_FACTORY(PanicInRequestContext),
ROOT_FACTORY(EnvoyRootContext),
"panic during request processing");
static RegisterContextFactory register_PanicInResponseContext(CONTEXT_FACTORY(PanicInResponseContext),
ROOT_FACTORY(EnvoyRootContext),
"panic after sending local reply");
"panic during response processing");

static RegisterContextFactory register_InvalidGrpcStatusReplyContext(CONTEXT_FACTORY(InvalidGrpcStatusReplyContext),
ROOT_FACTORY(EnvoyRootContext),
Expand Down
152 changes: 125 additions & 27 deletions test/extensions/common/wasm/wasm_test.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
#include "envoy/http/filter.h"
#include "envoy/http/filter_factory.h"
#include "envoy/server/lifecycle_notifier.h"

#include "source/common/common/hex.h"
#include "source/common/event/dispatcher_impl.h"
#include "source/common/http/filter_manager.h"
#include "source/common/stats/isolated_store_impl.h"
#include "source/extensions/common/wasm/wasm.h"

#include "test/extensions/common/wasm/wasm_runtime.h"
#include "test/mocks/local_reply/mocks.h"
#include "test/mocks/server/mocks.h"
#include "test/mocks/stats/mocks.h"
#include "test/mocks/upstream/mocks.h"
Expand Down Expand Up @@ -1310,7 +1314,6 @@ class WasmCommonContextTest : public Common::Wasm::WasmHttpFilterTestBase<
return new TestContext(wasm, plugin);
});
}

void setupContext() { setupFilterBase<TestContext>(); }

TestContext& rootContext() { return *static_cast<TestContext*>(root_context_); }
Expand Down Expand Up @@ -1392,43 +1395,19 @@ TEST_P(WasmCommonContextTest, EmptyContext) {
root_context_->validateConfiguration("", plugin_);
}

// test that we don't send the local reply twice, even though it's specified in the wasm code
TEST_P(WasmCommonContextTest, DuplicateLocalReply) {
std::string code;
if (std::get<0>(GetParam()) != "null") {
code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat(
"{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm")));
} else {
// The name of the Null VM plugin.
code = "CommonWasmTestContextCpp";
}
EXPECT_FALSE(code.empty());

setup(code, "context", "send local reply twice");
setupContext();
EXPECT_CALL(decoder_callbacks_, encodeHeaders_(_, _))
.WillOnce([this](Http::ResponseHeaderMap&, bool) { context().onResponseHeaders(0, false); });
EXPECT_CALL(decoder_callbacks_,
sendLocalReply(Envoy::Http::Code::OK, testing::Eq("body"), _, _, testing::Eq("ok")));

// Create in-VM context.
context().onCreate();
EXPECT_EQ(proxy_wasm::FilterDataStatus::StopIterationNoBuffer, context().onRequestBody(0, false));
}

// test that we don't send the local reply twice when the wasm code panics
TEST_P(WasmCommonContextTest, LocalReplyWhenPanic) {
std::string code;
if (std::get<0>(GetParam()) != "null") {
code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat(
"{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm")));
} else {
// no need test the Null VM plugin.
// Let's not cause crashes in Null VM
return;
}
EXPECT_FALSE(code.empty());

setup(code, "context", "panic after sending local reply");
setup(code, "context", "panic during request processing");
setupContext();
// In the case of VM failure, failStream is called, so we need to make sure that we don't send the
// local reply twice.
Expand Down Expand Up @@ -1492,6 +1471,125 @@ TEST_P(WasmCommonContextTest, ProcessValidGRPCStatusCodeAsEmptyInLocalReply) {
EXPECT_EQ(proxy_wasm::FilterDataStatus::StopIterationNoBuffer, context().onRequestBody(1, false));
}

class WasmLocalReplyTest : public WasmCommonContextTest {
public:
WasmLocalReplyTest() = default;

void setup(const std::string& code, std::string vm_configuration, std::string root_id = "") {
WasmCommonContextTest::setup(code, vm_configuration, root_id);
filter_manager_ = std::make_unique<Http::DownstreamFilterManager>(
filter_manager_callbacks_, dispatcher_, connection_, 0, nullptr, true, 10000,
filter_factory_, local_reply_, protocol_, time_source_, filter_state_, overload_manager_);
request_headers_ = Http::RequestHeaderMapPtr{
new Http::TestRequestHeaderMapImpl{{":path", "/"}, {":method", "GET"}}};
request_data_ = Envoy::Buffer::OwnedImpl("body");
}

Http::StreamFilterSharedPtr filter() { return context_; }

Http::FilterFactoryCb createWasmFilter() {
return [this](Http::FilterChainFactoryCallbacks& callbacks) {
callbacks.addStreamFilter(filter());
};
}

void setupContext() {
WasmCommonContextTest::setupContext();
ON_CALL(filter_factory_, createFilterChain(_))
.WillByDefault(Invoke([this](Http::FilterChainManager& manager) -> bool {
auto factory = createWasmFilter();
manager.applyFilterFactoryCb({}, factory);
return true;
}));
ON_CALL(filter_manager_callbacks_, requestHeaders())
.WillByDefault(Return(makeOptRef(*request_headers_)));
filter_manager_->createFilterChain();
filter_manager_->requestHeadersInitialized();
}

std::unique_ptr<Http::FilterManager> filter_manager_;
NiceMock<Http::MockFilterManagerCallbacks> filter_manager_callbacks_;
NiceMock<Event::MockDispatcher> dispatcher_;
NiceMock<Network::MockConnection> connection_;
NiceMock<Envoy::Http::MockFilterChainFactory> filter_factory_;
NiceMock<LocalReply::MockLocalReply> local_reply_;
Http::Protocol protocol_{Http::Protocol::Http2};
NiceMock<MockTimeSystem> time_source_;
StreamInfo::FilterStateSharedPtr filter_state_ =
std::make_shared<StreamInfo::FilterStateImpl>(StreamInfo::FilterState::LifeSpan::Connection);
NiceMock<Server::MockOverloadManager> overload_manager_;
Http::RequestHeaderMapPtr request_headers_;
Envoy::Buffer::OwnedImpl request_data_;
};

INSTANTIATE_TEST_SUITE_P(Runtimes, WasmLocalReplyTest,
Envoy::Extensions::Common::Wasm::runtime_and_cpp_values);

TEST_P(WasmLocalReplyTest, DuplicateLocalReply) {
std::string code;
if (std::get<0>(GetParam()) != "null") {
code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat(
"{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm")));
} else {
// Skip the Null plugin
return;
}
EXPECT_FALSE(code.empty());

setup(code, "context", "send local reply twice");
setupContext();

// Even if sendLocalReply is called multiple times it should only generate a single
// response to the client, so encodeHeaders should only be called once
EXPECT_CALL(filter_manager_callbacks_, encodeHeaders(_, _));
EXPECT_CALL(filter_manager_callbacks_, endStream());
filter_manager_->decodeHeaders(*request_headers_, false);
filter_manager_->decodeData(request_data_, false);
filter_manager_->destroyFilters();
}

TEST_P(WasmLocalReplyTest, LocalReplyInRequestAndResponse) {
std::string code;
if (std::get<0>(GetParam()) != "null") {
code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat(
"{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm")));
} else {
code = "CommonWasmTestContextCpp";
}
EXPECT_FALSE(code.empty());

setup(code, "context", "local reply in request and response");
setupContext();

EXPECT_CALL(filter_manager_callbacks_, encodeHeaders(_, _));
EXPECT_CALL(filter_manager_callbacks_, endStream());
filter_manager_->decodeHeaders(*request_headers_, false);
filter_manager_->decodeData(request_data_, false);
filter_manager_->destroyFilters();
}

TEST_P(WasmLocalReplyTest, PanicDuringResponse) {
std::string code;
if (std::get<0>(GetParam()) != "null") {
code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat(
"{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm")));
} else {
// Let's not cause crashes in Null VM
return;
}
EXPECT_FALSE(code.empty());

setup(code, "context", "panic during response processing");
setupContext();

EXPECT_CALL(filter_manager_callbacks_, encodeHeaders(_, _));
EXPECT_CALL(filter_manager_callbacks_, endStream());

filter_manager_->decodeHeaders(*request_headers_, false);
filter_manager_->decodeData(request_data_, false);
filter_manager_->destroyFilters();
}

} // namespace Wasm
} // namespace Common
} // namespace Extensions
Expand Down
8 changes: 4 additions & 4 deletions test/test_common/wasm_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,12 @@ template <typename Base = testing::Test> class WasmHttpFilterTestBase : public W
auto wasm = WasmTestBase<Base>::wasm_ ? WasmTestBase<Base>::wasm_->wasm().get() : nullptr;
int root_context_id = wasm ? wasm->getRootContext(WasmTestBase<Base>::plugin_, false)->id() : 0;
context_ =
std::make_unique<TestFilter>(wasm, root_context_id, WasmTestBase<Base>::plugin_handle_);
std::make_shared<TestFilter>(wasm, root_context_id, WasmTestBase<Base>::plugin_handle_);
context_->setDecoderFilterCallbacks(decoder_callbacks_);
context_->setEncoderFilterCallbacks(encoder_callbacks_);
}

std::unique_ptr<Context> context_;
std::shared_ptr<Context> context_;
NiceMock<Http::MockStreamDecoderFilterCallbacks> decoder_callbacks_;
NiceMock<Http::MockStreamEncoderFilterCallbacks> encoder_callbacks_;
NiceMock<Envoy::StreamInfo::MockStreamInfo> request_stream_info_;
Expand All @@ -160,12 +160,12 @@ class WasmNetworkFilterTestBase : public WasmTestBase<Base> {
auto wasm = WasmTestBase<Base>::wasm_ ? WasmTestBase<Base>::wasm_->wasm().get() : nullptr;
int root_context_id = wasm ? wasm->getRootContext(WasmTestBase<Base>::plugin_, false)->id() : 0;
context_ =
std::make_unique<TestFilter>(wasm, root_context_id, WasmTestBase<Base>::plugin_handle_);
std::make_shared<TestFilter>(wasm, root_context_id, WasmTestBase<Base>::plugin_handle_);
context_->initializeReadFilterCallbacks(read_filter_callbacks_);
context_->initializeWriteFilterCallbacks(write_filter_callbacks_);
}

std::unique_ptr<Context> context_;
std::shared_ptr<Context> context_;
NiceMock<Network::MockReadFilterCallbacks> read_filter_callbacks_;
NiceMock<Network::MockWriteFilterCallbacks> write_filter_callbacks_;
};
Expand Down

0 comments on commit 1c1a86b

Please sign in to comment.