diff --git a/dbms/CMakeLists.txt b/dbms/CMakeLists.txt index 8318567be72..1c903c52318 100644 --- a/dbms/CMakeLists.txt +++ b/dbms/CMakeLists.txt @@ -89,6 +89,7 @@ check_then_add_sources_compile_flag ( src/Columns/ColumnsCommon.cpp src/Columns/ColumnVector.cpp src/DataTypes/DataTypeString.cpp + src/Interpreters/Join.cpp ) list (APPEND tiflash_common_io_sources ${CONFIG_BUILD}) diff --git a/dbms/src/Common/TiFlashMetrics.h b/dbms/src/Common/TiFlashMetrics.h index d983aff40ad..6bc336418c0 100644 --- a/dbms/src/Common/TiFlashMetrics.h +++ b/dbms/src/Common/TiFlashMetrics.h @@ -88,16 +88,22 @@ namespace DB F(type_mpp_establish_conn, {{"type", "mpp_tunnel"}}), \ F(type_mpp_establish_conn_local, {{"type", "mpp_tunnel_local"}}), \ F(type_cancel_mpp_task, {{"type", "cancel_mpp_task"}})) \ - M(tiflash_exchange_data_bytes, "Total bytes sent by exchange operators", Counter, \ - F(type_hash_original, {"type", "hash_original"}), /*the original data size by hash exchange*/ \ - F(type_hash_none_compression_remote, {"type", "hash_none_compression_remote"}), /*the remote exchange data size by hash partition with no compression*/\ - F(type_hash_none_compression_local, {"type", "hash_none_compression_local"}), /*the local exchange data size by hash partition with no compression*/ \ - F(type_hash_lz4_compression, {"type", "hash_lz4_compression"}), /*the exchange data size by hash partition with lz4 compression*/ \ - F(type_hash_zstd_compression, {"type", "hash_zstd_compression"}), /*the exchange data size by hash partition with zstd compression*/ \ - F(type_broadcast_passthrough_original, {"type", "broadcast_passthrough_original"}), /*the original exchange data size by broadcast/passthough*/ \ - F(type_broadcast_passthrough_none_compression_local, {"type", "broadcast_passthrough_none_compression_local"}), /*the local exchange data size by broadcast/passthough with no compression*/ \ - F(type_broadcast_passthrough_none_compression_remote, {"type", "broadcast_passthrough_none_compression_remote"}), /*the remote exchange data size by broadcast/passthough with no compression*/ \ - ) \ + M(tiflash_exchange_data_bytes, "Total bytes sent by exchange operators", Counter, \ + F(type_hash_original, {"type", "hash_original"}), \ + F(type_hash_none_compression_remote, {"type", "hash_none_compression_remote"}), \ + F(type_hash_none_compression_local, {"type", "hash_none_compression_local"}), \ + F(type_hash_lz4_compression, {"type", "hash_lz4_compression"}), \ + F(type_hash_zstd_compression, {"type", "hash_zstd_compression"}), \ + F(type_broadcast_original, {"type", "broadcast_original"}), \ + F(type_broadcast_none_compression_local, {"type", "broadcast_none_compression_local"}), \ + F(type_broadcast_none_compression_remote, {"type", "broadcast_none_compression_remote"}), \ + F(type_broadcast_lz4_compression, {"type", "broadcast_lz4_compression"}), \ + F(type_broadcast_zstd_compression, {"type", "broadcast_zstd_compression"}), \ + F(type_passthrough_original, {"type", "passthrough_original"}), \ + F(type_passthrough_none_compression_local, {"type", "passthrough_none_compression_local"}), \ + F(type_passthrough_none_compression_remote, {"type", "passthrough_none_compression_remote"}), \ + F(type_passthrough_lz4_compression, {"type", "passthrough_lz4_compression"}), \ + F(type_passthrough_zstd_compression, {"type", "passthrough_zstd_compression"})) \ M(tiflash_schema_version, "Current version of tiflash cached schema", Gauge) \ M(tiflash_schema_applying, "Whether the schema is applying or not (holding lock)", Gauge) \ M(tiflash_schema_apply_count, "Total number of each kinds of apply", Counter, F(type_diff, {"type", "diff"}), \ diff --git a/dbms/src/Flash/Coprocessor/CHBlockChunkCodecV1.cpp b/dbms/src/Flash/Coprocessor/CHBlockChunkCodecV1.cpp index 5840b0e8e57..5eb75dbcfe3 100644 --- a/dbms/src/Flash/Coprocessor/CHBlockChunkCodecV1.cpp +++ b/dbms/src/Flash/Coprocessor/CHBlockChunkCodecV1.cpp @@ -248,6 +248,10 @@ struct CHBlockChunkCodecV1Impl { return encodeImpl(blocks, compression_method); } + CHBlockChunkCodecV1::EncodeRes encode(std::vector && blocks, CompressionMethod compression_method) + { + return encodeImpl(std::move(blocks), compression_method); + } static const ColumnPtr & toColumnPtr(const Columns & c, size_t index) { @@ -269,6 +273,10 @@ struct CHBlockChunkCodecV1Impl { return block.getByPosition(index).column; } + static ColumnPtr toColumnPtr(Block && block, size_t index) + { + return std::move(block.getByPosition(index).column); + } template static size_t getRows(ColumnsHolder && columns_holder) @@ -349,6 +357,13 @@ struct CHBlockChunkCodecV1Impl return encodeColumnImpl(block, ostr_ptr); } void encodeColumn(const std::vector & blocks, WriteBuffer * ostr_ptr) + { + for (auto && block : blocks) + { + encodeColumnImpl(block, ostr_ptr); + } + } + void encodeColumn(std::vector && blocks, WriteBuffer * ostr_ptr) { for (auto && block : blocks) { @@ -495,6 +510,19 @@ CHBlockChunkCodecV1::EncodeRes CHBlockChunkCodecV1::encode(const std::vector && blocks, CompressionMethod compression_method, bool check_schema) +{ + if (check_schema) + { + for (auto && block : blocks) + { + checkSchema(header, block); + } + } + + return CHBlockChunkCodecV1Impl{*this}.encode(std::move(blocks), compression_method); +} + static Block decodeCompression(const Block & header, ReadBuffer & istr) { size_t decoded_rows{}; @@ -504,6 +532,22 @@ static Block decodeCompression(const Block & header, ReadBuffer & istr) return decoded_block; } +template +extern size_t CompressionEncode( + std::string_view source, + const CompressionSettings & compression_settings, + Buffer & compressed_buffer); + +CHBlockChunkCodecV1::EncodeRes CHBlockChunkCodecV1::encode(std::string_view str, CompressionMethod compression_method) +{ + assert(compression_method != CompressionMethod::NONE); + + String compressed_buffer; + size_t compressed_size = CompressionEncode(str, CompressionSettings(compression_method), compressed_buffer); + compressed_buffer.resize(compressed_size); + return compressed_buffer; +} + Block CHBlockChunkCodecV1::decode(const Block & header, std::string_view str) { assert(!str.empty()); diff --git a/dbms/src/Flash/Coprocessor/CHBlockChunkCodecV1.h b/dbms/src/Flash/Coprocessor/CHBlockChunkCodecV1.h index 76331ce8314..4b3ddc35ba2 100644 --- a/dbms/src/Flash/Coprocessor/CHBlockChunkCodecV1.h +++ b/dbms/src/Flash/Coprocessor/CHBlockChunkCodecV1.h @@ -53,7 +53,9 @@ struct CHBlockChunkCodecV1 : boost::noncopyable EncodeRes encode(std::vector && columns, CompressionMethod compression_method); EncodeRes encode(const Block & block, CompressionMethod compression_method, bool check_schema = true); EncodeRes encode(const std::vector & blocks, CompressionMethod compression_method, bool check_schema = true); + EncodeRes encode(std::vector && blocks, CompressionMethod compression_method, bool check_schema = true); // + static EncodeRes encode(std::string_view str, CompressionMethod compression_method); static Block decode(const Block & header, std::string_view str); }; diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_block_chunk_codec.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_block_chunk_codec.cpp index d9f7a5f5c78..18e26634f57 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_block_chunk_codec.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_block_chunk_codec.cpp @@ -126,6 +126,41 @@ TEST(CHBlockChunkCodec, ChunkCodecV1) auto decoded_block = CHBlockChunkCodecV1::decode(header, str); ASSERT_EQ(total_rows, decoded_block.rows()); } + { + std::vector blocks_to_move; + blocks_to_move.reserve(blocks.size()); + for (auto && block : blocks) + { + blocks_to_move.emplace_back(block); + } + for (auto && block : blocks_to_move) + { + for (auto && col : block) + { + ASSERT_TRUE(col.column); + } + } + auto codec = CHBlockChunkCodecV1{ + header, + }; + auto str = codec.encode(std::move(blocks_to_move), mode); + for (auto && block : blocks_to_move) + { + ASSERT_EQ(block.rows(), 0); + } + ASSERT_FALSE(str.empty()); + ASSERT_EQ(codec.encoded_rows, total_rows); + + if (mode == CompressionMethod::NONE) + ASSERT_EQ(codec.compressed_size, 0); + else + ASSERT_NE(codec.compressed_size, 0); + + ASSERT_NE(codec.original_size, 0); + + auto decoded_block = CHBlockChunkCodecV1::decode(header, str); + ASSERT_EQ(total_rows, decoded_block.rows()); + } { auto columns = prepareBlock(rows).getColumns(); auto codec = CHBlockChunkCodecV1{ @@ -179,5 +214,19 @@ TEST(CHBlockChunkCodec, ChunkCodecV1) } test_enocde_release_data(std::move(batch_columns), header, total_rows); } + { + auto source_str = CHBlockChunkCodecV1{header}.encode(blocks.front(), CompressionMethod::NONE); + ASSERT_FALSE(source_str.empty()); + ASSERT_EQ(static_cast(source_str[0]), CompressionMethodByte::NONE); + + for (auto mode : {CompressionMethod::LZ4, CompressionMethod::ZSTD}) + { + auto compressed_str_a = CHBlockChunkCodecV1::encode({&source_str[1], source_str.size() - 1}, mode); + auto compressed_str_b = CHBlockChunkCodecV1{header}.encode(blocks.front(), mode); + + ASSERT_EQ(compressed_str_a, compressed_str_b); + } + } } + } // namespace DB::tests diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp index 6e5dc744628..121027a6619 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp @@ -86,17 +86,48 @@ struct MockWriter return summary; } - void broadcastOrPassThroughWrite(Blocks & blocks) + void broadcastOrPassThroughWriteV0(Blocks & blocks) { auto && packet = MPPTunnelSetHelper::ToPacketV0(blocks, result_field_types); ++total_packets; if (!packet) return; - if (!packet->packet.chunks().empty()) - total_bytes += packet->packet.ByteSizeLong(); + total_bytes += packet->packet.ByteSizeLong(); queue->push(std::move(packet)); } + + void broadcastWrite(Blocks & blocks) + { + return broadcastOrPassThroughWriteV0(blocks); + } + void passThroughWrite(Blocks & blocks) + { + return broadcastOrPassThroughWriteV0(blocks); + } + void broadcastOrPassThroughWrite(Blocks & blocks, MPPDataPacketVersion version, CompressionMethod compression_method) + { + if (version == MPPDataPacketV0) + return broadcastOrPassThroughWriteV0(blocks); + + size_t original_size{}; + auto && packet = MPPTunnelSetHelper::ToPacket(std::move(blocks), version, compression_method, original_size); + ++total_packets; + if (!packet) + return; + + total_bytes += packet->packet.ByteSizeLong(); + queue->push(std::move(packet)); + } + void broadcastWrite(Blocks & blocks, MPPDataPacketVersion version, CompressionMethod compression_method) + { + return broadcastOrPassThroughWrite(blocks, version, compression_method); + } + void passThroughWrite(Blocks & blocks, MPPDataPacketVersion version, CompressionMethod compression_method) + { + return broadcastOrPassThroughWrite(blocks, version, compression_method); + } + void write(tipb::SelectResponse & response) { if (add_summary) @@ -119,10 +150,6 @@ struct MockWriter queue->push(tracked_packet); } uint16_t getPartitionNum() const { return 1; } - bool isLocal(size_t index) const - { - return index == 0; - } bool isReadyForWrite() const { throw Exception("Unsupport async write"); } std::vector result_field_types; @@ -357,7 +384,10 @@ class TestTiRemoteBlockInputStream : public testing::Test auto dag_writer = std::make_shared>( writer, batch_send_min_limit, - *dag_context_ptr); + *dag_context_ptr, + MPPDataPacketVersion::MPPDataPacketV1, + tipb::CompressionMode::FAST, + tipb::ExchangeType::Broadcast); // 2. encode all blocks for (const auto & block : source_blocks) diff --git a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp index 01df6814dec..d6b3fac6d2c 100644 --- a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp +++ b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -24,13 +25,41 @@ template BroadcastOrPassThroughWriter::BroadcastOrPassThroughWriter( ExchangeWriterPtr writer_, Int64 batch_send_min_limit_, - DAGContext & dag_context_) + DAGContext & dag_context_, + MPPDataPacketVersion data_codec_version_, + tipb::CompressionMode compression_mode_, + tipb::ExchangeType exchange_type_) : DAGResponseWriter(/*records_per_chunk=*/-1, dag_context_) , batch_send_min_limit(batch_send_min_limit_) , writer(writer_) + , exchange_type(exchange_type_) + , data_codec_version(data_codec_version_) + , compression_method(ToInternalCompressionMethod(compression_mode_)) { rows_in_blocks = 0; RUNTIME_CHECK(dag_context.encode_type == tipb::EncodeType::TypeCHBlock); + RUNTIME_CHECK(exchange_type == tipb::ExchangeType::Broadcast || exchange_type == tipb::ExchangeType::PassThrough); + + switch (data_codec_version) + { + case MPPDataPacketV0: + break; + case MPPDataPacketV1: + default: + { + // make `batch_send_min_limit` always GT 0 + if (batch_send_min_limit <= 0) + { + // set upper limit if not specified + batch_send_min_limit = 8 * 1024 /* 8K */; + } + for (const auto & field_type : dag_context.result_field_types) + { + expected_types.emplace_back(getDataTypeByFieldTypeForComputingLayer(field_type)); + } + break; + } + } } template @@ -66,10 +95,20 @@ void BroadcastOrPassThroughWriter::write(const Block & block) template void BroadcastOrPassThroughWriter::writeBlocks() { - if (unlikely(blocks.empty())) + if unlikely (blocks.empty()) return; - writer->broadcastOrPassThroughWrite(blocks); + // check schema + if (!expected_types.empty()) + { + for (auto && block : blocks) + assertBlockSchema(expected_types, block, "BroadcastOrPassThroughWriter"); + } + + if (exchange_type == tipb::ExchangeType::Broadcast) + writer->broadcastWrite(blocks, data_codec_version, compression_method); + else + writer->passThroughWrite(blocks, data_codec_version, compression_method); blocks.clear(); rows_in_blocks = 0; } diff --git a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h index 296fda38ba9..812eeb7c70b 100644 --- a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h +++ b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h @@ -22,6 +22,8 @@ namespace DB { class DAGContext; +enum class CompressionMethod; +enum MPPDataPacketVersion : int64_t; template class BroadcastOrPassThroughWriter : public DAGResponseWriter @@ -30,7 +32,10 @@ class BroadcastOrPassThroughWriter : public DAGResponseWriter BroadcastOrPassThroughWriter( ExchangeWriterPtr writer_, Int64 batch_send_min_limit_, - DAGContext & dag_context_); + DAGContext & dag_context_, + MPPDataPacketVersion data_codec_version_, + tipb::CompressionMode compression_mode_, + tipb::ExchangeType exchange_type_); void write(const Block & block) override; bool isReadyForWrite() const override; void flush() override; @@ -43,6 +48,12 @@ class BroadcastOrPassThroughWriter : public DAGResponseWriter ExchangeWriterPtr writer; std::vector blocks; size_t rows_in_blocks; + const tipb::ExchangeType exchange_type; + + // support data compression + DataTypes expected_types; + MPPDataPacketVersion data_codec_version; + CompressionMethod compression_method{}; }; } // namespace DB diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp index 890b0e41b5c..9fdd65b6a20 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp @@ -92,6 +92,10 @@ void MPPTunnelSetBase::registerTunnel(const MPPTaskId & receiver_task_id { ++external_thread_cnt; } + if (tunnel->isLocal()) + { + ++local_tunnel_cnt; + } } template diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.h b/dbms/src/Flash/Mpp/MPPTunnelSet.h index ab1d96b6e9c..84138d19a61 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.h @@ -60,6 +60,10 @@ class MPPTunnelSetBase : private boost::noncopyable { return external_thread_cnt; } + size_t getLocalTunnelCnt() + { + return local_tunnel_cnt; + } const std::vector & getTunnels() const { return tunnels; } @@ -73,6 +77,7 @@ class MPPTunnelSetBase : private boost::noncopyable const LoggerPtr log; int external_thread_cnt = 0; + size_t local_tunnel_cnt = 0; }; class MPPTunnelSet : public MPPTunnelSetBase diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp index f1bd57a8d4b..b5f47b7bc78 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp @@ -42,6 +42,31 @@ TrackedMppDataPacketPtr ToPacket( return tracked_packet; } +TrackedMppDataPacketPtr ToPacket( + Blocks && blocks, + MPPDataPacketVersion version, + CompressionMethod method, + size_t & original_size) +{ + assert(version > MPPDataPacketV0); + + if (blocks.empty()) + return nullptr; + const Block & header = blocks.front().cloneEmpty(); + auto && codec = CHBlockChunkCodecV1{header}; + auto && res = codec.encode( + std::move(blocks), + method, + false); + if unlikely (res.empty()) + return nullptr; + + auto tracked_packet = std::make_shared(version); + tracked_packet->addChunk(std::move(res)); + original_size += codec.original_size; + return tracked_packet; +} + TrackedMppDataPacketPtr ToPacketV0(Blocks & blocks, const std::vector & field_types) { if (blocks.empty()) @@ -57,7 +82,6 @@ TrackedMppDataPacketPtr ToPacketV0(Blocks & blocks, const std::vectoraddChunk(codec_stream->getString()); codec_stream->clear(); } - blocks.clear(); return tracked_packet; } @@ -142,4 +166,30 @@ TrackedMppDataPacketPtr ToFineGrainedPacketV0( } return tracked_packet; } + +TrackedMppDataPacketPtr ToCompressedPacket( + const TrackedMppDataPacketPtr & uncompressed_source, + MPPDataPacketVersion version, + CompressionMethod method) +{ + assert(uncompressed_source); + for ([[maybe_unused]] const auto & chunk : uncompressed_source->getPacket().chunks()) + { + assert(chunk.empty()); + assert(static_cast(chunk[0]) == CompressionMethodByte::NONE); + } + + // re-encode by specified compression method + auto compressed_tracked_packet = std::make_shared(version); + for (const auto & chunk : uncompressed_source->getPacket().chunks()) + { + auto && compressed_buffer = CHBlockChunkCodecV1::encode({&chunk[1], chunk.size() - 1}, method); + assert(!compressed_buffer.empty()); + + compressed_tracked_packet->addChunk(std::move(compressed_buffer)); + } + return compressed_tracked_packet; +} + + } // namespace DB::MPPTunnelSetHelper diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetHelper.h b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.h index 050547dd330..4fad6f8aba4 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSetHelper.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.h @@ -27,6 +27,17 @@ namespace DB::MPPTunnelSetHelper { TrackedMppDataPacketPtr ToPacketV0(Blocks & blocks, const std::vector & field_types); +TrackedMppDataPacketPtr ToCompressedPacket( + const TrackedMppDataPacketPtr & uncompressed_source, + MPPDataPacketVersion version, + CompressionMethod method); + +TrackedMppDataPacketPtr ToPacket( + Blocks && blocks, + MPPDataPacketVersion version, + CompressionMethod method, + size_t & original_size); + TrackedMppDataPacketPtr ToPacket( const Block & header, std::vector && part_columns, diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.cpp b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.cpp index 804417b8034..da92dfaa5f5 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.cpp @@ -30,52 +30,167 @@ void checkPacketSize(size_t size) static constexpr size_t max_packet_size = 1u << 31; RUNTIME_CHECK_MSG(size < max_packet_size, "Packet is too large to send, size : {}", size); } +} // namespace + + +#define UPDATE_EXCHANGE_MATRIC_IMPL(type, compress, value, ...) \ + do \ + { \ + GET_METRIC(tiflash_exchange_data_bytes, type_##type##_##compress##_compression##__VA_ARGS__).Increment(value); \ + } while (false) +#define UPDATE_EXCHANGE_MATRIC_NONE_COMPRESS_LOCAL(type, value) UPDATE_EXCHANGE_MATRIC_IMPL(type, none, value, _local) +#define UPDATE_EXCHANGE_MATRIC_NONE_COMPRESS_REMOTE(type, value) UPDATE_EXCHANGE_MATRIC_IMPL(type, none, value, _remote) +#define UPDATE_EXCHANGE_MATRIC_NONE_COMPRESS(type, is_local, value) \ + do \ + { \ + if (is_local) \ + { \ + UPDATE_EXCHANGE_MATRIC_NONE_COMPRESS_LOCAL(type, (value)); \ + } \ + else \ + { \ + UPDATE_EXCHANGE_MATRIC_NONE_COMPRESS_REMOTE(type, (value)); \ + } \ + } while (false) +#define UPDATE_EXCHANGE_MATRIC_LZ4_COMPRESS(type, value) UPDATE_EXCHANGE_MATRIC_IMPL(type, lz4, value) +#define UPDATE_EXCHANGE_MATRIC_ZSTD_COMPRESS(type, value) UPDATE_EXCHANGE_MATRIC_IMPL(type, zstd, value) +#define UPDATE_EXCHANGE_MATRIC_ORIGINAL(type, value) \ + do \ + { \ + GET_METRIC(tiflash_exchange_data_bytes, type_##type##_original).Increment(value); \ + } while (false) +#define UPDATE_EXCHANGE_MATRIC(type, compress_method, original_size, actual_size, is_local) \ + do \ + { \ + UPDATE_EXCHANGE_MATRIC_ORIGINAL(type, original_size); \ + switch (compress_method) \ + { \ + case CompressionMethod::NONE: \ + { \ + UPDATE_EXCHANGE_MATRIC_NONE_COMPRESS(type, is_local, actual_size); \ + break; \ + } \ + case CompressionMethod::LZ4: \ + { \ + UPDATE_EXCHANGE_MATRIC_LZ4_COMPRESS(type, actual_size); \ + break; \ + } \ + case CompressionMethod::ZSTD: \ + { \ + UPDATE_EXCHANGE_MATRIC_ZSTD_COMPRESS(type, actual_size); \ + break; \ + } \ + default: \ + break; \ + } \ + } while (false) -void updatePartitionWriterMetrics(size_t packet_bytes, bool is_local) +static inline void updatePartitionWriterMetrics(CompressionMethod method, size_t original_size, size_t actual_size, bool is_local) { - // statistic - GET_METRIC(tiflash_exchange_data_bytes, type_hash_original).Increment(packet_bytes); - // compression method is always NONE - if (is_local) - GET_METRIC(tiflash_exchange_data_bytes, type_hash_none_compression_local).Increment(packet_bytes); - else - GET_METRIC(tiflash_exchange_data_bytes, type_hash_none_compression_remote).Increment(packet_bytes); + UPDATE_EXCHANGE_MATRIC(hash, method, original_size, actual_size, is_local); } -void updatePartitionWriterMetrics(CompressionMethod method, size_t original_size, size_t sz, bool is_local) +template +static void broadcastOrPassThroughWriteImpl( + const size_t tunnel_cnt, + const size_t local_tunnel_cnt, // can be 0 for PassThrough writer + const size_t ori_packet_bytes, // original data packet size + TrackedMppDataPacketPtr && local_tracked_packet, // can be NULL if there is no local tunnel + TrackedMppDataPacketPtr && remote_tracked_packet, // can be NULL if all tunnels are local mode + const CompressionMethod compression_method, + FuncIsLocalTunnel && isLocalTunnel, + FuncWriteToTunnel && writeToTunnel) { - // statistic - GET_METRIC(tiflash_exchange_data_bytes, type_hash_original).Increment(original_size); + assert(ori_packet_bytes > 0); + + const size_t remote_tunnel_cnt = tunnel_cnt - local_tunnel_cnt; + auto remote_tracked_packet_bytes = remote_tracked_packet ? remote_tracked_packet->getPacket().ByteSizeLong() : 0; - switch (method) + if (!local_tracked_packet) { - case CompressionMethod::NONE: + assert(local_tunnel_cnt == 0); + assert(remote_tracked_packet); + } + else + { + checkPacketSize(ori_packet_bytes); + } + + if (!remote_tracked_packet) + { + assert(local_tracked_packet); + assert(local_tunnel_cnt == tunnel_cnt); + } + else + { + checkPacketSize(remote_tracked_packet_bytes); + } + + // TODO avoid copy packet for broadcast. + for (size_t i = 0, local_cnt = 0, remote_cnt = 0; i < tunnel_cnt; ++i) { - if (is_local) + if (isLocalTunnel(i)) { - GET_METRIC(tiflash_exchange_data_bytes, type_hash_none_compression_local).Increment(sz); + local_cnt++; + if (local_cnt == local_tunnel_cnt) + writeToTunnel(std::move(local_tracked_packet), i); + else + writeToTunnel(local_tracked_packet->copy(), i); // NOLINT } else { - GET_METRIC(tiflash_exchange_data_bytes, type_hash_none_compression_remote).Increment(sz); + remote_cnt++; + if (remote_cnt == remote_tunnel_cnt) + writeToTunnel(std::move(remote_tracked_packet), i); + else + writeToTunnel(remote_tracked_packet->copy(), i); // NOLINT } - break; } - case CompressionMethod::LZ4: + + if constexpr (is_broadcast) + { + UPDATE_EXCHANGE_MATRIC(broadcast, CompressionMethod::NONE, local_tunnel_cnt * ori_packet_bytes, local_tunnel_cnt * ori_packet_bytes, true); + UPDATE_EXCHANGE_MATRIC(broadcast, compression_method, remote_tunnel_cnt * ori_packet_bytes, remote_tunnel_cnt * remote_tracked_packet_bytes, false); + } + else { - GET_METRIC(tiflash_exchange_data_bytes, type_hash_lz4_compression).Increment(sz); - break; + UPDATE_EXCHANGE_MATRIC(passthrough, CompressionMethod::NONE, local_tunnel_cnt * ori_packet_bytes, local_tunnel_cnt * ori_packet_bytes, true); + UPDATE_EXCHANGE_MATRIC(passthrough, compression_method, remote_tunnel_cnt * ori_packet_bytes, remote_tunnel_cnt * remote_tracked_packet_bytes, false); } - case CompressionMethod::ZSTD: +} + +template +static void broadcastOrPassThroughWriteV0( + const size_t tunnel_cnt, + const size_t local_tunnel_cnt, + Blocks & blocks, + const std::vector & result_field_types, + FuncIsLocalTunnel && isLocalTunnel, + FuncWriteToTunnel && writeToTunnel) +{ + auto && ori_tracked_packet = MPPTunnelSetHelper::ToPacketV0(blocks, result_field_types); + if (!ori_tracked_packet) + return; + size_t tracked_packet_bytes = ori_tracked_packet->getPacket().ByteSizeLong(); + TrackedMppDataPacketPtr remote_tracked_packet = nullptr; + if (local_tunnel_cnt != tunnel_cnt) { - GET_METRIC(tiflash_exchange_data_bytes, type_hash_zstd_compression).Increment(sz); - break; + remote_tracked_packet = ori_tracked_packet; } - default: - break; + if (0 == local_tunnel_cnt) + { + ori_tracked_packet = nullptr; } + broadcastOrPassThroughWriteImpl( + tunnel_cnt, + local_tunnel_cnt, + tracked_packet_bytes, + std::move(ori_tracked_packet), + std::move(remote_tracked_packet), + CompressionMethod::NONE, + std::forward(isLocalTunnel), + std::forward(writeToTunnel)); } -} // namespace MPPTunnelSetWriterBase::MPPTunnelSetWriterBase( const MPPTunnelSetPtr & mpp_tunnel_set_, @@ -95,36 +210,113 @@ void MPPTunnelSetWriterBase::write(tipb::SelectResponse & response) writeToTunnel(response, 0); } -void MPPTunnelSetWriterBase::broadcastOrPassThroughWrite(Blocks & blocks) +void MPPTunnelSetWriterBase::broadcastWrite(Blocks & blocks) { - auto && tracked_packet = MPPTunnelSetHelper::ToPacketV0(blocks, result_field_types); - if (!tracked_packet) + return broadcastOrPassThroughWriteV0( + mpp_tunnel_set->getTunnels().size(), + mpp_tunnel_set->getLocalTunnelCnt(), + blocks, + result_field_types, + [&](size_t i) { return mpp_tunnel_set->isLocal(i); }, + [&](TrackedMppDataPacketPtr && data, size_t index) { + return writeToTunnel(std::move(data), index); + }); +} + +void MPPTunnelSetWriterBase::passThroughWrite(Blocks & blocks) +{ + return broadcastOrPassThroughWriteV0( + mpp_tunnel_set->getTunnels().size(), + mpp_tunnel_set->getLocalTunnelCnt(), + blocks, + result_field_types, + [&](size_t i) { return mpp_tunnel_set->isLocal(i); }, + [&](TrackedMppDataPacketPtr && data, size_t index) { + return writeToTunnel(std::move(data), index); + }); +} + +template +static void broadcastOrPassThroughWrite( + const size_t tunnel_cnt, + const size_t local_tunnel_cnt, + Blocks & blocks, + MPPDataPacketVersion version, + CompressionMethod compression_method, + FuncIsLocalTunnel && isLocalTunnel, + FuncWriteToTunnel && writeToTunnel) +{ + assert(version > MPPDataPacketV0); + + size_t original_size = 0; + // encode by method NONE + auto && ori_tracked_packet = MPPTunnelSetHelper::ToPacket(std::move(blocks), version, CompressionMethod::NONE, original_size); + if (!ori_tracked_packet) return; - auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); - checkPacketSize(packet_bytes); - // TODO avoid copy packet for broadcast. - for (size_t i = 1; i < getPartitionNum(); ++i) - writeToTunnel(tracked_packet->copy(), i); - writeToTunnel(std::move(tracked_packet), 0); + size_t tracked_packet_bytes = ori_tracked_packet->getPacket().ByteSizeLong(); + + TrackedMppDataPacketPtr remote_tunnel_tracked_packet = nullptr; + + if (local_tunnel_cnt != tunnel_cnt) { - // statistic - size_t data_bytes = 0; - size_t local_data_bytes = 0; - { - auto tunnel_cnt = getPartitionNum(); - size_t local_tunnel_cnt = 0; - for (size_t i = 0; i < tunnel_cnt; ++i) - { - local_tunnel_cnt += mpp_tunnel_set->isLocal(i); - } - data_bytes = packet_bytes * tunnel_cnt; - local_data_bytes = packet_bytes * local_tunnel_cnt; - } - GET_METRIC(tiflash_exchange_data_bytes, type_broadcast_passthrough_original).Increment(data_bytes); - GET_METRIC(tiflash_exchange_data_bytes, type_broadcast_passthrough_none_compression_local).Increment(local_data_bytes); - GET_METRIC(tiflash_exchange_data_bytes, type_broadcast_passthrough_none_compression_remote).Increment(data_bytes - local_data_bytes); + if (compression_method != CompressionMethod::NONE) + remote_tunnel_tracked_packet = MPPTunnelSetHelper::ToCompressedPacket(ori_tracked_packet, version, compression_method); + else + remote_tunnel_tracked_packet = ori_tracked_packet; + } + else + { + // remote packet will be NULL if local_tunnel_cnt == tunnel_cnt + } + + if (0 == local_tunnel_cnt) + { + // if no need local tunnel, just release early to reduce memory usage + ori_tracked_packet = nullptr; } + + return broadcastOrPassThroughWriteImpl( + tunnel_cnt, + local_tunnel_cnt, + tracked_packet_bytes, + std::move(ori_tracked_packet), + std::move(remote_tunnel_tracked_packet), + compression_method, + std::forward(isLocalTunnel), + std::forward(writeToTunnel)); +} + +void MPPTunnelSetWriterBase::broadcastWrite(Blocks & blocks, MPPDataPacketVersion version, CompressionMethod compression_method) +{ + if (MPPDataPacketV0 == version) + return broadcastWrite(blocks); + return broadcastOrPassThroughWrite( + mpp_tunnel_set->getTunnels().size(), + mpp_tunnel_set->getLocalTunnelCnt(), + blocks, + version, + compression_method, + [&](size_t i) { return mpp_tunnel_set->isLocal(i); }, + [&](TrackedMppDataPacketPtr && data, size_t index) { + return writeToTunnel(std::move(data), index); + }); +} + +void MPPTunnelSetWriterBase::passThroughWrite(Blocks & blocks, MPPDataPacketVersion version, CompressionMethod compression_method) +{ + if (MPPDataPacketV0 == version) + return passThroughWrite(blocks); + return broadcastOrPassThroughWrite( + mpp_tunnel_set->getTunnels().size(), + mpp_tunnel_set->getLocalTunnelCnt(), + blocks, + version, + compression_method, + [&](size_t i) { return mpp_tunnel_set->isLocal(i); }, + [&](TrackedMppDataPacketPtr && data, size_t index) { + return writeToTunnel(std::move(data), index); + }); } void MPPTunnelSetWriterBase::partitionWrite(Blocks & blocks, int16_t partition_id) @@ -135,7 +327,7 @@ void MPPTunnelSetWriterBase::partitionWrite(Blocks & blocks, int16_t partition_i auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); checkPacketSize(packet_bytes); writeToTunnel(std::move(tracked_packet), partition_id); - updatePartitionWriterMetrics(packet_bytes, mpp_tunnel_set->isLocal(partition_id)); + updatePartitionWriterMetrics(CompressionMethod::NONE, packet_bytes, packet_bytes, mpp_tunnel_set->isLocal(partition_id)); } void MPPTunnelSetWriterBase::partitionWrite( @@ -219,7 +411,7 @@ void MPPTunnelSetWriterBase::fineGrainedShuffleWrite( auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); checkPacketSize(packet_bytes); writeToTunnel(std::move(tracked_packet), partition_id); - updatePartitionWriterMetrics(packet_bytes, mpp_tunnel_set->isLocal(partition_id)); + updatePartitionWriterMetrics(CompressionMethod::NONE, packet_bytes, packet_bytes, mpp_tunnel_set->isLocal(partition_id)); } void SyncMPPTunnelSetWriter::writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h index 87bb4b63e6f..209d3c5b605 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h @@ -32,7 +32,11 @@ class MPPTunnelSetWriterBase : private boost::noncopyable void write(tipb::SelectResponse & response); // this is a broadcast or pass through writing. // data codec version V0 - void broadcastOrPassThroughWrite(Blocks & blocks); + void broadcastWrite(Blocks & blocks); + void passThroughWrite(Blocks & blocks); + // data codec version > V0 + void broadcastWrite(Blocks & blocks, MPPDataPacketVersion version, CompressionMethod compression_method); + void passThroughWrite(Blocks & blocks, MPPDataPacketVersion version, CompressionMethod compression_method); // this is a partition writing. // data codec version V0 void partitionWrite(Blocks & blocks, int16_t partition_id); diff --git a/dbms/src/Flash/Mpp/newMPPExchangeWriter.cpp b/dbms/src/Flash/Mpp/newMPPExchangeWriter.cpp index 0d43b6fd162..7e9a7f55efb 100644 --- a/dbms/src/Flash/Mpp/newMPPExchangeWriter.cpp +++ b/dbms/src/Flash/Mpp/newMPPExchangeWriter.cpp @@ -54,13 +54,16 @@ std::unique_ptr buildMPPExchangeWriter( } else { + auto mpp_version = dag_context.getMPPTaskMeta().mpp_version(); + auto data_codec_version = mpp_version == MppVersionV0 + ? MPPDataPacketV0 + : MPPDataPacketV1; + auto chosen_batch_send_min_limit = mpp_version == MppVersionV0 + ? batch_send_min_limit + : batch_send_min_limit_compression; + if (exchange_type == tipb::ExchangeType::Hash) { - auto mpp_version = dag_context.getMPPTaskMeta().mpp_version(); - auto data_codec_version = mpp_version == MppVersionV0 - ? MPPDataPacketV0 - : MPPDataPacketV1; - if (enable_fine_grained_shuffle) { return std::make_unique>( @@ -75,10 +78,6 @@ std::unique_ptr buildMPPExchangeWriter( } else { - auto chosen_batch_send_min_limit = mpp_version == MppVersionV0 - ? batch_send_min_limit - : batch_send_min_limit_compression; - return std::make_unique>( writer, partition_col_ids, @@ -91,14 +90,14 @@ std::unique_ptr buildMPPExchangeWriter( } else { - // TODO: support data compression if necessary - RUNTIME_CHECK(compression_mode == tipb::CompressionMode::NONE); - RUNTIME_CHECK(!enable_fine_grained_shuffle); return std::make_unique>( writer, - batch_send_min_limit, - dag_context); + chosen_batch_send_min_limit, + dag_context, + data_codec_version, + compression_mode, + exchange_type); } } } diff --git a/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp b/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp index 14757049eaa..f30a3e8f190 100644 --- a/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp +++ b/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp @@ -166,10 +166,41 @@ struct MockExchangeWriter original_size); checker(tracked_packet, part_id); } - void broadcastOrPassThroughWrite(Blocks & blocks) + + void broadcastOrPassThroughWriteV0(Blocks & blocks) { checker(MPPTunnelSetHelper::ToPacketV0(blocks, result_field_types), 0); } + + void broadcastWrite(Blocks & blocks) + { + return broadcastOrPassThroughWriteV0(blocks); + } + void passThroughWrite(Blocks & blocks) + { + return broadcastOrPassThroughWriteV0(blocks); + } + void broadcastOrPassThroughWrite(Blocks & blocks, MPPDataPacketVersion version, CompressionMethod compression_method) + { + if (version == MPPDataPacketV0) + return broadcastOrPassThroughWriteV0(blocks); + + size_t original_size{}; + auto && packet = MPPTunnelSetHelper::ToPacket(std::move(blocks), version, compression_method, original_size); + if (!packet) + return; + + checker(packet, 0); + } + void broadcastWrite(Blocks & blocks, MPPDataPacketVersion version, CompressionMethod compression_method) + { + return broadcastOrPassThroughWrite(blocks, version, compression_method); + } + void passThroughWrite(Blocks & blocks, MPPDataPacketVersion version, CompressionMethod compression_method) + { + return broadcastOrPassThroughWrite(blocks, version, compression_method); + } + void partitionWrite(Blocks & blocks, uint16_t part_id) { checker(MPPTunnelSetHelper::ToPacketV0(blocks, result_field_types), part_id); @@ -504,7 +535,11 @@ try auto dag_writer = std::make_shared>>( mock_writer, batch_send_min_limit, - *dag_context_ptr); + *dag_context_ptr, + MPPDataPacketVersion::MPPDataPacketV0, + tipb::CompressionMode::NONE, + tipb::ExchangeType::Broadcast); + for (const auto & block : blocks) dag_writer->write(block); dag_writer->flush(); @@ -524,6 +559,60 @@ try } CATCH +TEST_F(TestMPPExchangeWriter, TestBroadcastOrPassThroughWriterV1) +try +{ + const size_t block_rows = 64; + const size_t block_num = 64; + const size_t batch_send_min_limit = 108; + + // 1. Build Blocks. + std::vector blocks; + for (size_t i = 0; i < block_num; ++i) + { + blocks.emplace_back(prepareRandomBlock(block_rows)); + blocks.emplace_back(prepareRandomBlock(0)); + } + Block header = blocks.back(); + for (auto mode : {tipb::CompressionMode::NONE, tipb::CompressionMode::FAST, tipb::CompressionMode::HIGH_COMPRESSION}) + { + // 2. Build MockExchangeWriter. + TrackedMppDataPacketPtrs write_report; + auto checker = [&write_report](const TrackedMppDataPacketPtr & packet, uint16_t part_id) { + ASSERT_EQ(part_id, 0); + write_report.emplace_back(packet); + }; + auto mock_writer = std::make_shared(checker, 1, *dag_context_ptr); + + // 3. Start to write. + auto dag_writer = std::make_shared>>( + mock_writer, + batch_send_min_limit, + *dag_context_ptr, + MPPDataPacketVersion::MPPDataPacketV1, + mode, + tipb::ExchangeType::Broadcast); + + for (const auto & block : blocks) + dag_writer->write(block); + dag_writer->flush(); + + // 4. Start to check write_report. + size_t expect_rows = block_rows * block_num; + size_t decoded_block_rows = 0; + for (const auto & packet : write_report) + { + for (int i = 0; i < packet->getPacket().chunks_size(); ++i) + { + auto decoded_block = CHBlockChunkCodecV1::decode(header, packet->getPacket().chunks(i)); + decoded_block_rows += decoded_block.rows(); + } + } + ASSERT_EQ(decoded_block_rows, expect_rows); + } +} +CATCH + static CompressionMethodByte GetCompressionMethodByte(CompressionMethod m) { switch (m) diff --git a/dbms/src/IO/CompressedWriteBuffer.cpp b/dbms/src/IO/CompressedWriteBuffer.cpp index 738830217c0..d870a5b28af 100644 --- a/dbms/src/IO/CompressedWriteBuffer.cpp +++ b/dbms/src/IO/CompressedWriteBuffer.cpp @@ -32,16 +32,13 @@ extern const int CANNOT_COMPRESS; extern const int UNKNOWN_COMPRESSION_METHOD; } // namespace ErrorCodes - -template -void CompressedWriteBuffer::nextImpl() +template +size_t CompressionEncode( + std::string_view source, + const CompressionSettings & compression_settings, + Buffer & compressed_buffer) { - if (!offset()) - return; - - size_t uncompressed_size = offset(); size_t compressed_size = 0; - char * compressed_buffer_ptr = nullptr; /** The format of compressed block - see CompressedStream.h */ @@ -52,41 +49,35 @@ void CompressedWriteBuffer::nextImpl() case CompressionMethod::LZ4HC: { static constexpr size_t header_size = 1 + sizeof(UInt32) + sizeof(UInt32); - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wold-style-cast" - compressed_buffer.resize(header_size + LZ4_COMPRESSBOUND(uncompressed_size)); -#pragma GCC diagnostic pop - + compressed_buffer.resize(header_size + LZ4_COMPRESSBOUND(source.size())); compressed_buffer[0] = static_cast(CompressionMethodByte::LZ4); if (compression_settings.method == CompressionMethod::LZ4) - compressed_size = header_size + LZ4_compress_fast(working_buffer.begin(), &compressed_buffer[header_size], uncompressed_size, LZ4_COMPRESSBOUND(uncompressed_size), compression_settings.level); + compressed_size = header_size + LZ4_compress_fast(source.data(), &compressed_buffer[header_size], source.size(), LZ4_COMPRESSBOUND(source.size()), compression_settings.level); else - compressed_size = header_size + LZ4_compress_HC(working_buffer.begin(), &compressed_buffer[header_size], uncompressed_size, LZ4_COMPRESSBOUND(uncompressed_size), compression_settings.level); + compressed_size = header_size + LZ4_compress_HC(source.data(), &compressed_buffer[header_size], source.size(), LZ4_COMPRESSBOUND(source.size()), compression_settings.level); UInt32 compressed_size_32 = compressed_size; - UInt32 uncompressed_size_32 = uncompressed_size; + UInt32 uncompressed_size_32 = source.size(); unalignedStore(&compressed_buffer[1], compressed_size_32); unalignedStore(&compressed_buffer[5], uncompressed_size_32); - compressed_buffer_ptr = &compressed_buffer[0]; break; } case CompressionMethod::ZSTD: { static constexpr size_t header_size = 1 + sizeof(UInt32) + sizeof(UInt32); - compressed_buffer.resize(header_size + ZSTD_compressBound(uncompressed_size)); + compressed_buffer.resize(header_size + ZSTD_compressBound(source.size())); compressed_buffer[0] = static_cast(CompressionMethodByte::ZSTD); size_t res = ZSTD_compress( &compressed_buffer[header_size], compressed_buffer.size() - header_size, - working_buffer.begin(), - uncompressed_size, + source.data(), + source.size(), compression_settings.level); if (ZSTD_isError(res)) @@ -95,20 +86,19 @@ void CompressedWriteBuffer::nextImpl() compressed_size = header_size + res; UInt32 compressed_size_32 = compressed_size; - UInt32 uncompressed_size_32 = uncompressed_size; + UInt32 uncompressed_size_32 = source.size(); unalignedStore(&compressed_buffer[1], compressed_size_32); unalignedStore(&compressed_buffer[5], uncompressed_size_32); - compressed_buffer_ptr = &compressed_buffer[0]; break; } case CompressionMethod::NONE: { static constexpr size_t header_size = 1 + sizeof(UInt32) + sizeof(UInt32); - compressed_size = header_size + uncompressed_size; - UInt32 uncompressed_size_32 = uncompressed_size; + compressed_size = header_size + source.size(); + UInt32 uncompressed_size_32 = source.size(); UInt32 compressed_size_32 = compressed_size; compressed_buffer.resize(compressed_size); @@ -117,15 +107,28 @@ void CompressedWriteBuffer::nextImpl() unalignedStore(&compressed_buffer[1], compressed_size_32); unalignedStore(&compressed_buffer[5], uncompressed_size_32); - memcpy(&compressed_buffer[9], working_buffer.begin(), uncompressed_size); + memcpy(&compressed_buffer[9], source.data(), source.size()); - compressed_buffer_ptr = &compressed_buffer[0]; break; } default: throw Exception("Unknown compression method", ErrorCodes::UNKNOWN_COMPRESSION_METHOD); } + return compressed_size; +} + +template +void CompressedWriteBuffer::nextImpl() +{ + if (!offset()) + return; + + const char * source = working_buffer.begin(); + const size_t source_size = offset(); + size_t compressed_size = CompressionEncode({source, source_size}, compression_settings, compressed_buffer); + const auto * compressed_buffer_ptr = &compressed_buffer[0]; + if constexpr (add_checksum) { CityHash_v1_0_2::uint128 checksum = CityHash_v1_0_2::CityHash128(compressed_buffer_ptr, compressed_size); @@ -161,4 +164,12 @@ CompressedWriteBuffer::~CompressedWriteBuffer() template class CompressedWriteBuffer; template class CompressedWriteBuffer; +template size_t CompressionEncode>( + std::string_view, + const CompressionSettings &, + PODArray &); +template size_t CompressionEncode( + std::string_view, + const CompressionSettings &, + String &); } // namespace DB diff --git a/dbms/src/Storages/StorageDisaggregated.cpp b/dbms/src/Storages/StorageDisaggregated.cpp index b0e87c50422..e3782abbaec 100644 --- a/dbms/src/Storages/StorageDisaggregated.cpp +++ b/dbms/src/Storages/StorageDisaggregated.cpp @@ -123,6 +123,10 @@ StorageDisaggregated::RequestAndRegionIDs StorageDisaggregated::buildDispatchMPP dispatch_req_meta->set_server_id(sender_target_mpp_task_id.query_id.server_id); dispatch_req_meta->set_task_id(sender_target_mpp_task_id.task_id); dispatch_req_meta->set_address(batch_cop_task.store_addr); + + // TODO: use different mpp version if necessary + // dispatch_req_meta->set_mpp_version(?); + const auto & settings = context.getSettings(); dispatch_req->set_timeout(60); dispatch_req->set_schema_ver(settings.schema_version); @@ -155,6 +159,10 @@ StorageDisaggregated::RequestAndRegionIDs StorageDisaggregated::buildDispatchMPP tipb::ExchangeSender * sender = executor->mutable_exchange_sender(); sender->set_tp(tipb::ExchangeType::PassThrough); + + // TODO: enable data compression if necessary + // sender->set_compression(tipb::CompressionMode::FAST); + sender->add_encoded_task_meta(sender_target_task_meta.SerializeAsString()); auto * child = sender->mutable_child(); child->CopyFrom(buildTableScanTiPB()); diff --git a/dbms/src/Storages/Transaction/ReadIndexWorker.cpp b/dbms/src/Storages/Transaction/ReadIndexWorker.cpp index b3e342daf29..075e871f364 100644 --- a/dbms/src/Storages/Transaction/ReadIndexWorker.cpp +++ b/dbms/src/Storages/Transaction/ReadIndexWorker.cpp @@ -19,7 +19,6 @@ #include #include -#include #include namespace DB @@ -68,22 +67,22 @@ void F_TEST_LOG_FMT(const std::string &) namespace DB { -AsyncNotifier::Status AsyncWaker::Notifier::blockedWaitFor(std::chrono::milliseconds timeout) +AsyncNotifier::Status AsyncWaker::Notifier::blockedWaitUtil(const SteadyClock::time_point & time_point) { // if flag from false to false, wait for notification. // if flag from true to false, do nothing. auto res = AsyncNotifier::Status::Normal; - if (!wait_flag.exchange(false, std::memory_order_acq_rel)) + if (!is_awake->exchange(false, std::memory_order_acq_rel)) { { auto lock = genUniqueLock(); - if (!wait_flag.load(std::memory_order_acquire)) + if (!is_awake->load(std::memory_order_acquire)) { - if (cv.wait_for(lock, timeout) == std::cv_status::timeout) + if (condVar().wait_until(lock, time_point) == std::cv_status::timeout) res = AsyncNotifier::Status::Timeout; } } - wait_flag.store(false, std::memory_order_release); + is_awake->store(false, std::memory_order_release); } return res; } @@ -92,11 +91,13 @@ void AsyncWaker::Notifier::wake() { // if flag from false -> true, then wake up. // if flag from true -> true, do nothing. - if (!wait_flag.exchange(true, std::memory_order_acq_rel)) + if (is_awake->load(std::memory_order_acquire)) + return; + if (!is_awake->exchange(true, std::memory_order_acq_rel)) { // wake up notifier auto _ = genLockGuard(); - cv.notify_one(); + condVar().notify_one(); } } @@ -117,9 +118,9 @@ AsyncWaker::AsyncWaker(const TiFlashRaftProxyHelper & helper_, AsyncNotifier * n { } -AsyncNotifier::Status AsyncWaker::waitFor(std::chrono::milliseconds timeout) +AsyncNotifier::Status AsyncWaker::waitUtil(SteadyClock::time_point time_point) { - return notifier.blockedWaitFor(timeout); + return notifier.blockedWaitUtil(time_point); } RawVoidPtr AsyncWaker::getRaw() const @@ -130,30 +131,20 @@ RawVoidPtr AsyncWaker::getRaw() const struct BlockedReadIndexHelperTrait { explicit BlockedReadIndexHelperTrait(uint64_t timeout_ms_) - : timeout_ms(timeout_ms_) + : time_point(SteadyClock::now() + std::chrono::milliseconds{timeout_ms_}) {} - virtual AsyncNotifier::Status blockedWaitFor(std::chrono::milliseconds) = 0; + virtual AsyncNotifier::Status blockedWaitUtil(SteadyClock::time_point) = 0; // block current runtime and wait. virtual AsyncNotifier::Status blockedWait() { - auto time_cost_ms = check_watch.elapsedMilliseconds(); - - if (time_cost_ms >= timeout_ms) - { - return AsyncNotifier::Status::Timeout; - } - - auto remain = std::chrono::milliseconds(timeout_ms - time_cost_ms); - // TODO: use async process if supported by framework - return blockedWaitFor(remain); + return blockedWaitUtil(time_point); } virtual ~BlockedReadIndexHelperTrait() = default; protected: - Stopwatch check_watch; - uint64_t timeout_ms; + SteadyClock::time_point time_point; }; struct BlockedReadIndexHelper final : BlockedReadIndexHelperTrait @@ -170,9 +161,9 @@ struct BlockedReadIndexHelper final : BlockedReadIndexHelperTrait return waker; } - AsyncNotifier::Status blockedWaitFor(std::chrono::milliseconds tm) override + AsyncNotifier::Status blockedWaitUtil(SteadyClock::time_point time_point) override { - return waker.waitFor(tm); + return waker.waitUtil(time_point); } ~BlockedReadIndexHelper() override = default; @@ -189,9 +180,9 @@ struct BlockedReadIndexHelperV3 final : BlockedReadIndexHelperTrait { } - AsyncNotifier::Status blockedWaitFor(std::chrono::milliseconds tm) override + AsyncNotifier::Status blockedWaitUtil(SteadyClock::time_point time_point) override { - return notifier.blockedWaitFor(tm); + return notifier.blockedWaitUtil(time_point); } ~BlockedReadIndexHelperV3() override = default; @@ -342,6 +333,10 @@ struct RegionReadIndexNotifier final : AsyncNotifier notify->add(region_id, ts); notify->wake(); } + Status blockedWaitUtil(const SteadyClock::time_point &) override + { + return Status::Timeout; + } ~RegionReadIndexNotifier() override = default; @@ -446,7 +441,7 @@ void ReadIndexDataNode::ReadIndexElement::doPoll(const TiFlashRaftProxyHelper & clean_task = true; } - else if (std::chrono::steady_clock::now() > timeout + start_time) + else if (SteadyClock::now() > timeout + start_time) { TEST_LOG_FMT("poll ReadIndexElement timeout for region {}", region_id); @@ -458,7 +453,7 @@ void ReadIndexDataNode::ReadIndexElement::doPoll(const TiFlashRaftProxyHelper & TEST_LOG_FMT( "poll ReadIndexElement failed for region {}, time cost {}, timeout {}, start time {}", region_id, - std::chrono::steady_clock::now() - start_time, + SteadyClock::now() - start_time, timeout, start_time); } @@ -657,7 +652,7 @@ ReadIndexFuturePtr ReadIndexDataNode::insertTask(const kvrpcpb::ReadIndexRequest ReadIndexDataNodePtr ReadIndexWorker::DataMap::upsertDataNode(RegionID region_id) const { - auto _ = genWriteLockGuard(); + auto _ = genUniqueLock(); TEST_LOG_FMT("upsertDataNode for region {}", region_id); @@ -669,7 +664,7 @@ ReadIndexDataNodePtr ReadIndexWorker::DataMap::upsertDataNode(RegionID region_id ReadIndexDataNodePtr ReadIndexWorker::DataMap::tryGetDataNode(RegionID region_id) const { - auto _ = genReadLockGuard(); + auto _ = genSharedLock(); if (auto it = region_map.find(region_id); it != region_map.end()) { return it->second; @@ -686,13 +681,13 @@ ReadIndexDataNodePtr ReadIndexWorker::DataMap::getDataNode(RegionID region_id) c void ReadIndexWorker::DataMap::invoke(std::function &)> && cb) { - auto _ = genWriteLockGuard(); + auto _ = genUniqueLock(); cb(region_map); } void ReadIndexWorker::DataMap::removeRegion(RegionID region_id) { - auto _ = genWriteLockGuard(); + auto _ = genUniqueLock(); region_map.erase(region_id); } @@ -706,7 +701,7 @@ void ReadIndexWorker::consumeReadIndexNotifyCtrl() } } -void ReadIndexWorker::consumeRegionNotifies(std::chrono::steady_clock::duration min_dur) +void ReadIndexWorker::consumeRegionNotifies(SteadyClock::duration min_dur) { if (!lastRunTimeout(min_dur)) { @@ -721,7 +716,7 @@ void ReadIndexWorker::consumeRegionNotifies(std::chrono::steady_clock::duration } TEST_LOG_FMT("worker {} set last run time {}", getID(), Clock::now()); - last_run_time.store(std::chrono::steady_clock::now(), std::memory_order_release); + last_run_time.store(SteadyClock::now(), std::memory_order_release); } ReadIndexFuturePtr ReadIndexWorker::genReadIndexFuture(const kvrpcpb::ReadIndexRequest & req) @@ -737,7 +732,7 @@ ReadIndexFuturePtr ReadIndexWorkerManager::genReadIndexFuture(const kvrpcpb::Rea return getWorkerByRegion(req.context().region_id()).genReadIndexFuture(req); } -void ReadIndexWorker::runOneRound(std::chrono::steady_clock::duration min_dur) +void ReadIndexWorker::runOneRound(SteadyClock::duration min_dur) { if (!read_index_notify_ctrl->empty()) { @@ -759,10 +754,10 @@ ReadIndexWorker::ReadIndexWorker( { } -bool ReadIndexWorker::lastRunTimeout(std::chrono::steady_clock::duration timeout) const +bool ReadIndexWorker::lastRunTimeout(SteadyClock::duration timeout) const { TEST_LOG_FMT("worker {}, last run time {}, timeout {}", getID(), last_run_time.load(std::memory_order_relaxed), timeout); - return last_run_time.load(std::memory_order_relaxed) + timeout < std::chrono::steady_clock::now(); + return last_run_time.load(std::memory_order_relaxed) + timeout < SteadyClock::now(); } ReadIndexWorker & ReadIndexWorkerManager::getWorkerByRegion(RegionID region_id) @@ -828,13 +823,13 @@ ReadIndexWorkerManager::~ReadIndexWorkerManager() stop(); } -void ReadIndexWorkerManager::runOneRoundAll(std::chrono::steady_clock::duration min_dur) +void ReadIndexWorkerManager::runOneRoundAll(SteadyClock::duration min_dur) { for (size_t id = 0; id < runners.size(); ++id) runOneRound(min_dur, id); } -void ReadIndexWorkerManager::runOneRound(std::chrono::steady_clock::duration min_dur, size_t id) +void ReadIndexWorkerManager::runOneRound(SteadyClock::duration min_dur, size_t id) { runners[id]->runOneRound(min_dur); } @@ -1003,7 +998,7 @@ void ReadIndexWorkerManager::ReadIndexRunner::blockedWaitFor(std::chrono::millis global_notifier->blockedWaitFor(timeout); } -void ReadIndexWorkerManager::ReadIndexRunner::runOneRound(std::chrono::steady_clock::duration min_dur) +void ReadIndexWorkerManager::ReadIndexRunner::runOneRound(SteadyClock::duration min_dur) { for (size_t i = id; i < workers.size(); i += runner_cnt) workers[i]->runOneRound(min_dur); diff --git a/dbms/src/Storages/Transaction/ReadIndexWorker.h b/dbms/src/Storages/Transaction/ReadIndexWorker.h index ebc36874226..45ddba4d3dd 100644 --- a/dbms/src/Storages/Transaction/ReadIndexWorker.h +++ b/dbms/src/Storages/Transaction/ReadIndexWorker.h @@ -36,17 +36,19 @@ class ReadIndexTest; struct AsyncWaker { struct Notifier final : AsyncNotifier - , MutexLockWrap + , MutexCondVarWrap { - mutable std::condition_variable cv; - // multi notifiers single receiver model. use another flag to avoid waiting endlessly. - mutable std::atomic_bool wait_flag{false}; - - // usually sender invoke `wake`, receiver invoke `blockedWaitFor` - AsyncNotifier::Status blockedWaitFor(std::chrono::milliseconds timeout) override; + // usually sender invoke `wake`, receiver invoke `blockedWaitUtil` + // NOT thread safe + Status blockedWaitUtil(const SteadyClock::time_point &) override; + // thread safe void wake() override; ~Notifier() override = default; + + private: + // multi notifiers single receiver model. use another flag to avoid waiting endlessly. + AlignedStruct is_awake{false}; }; using NotifierPtr = std::shared_ptr; @@ -56,7 +58,7 @@ struct AsyncWaker // create a `Notifier` in heap & let proxy wrap it and return as rust ptr with specific type. explicit AsyncWaker(const TiFlashRaftProxyHelper & helper_); - AsyncNotifier::Status waitFor(std::chrono::milliseconds timeout); + AsyncNotifier::Status waitUtil(SteadyClock::time_point); RawVoidPtr getRaw() const; @@ -88,7 +90,7 @@ class ReadIndexWorkerManager : boost::noncopyable void wakeAll(); // wake all runners to handle tasks void asyncRun(); - void runOneRound(std::chrono::steady_clock::duration min_dur, size_t id); + void runOneRound(SteadyClock::duration min_dur, size_t id); void stop(); ~ReadIndexWorkerManager(); BatchReadIndexRes batchReadIndex( @@ -104,7 +106,7 @@ class ReadIndexWorkerManager : boost::noncopyable ReadIndexFuturePtr genReadIndexFuture(const kvrpcpb::ReadIndexRequest & req); private: - void runOneRoundAll(std::chrono::steady_clock::duration min_dur = std::chrono::milliseconds{0}); + void runOneRoundAll(SteadyClock::duration min_dur = std::chrono::milliseconds{0}); enum class State : uint8_t { @@ -124,7 +126,7 @@ class ReadIndexWorkerManager : boost::noncopyable void blockedWaitFor(std::chrono::milliseconds timeout) const; /// Traverse its workers and try to execute tasks. - void runOneRound(std::chrono::steady_clock::duration min_dur); + void runOneRound(SteadyClock::duration min_dur); /// Create one thread to run asynchronously. void asyncRun(); @@ -219,7 +221,7 @@ struct ReadIndexDataNode : MutexLockWrap Task task_pair; kvrpcpb::ReadIndexResponse resp; std::deque callbacks; - std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now(); + SteadyClock::time_point start_time = SteadyClock::now(); }; struct WaitingTasks : MutexLockWrap @@ -298,12 +300,12 @@ struct ReadIndexWorker void consumeReadIndexNotifyCtrl(); - void consumeRegionNotifies(std::chrono::steady_clock::duration min_dur); + void consumeRegionNotifies(SteadyClock::duration min_dur); ReadIndexFuturePtr genReadIndexFuture(const kvrpcpb::ReadIndexRequest & req); // try to consume read-index response notifications & region waiting list - void runOneRound(std::chrono::steady_clock::duration min_dur); + void runOneRound(SteadyClock::duration min_dur); explicit ReadIndexWorker( const TiFlashRaftProxyHelper & proxy_helper_, @@ -329,7 +331,7 @@ struct ReadIndexWorker // x = x == 0 ? 1 : x; // max_read_index_history = x; // } - bool lastRunTimeout(std::chrono::steady_clock::duration timeout) const; + bool lastRunTimeout(SteadyClock::duration timeout) const; void removeRegion(uint64_t); @@ -348,7 +350,7 @@ struct ReadIndexWorker RegionNotifyMap region_notify_map; // no need to be protected - std::atomic last_run_time{std::chrono::steady_clock::time_point::min()}; + std::atomic last_run_time{SteadyClock::time_point::min()}; }; struct MockStressTestCfg diff --git a/dbms/src/Storages/Transaction/RegionManager.h b/dbms/src/Storages/Transaction/RegionManager.h index f97eb144954..eebeb5183b2 100644 --- a/dbms/src/Storages/Transaction/RegionManager.h +++ b/dbms/src/Storages/Transaction/RegionManager.h @@ -57,12 +57,12 @@ struct RegionManager : SharedMutexLockWrap RegionReadLock genRegionReadLock() const { - return {genReadLockGuard(), regions, region_range_index}; + return {genSharedLock(), regions, region_range_index}; } RegionWriteLock genRegionWriteLock() { - return {genWriteLockGuard(), regions, region_range_index}; + return {genUniqueLock(), regions, region_range_index}; } /// Encapsulate the task lock for region diff --git a/dbms/src/Storages/Transaction/Utils.h b/dbms/src/Storages/Transaction/Utils.h index fa742d1c7e5..6b7d6c81f63 100644 --- a/dbms/src/Storages/Transaction/Utils.h +++ b/dbms/src/Storages/Transaction/Utils.h @@ -21,45 +21,85 @@ namespace DB { +using SteadyClock = std::chrono::steady_clock; +static constexpr size_t CPU_CACHE_LINE_SIZE = 64; + +template +struct AlignedStruct +{ + template + explicit AlignedStruct(Args &&... args) + : inner{std::forward(args)...} + {} + + Base & base() { return inner; } + const Base & base() const { return inner; } + Base * operator->() { return &inner; } + const Base * operator->() const { return &inner; } + Base & operator*() { return inner; } + const Base & operator*() const { return inner; } + +private: + // Wrapped with struct to guarantee that it is aligned to `alignment` + // DO NOT need padding byte + alignas(alignment) Base inner; +}; + class MutexLockWrap { public: - std::lock_guard genLockGuard() const + using Mutex = std::mutex; + + std::lock_guard genLockGuard() const { - return std::lock_guard(mutex); + return std::lock_guard(*mutex); } - std::unique_lock tryToLock() const + std::unique_lock tryToLock() const { - return std::unique_lock(mutex, std::try_to_lock); + return std::unique_lock(*mutex, std::try_to_lock); } - std::unique_lock genUniqueLock() const + std::unique_lock genUniqueLock() const { - return std::unique_lock(mutex); + return std::unique_lock(*mutex); } private: - mutable std::mutex mutex; + mutable AlignedStruct mutex; }; class SharedMutexLockWrap { public: - std::shared_lock genReadLockGuard() const + using Mutex = std::shared_mutex; + + std::shared_lock genSharedLock() const { - return std::shared_lock(shared_mutex); + return std::shared_lock(*mutex); } - std::unique_lock genWriteLockGuard() const + std::unique_lock genUniqueLock() const { - return std::unique_lock(shared_mutex); + return std::unique_lock(*mutex); } private: - mutable std::shared_mutex shared_mutex; + mutable AlignedStruct mutex; }; +class MutexCondVarWrap : public MutexLockWrap +{ +public: + using CondVar = std::condition_variable; + + CondVar & condVar() const { return *cv; } + +private: + mutable AlignedStruct cv; +}; + + struct AsyncNotifier { enum class Status @@ -67,7 +107,11 @@ struct AsyncNotifier Timeout, Normal, }; - virtual Status blockedWaitFor(std::chrono::milliseconds) { return AsyncNotifier::Status::Timeout; } + virtual Status blockedWaitFor(const std::chrono::milliseconds & duration) + { + return blockedWaitUtil(SteadyClock::now() + duration); + } + virtual Status blockedWaitUtil(const SteadyClock::time_point &) = 0; virtual void wake() = 0; virtual ~AsyncNotifier() = default; }; diff --git a/libs/libcommon/include/common/avx2_byte_count.h b/libs/libcommon/include/common/avx2_byte_count.h index af2fb2d5844..c676bd8e097 100644 --- a/libs/libcommon/include/common/avx2_byte_count.h +++ b/libs/libcommon/include/common/avx2_byte_count.h @@ -27,7 +27,7 @@ ALWAYS_INLINE static inline #endif uint64_t avx2_byte_count(const char * src, size_t size, char target) { - uint64_t zero_bytes_cnt = 0; + uint64_t tar_byte_cnt = 0; const auto check_block32 = _mm256_set1_epi8(target); if (uint8_t right_offset = OFFSET_FROM_ALIGNED(size_t(src), BLOCK32_SIZE); right_offset != 0) @@ -49,7 +49,7 @@ uint64_t avx2_byte_count(const char * src, size_t size, char target) } mask >>= right_offset; - zero_bytes_cnt += std::popcount(mask); + tar_byte_cnt += std::popcount(mask); size -= left_remain; src += BLOCK32_SIZE; } @@ -61,7 +61,7 @@ uint64_t avx2_byte_count(const char * src, size_t size, char target) for (; size >= BLOCK32_SIZE;) { auto mask = get_block32_cmp_eq_mask(src, check_block32); - zero_bytes_cnt += std::popcount(mask); + tar_byte_cnt += std::popcount(mask); size -= BLOCK32_SIZE, src += BLOCK32_SIZE; } @@ -71,10 +71,10 @@ uint64_t avx2_byte_count(const char * src, size_t size, char target) uint32_t left_remain = BLOCK32_SIZE - size; mask <<= left_remain; mask >>= left_remain; - zero_bytes_cnt += std::popcount(mask); + tar_byte_cnt += std::popcount(mask); } - return zero_bytes_cnt; + return tar_byte_cnt; } } // namespace mem_utils::details