diff --git a/go.mod b/go.mod index 38ae793c2a7b6..96e7e958e5887 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.16 require ( github.com/BurntSushi/toml v1.0.0 github.com/HdrHistogram/hdrhistogram-go v1.0.1 // indirect + github.com/Jeffail/tunny v0.1.4 // indirect github.com/StackExchange/wmi v1.2.1 // indirect github.com/antonmedv/expr v1.8.9 github.com/apache/arrow/go/v8 v8.0.0-20220322092137-778b1772fd20 @@ -29,6 +30,7 @@ require ( github.com/minio/minio-go/v7 v7.0.10 github.com/mitchellh/mapstructure v1.4.1 github.com/opentracing/opentracing-go v1.2.0 + github.com/panjf2000/ants/v2 v2.4.8 // indirect github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pierrec/lz4 v2.5.2+incompatible // indirect github.com/pkg/errors v0.9.1 @@ -64,4 +66,3 @@ replace ( github.com/keybase/go-keychain => github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 google.golang.org/grpc => google.golang.org/grpc v1.38.0 ) - diff --git a/go.sum b/go.sum index ea9ce85fb137f..eca84ee7d49a2 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,8 @@ github.com/DataDog/zstd v1.4.6-0.20210211175136-c6db21d202f4 h1:++HGU87uq9UsSTlF github.com/DataDog/zstd v1.4.6-0.20210211175136-c6db21d202f4/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/HdrHistogram/hdrhistogram-go v1.0.1 h1:GX8GAYDuhlFQnI2fRDHQhTlkHMz8bEn0jTI6LJU0mpw= github.com/HdrHistogram/hdrhistogram-go v1.0.1/go.mod h1:BWJ+nMSHY3L41Zj7CA3uXnloDp7xxV0YvstAE7nKTaM= +github.com/Jeffail/tunny v0.1.4 h1:chtpdz+nUtaYQeCKlNBg6GycFF/kGVHOr6A3cmzTJXs= +github.com/Jeffail/tunny v0.1.4/go.mod h1:P8xAx4XQl0xsuhjX1DtfaMDCSuavzdb2rwbd0lk+fvo= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= @@ -121,8 +123,6 @@ github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/confluentinc/confluent-kafka-go v1.8.2 h1:PBdbvYpyOdFLehj8j+9ba7FL4c4Moxn79gy9cYKxG5E= -github.com/confluentinc/confluent-kafka-go v1.8.2/go.mod h1:u2zNLny2xq+5rWeTQjFHbDzzNuba4P1vo31r9r4uAdg= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -137,6 +137,8 @@ github.com/cockroachdb/errors v1.2.4/go.mod h1:rQD95gz6FARkaKkQXUksEje/d9a6wBJoC github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f h1:o/kfcElHqOiXqcou5a3rIlMc7oJbMQkeLk0VQJ7zgqY= github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f/go.mod h1:i/u985jwjWRlyHXQbwatDASoW0RMlZ/3i9yJHE2xLkI= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= +github.com/confluentinc/confluent-kafka-go v1.8.2 h1:PBdbvYpyOdFLehj8j+9ba7FL4c4Moxn79gy9cYKxG5E= +github.com/confluentinc/confluent-kafka-go v1.8.2/go.mod h1:u2zNLny2xq+5rWeTQjFHbDzzNuba4P1vo31r9r4uAdg= github.com/containerd/cgroups v1.0.2 h1:mZBclaSgNDfPWtfhj2xJY28LZ9nYIgzB0pwSURPl6JM= github.com/containerd/cgroups v1.0.2/go.mod h1:qpbpJ1jmlqsR9f2IyaLPsdkCdnt0rbDVqIDlhuu5tRY= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= @@ -529,6 +531,8 @@ github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJ github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= +github.com/panjf2000/ants/v2 v2.4.8 h1:JgTbolX6K6RreZ4+bfctI0Ifs+3mrE5BIHudQxUDQ9k= +github.com/panjf2000/ants/v2 v2.4.8/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= @@ -1189,6 +1193,7 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/internal/querynode/segment_loader.go b/internal/querynode/segment_loader.go index 3619b28cf8416..8ab2e51cd0e35 100644 --- a/internal/querynode/segment_loader.go +++ b/internal/querynode/segment_loader.go @@ -26,6 +26,7 @@ import ( "sync" "time" + "github.com/panjf2000/ants/v2" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/common" @@ -39,6 +40,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/concurrency" "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/timerecord" @@ -56,6 +58,9 @@ type segmentLoader struct { cm storage.ChunkManager // minio cm etcdKV *etcdkv.EtcdKV + ioPool *concurrency.Pool + cpuPool *concurrency.Pool + factory msgstream.Factory } @@ -73,21 +78,31 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme switch segmentType { case segmentTypeGrowing: metaReplica = loader.streamingReplica + case segmentTypeSealed: metaReplica = loader.historicalReplica + default: err := fmt.Errorf("illegal segment type when load segment, collectionID = %d", req.CollectionID) - log.Error("load segment failed, illegal segment type", zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), zap.Error(err)) + log.Error("load segment failed, illegal segment type", + zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), + zap.Error(err)) return err } log.Debug("segmentLoader start loading...", - zap.Any("collectionID", req.CollectionID), - zap.Any("numOfSegments", len(req.Infos)), + zap.Int64("collectionID", req.CollectionID), + zap.Int("numOfSegments", len(req.Infos)), zap.Any("loadType", segmentType), ) // check memory limit - concurrencyLevel := runtime.GOMAXPROCS(0) + concurrencyLevel := loader.cpuPool.Cap() + if len(req.Infos) > 0 && len(req.Infos[0].BinlogPaths) > 0 { + concurrencyLevel /= len(req.Infos[0].BinlogPaths) + if concurrencyLevel <= 0 { + concurrencyLevel = 1 + } + } for ; concurrencyLevel > 1; concurrencyLevel /= 2 { err := loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel) if err == nil { @@ -97,7 +112,9 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme err := loader.checkSegmentSize(req.CollectionID, req.Infos, concurrencyLevel) if err != nil { - log.Error("load failed, OOM if loaded", zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), zap.Error(err)) + log.Error("load failed, OOM if loaded", + zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), + zap.Error(err)) return err } @@ -118,6 +135,7 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme segmentGC() return err } + segment, err := newSegment(collection, segmentID, partitionID, collectionID, "", segmentType, true) if err != nil { log.Error("load segment failed when create new segment", @@ -139,8 +157,9 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme partitionID := loadInfo.PartitionID segmentID := loadInfo.SegmentID segment := newSegments[segmentID] + tr := timerecord.NewTimeRecorder("loadDurationPerSegment") - err = loader.loadSegmentInternal(segment, loadInfo) + err := loader.loadSegmentInternal(segment, loadInfo) if err != nil { log.Error("load segment failed when load data into memory", zap.Int64("collectionID", collectionID), @@ -156,7 +175,10 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme return nil } // start to load - err = funcutil.ProcessFuncParallel(len(req.Infos), concurrencyLevel, loadSegmentFunc, "loadSegmentFunc") + // Make sure we can always benefit from concurrency, and not spawn too many idle goroutines + err = funcutil.ProcessFuncParallel(len(req.Infos), + concurrencyLevel, + loadSegmentFunc, "loadSegmentFunc") if err != nil { segmentGC() return err @@ -218,16 +240,23 @@ func (loader *segmentLoader) loadSegmentInternal(segment *Segment, } } - err = loader.loadIndexedFieldData(segment, indexedFieldInfos) + loadVecFieldsTask := func() error { + return loader.loadVecFieldData(segment, indexedFieldInfos) + } + loadScalarFieldsTask := func() error { + return loader.loadSealedFields(segment, nonIndexedFieldBinlogs) + } + + err = funcutil.ProcessTaskParallel(2, "loadVecAndScalarFields", + loadVecFieldsTask, loadScalarFieldsTask) if err != nil { return err } } else { - nonIndexedFieldBinlogs = loadInfo.BinlogPaths - } - err = loader.loadFiledBinlogData(segment, nonIndexedFieldBinlogs) - if err != nil { - return err + err = loader.loadGrowingFields(segment, loadInfo.BinlogPaths) + if err != nil { + return err + } } if pkFieldID == common.InvalidFieldID { @@ -258,50 +287,116 @@ func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBi return result } -func (loader *segmentLoader) loadFiledBinlogData(segment *Segment, fieldBinlogs []*datapb.FieldBinlog) error { - segmentType := segment.getType() +// Load segments concurrency granularity: each binlog file +// Deserialize blobs concurrency granularity: each field +func (loader *segmentLoader) loadSealedFields(segment *Segment, fields []*datapb.FieldBinlog) error { + if segment.getType() != segmentTypeSealed { + return fmt.Errorf("illegal segment type when load segment, collectionID = %v, should be sealed segment", segment.collectionID) + } + + futures := make([]*concurrency.Future, 0, len(fields)) iCodec := storage.InsertCodec{} - blobs := make([]*storage.Blob, 0) - for _, fieldBinlog := range fieldBinlogs { - for _, path := range fieldBinlog.Binlogs { - binLog, err := loader.cm.Read(path.GetLogPath()) + + for i := range fields { + field := fields[i] + + // We should acquire the CPU limiter, then start load the field binlogs, + // to make sure we can commit a CPU task as soon as possible + future := loader.cpuPool.Submit(func() (interface{}, error) { + blobs, err := loader.loadFieldBinlogs(field) if err != nil { - return err + return nil, err } - blob := &storage.Blob{ - Key: path.GetLogPath(), - Value: binLog, + + _, _, insertData, err := iCodec.Deserialize(blobs) + if err != nil { + return nil, err } - blobs = append(blobs, blob) - } + + log.Debug("deserialize blobs done", + zap.Int64("fieldID", field.FieldID), + zap.Int("len(insertData)", len(insertData.Data))) + + return nil, loader.loadSealedSegments(segment, insertData) + }) + futures = append(futures, future) } - _, _, insertData, err := iCodec.Deserialize(blobs) - if err != nil { - log.Warn(err.Error()) - return err + return concurrency.AwaitAll(futures) +} + +func (loader *segmentLoader) loadGrowingFields(segment *Segment, fieldBinlogs []*datapb.FieldBinlog) error { + if segment.getType() != segmentTypeGrowing { + return fmt.Errorf("illegal segment type when load segment, collectionID = %v, should be growing segment", segment.collectionID) } - switch segmentType { - case segmentTypeGrowing: - timestamps, ids, rowData, err := storage.TransferColumnBasedInsertDataToRowBased(insertData) + fieldBlobs := make([]*storage.Blob, 0) + for _, field := range fieldBinlogs { + blobs, err := loader.loadFieldBinlogs(field) if err != nil { return err } - return loader.loadGrowingSegments(segment, ids, timestamps, rowData) - case segmentTypeSealed: - return loader.loadSealedSegments(segment, insertData) - default: - err := errors.New(fmt.Sprintln("illegal segment type when load segment, collectionID = ", segment.collectionID)) + fieldBlobs = append(fieldBlobs, blobs...) + } + + iCodec := storage.InsertCodec{} + _, _, insertData, err := iCodec.Deserialize(fieldBlobs) + if err != nil { return err } + timestamps, ids, rowData, err := storage.TransferColumnBasedInsertDataToRowBased(insertData) + if err != nil { + return err + } + return loader.loadGrowingSegments(segment, ids, timestamps, rowData) +} + +// Load binlogs concurrently into memory from DataKV +func (loader *segmentLoader) loadFieldBinlogs(field *datapb.FieldBinlog) ([]*storage.Blob, error) { + log.Debug("load field binlogs", + zap.Int64("fieldID", field.FieldID), + zap.Int("len(binlogs)", len(field.Binlogs))) + + futures := make([]*concurrency.Future, 0, len(field.Binlogs)) + for i := range field.Binlogs { + path := field.Binlogs[i].GetLogPath() + future := loader.ioPool.Submit(func() (interface{}, error) { + binLog, err := loader.cm.Read(path) + if err != nil { + return nil, err + } + blob := &storage.Blob{ + Key: path, + Value: binLog, + } + + return blob, nil + }) + + futures = append(futures, future) + } + + blobs := make([]*storage.Blob, 0, len(field.Binlogs)) + for _, future := range futures { + if !future.OK() { + return nil, future.Err() + } + + blob := future.Value().(*storage.Blob) + blobs = append(blobs, blob) + } + + log.Debug("log field binlogs done", + zap.Int64("fieldID", field.FieldID)) + + return blobs, nil } -func (loader *segmentLoader) loadIndexedFieldData(segment *Segment, vecFieldInfos map[int64]*IndexedFieldInfo) error { +func (loader *segmentLoader) loadVecFieldData(segment *Segment, vecFieldInfos map[int64]*IndexedFieldInfo) error { for fieldID, fieldInfo := range vecFieldInfos { if fieldInfo.indexInfo == nil || !fieldInfo.indexInfo.EnableIndex { fieldBinlog := fieldInfo.fieldBinlog - err := loader.loadFiledBinlogData(segment, []*datapb.FieldBinlog{fieldBinlog}) + err := loader.loadSealedFields(segment, []*datapb.FieldBinlog{fieldBinlog}) if err != nil { return err } @@ -321,28 +416,47 @@ func (loader *segmentLoader) loadIndexedFieldData(segment *Segment, vecFieldInfo } func (loader *segmentLoader) loadFieldIndexData(segment *Segment, indexInfo *querypb.FieldIndexInfo) error { - indexBuffer := make([][]byte, 0) - indexCodec := storage.NewIndexFileBinlogCodec() + indexBuffer := make([][]byte, 0, len(indexInfo.IndexFilePaths)) filteredPaths := make([]string, 0, len(indexInfo.IndexFilePaths)) + futures := make([]*concurrency.Future, 0, len(indexInfo.IndexFilePaths)) + indexCodec := storage.NewIndexFileBinlogCodec() + for _, p := range indexInfo.IndexFilePaths { - log.Debug("load index file", zap.String("path", p)) - indexPiece, err := loader.cm.Read(p) - if err != nil { - return err - } + indexPath := p + if path.Base(indexPath) != storage.IndexParamsKey { + indexFuture := loader.cpuPool.Submit(func() (interface{}, error) { + indexBlobFuture := loader.ioPool.Submit(func() (interface{}, error) { + log.Debug("load index file", zap.String("path", indexPath)) + return loader.cm.Read(indexPath) + }) + + indexBlob, err := indexBlobFuture.Await() + if err != nil { + return nil, err + } - if path.Base(p) != storage.IndexParamsKey { - data, _, _, _, err := indexCodec.Deserialize([]*storage.Blob{{Key: path.Base(p), Value: indexPiece}}) - if err != nil { - return err - } - indexBuffer = append(indexBuffer, data[0].Value) - filteredPaths = append(filteredPaths, p) + data, _, _, _, err := indexCodec.Deserialize([]*storage.Blob{{Key: path.Base(indexPath), Value: indexBlob.([]byte)}}) + return data, err + }) + + futures = append(futures, indexFuture) + filteredPaths = append(filteredPaths, indexPath) } } + + err := concurrency.AwaitAll(futures) + if err != nil { + return err + } + + for _, index := range futures { + blobs := index.Value().([]*storage.Blob) + indexBuffer = append(indexBuffer, blobs[0].Value) + } + // 2. use index bytes and index path to update segment indexInfo.IndexFilePaths = filteredPaths - err := segment.segmentLoadIndexData(indexBuffer, indexInfo) + err = segment.segmentLoadIndexData(indexBuffer, indexInfo) return err } @@ -438,6 +552,9 @@ func (loader *segmentLoader) loadSealedSegments(segment *Segment, insertData *st for _, numRow := range numRows { totalNumRows += numRow } + log.Debug("loadSealedSegments inserts field data", + zap.Int64("fieldID", fieldID), + zap.Int("totalNumRows", int(totalNumRows))) err := segment.segmentLoadFieldData(fieldID, int(totalNumRows), data) if err != nil { // TODO: return or continue? @@ -693,13 +810,35 @@ func newSegmentLoader( cm storage.ChunkManager, factory msgstream.Factory) *segmentLoader { - return &segmentLoader{ + cpuNum := runtime.GOMAXPROCS(0) + // This error is not nil only if the options of creating pool is invalid + cpuPool, err := concurrency.NewPool(cpuNum, ants.WithPreAlloc(true)) + if err != nil { + log.Error("failed to create goroutine pool for segment loader", + zap.Error(err)) + panic(err) + } + + ioPool, err := concurrency.NewPool(cpuNum*2, ants.WithPreAlloc(true)) + if err != nil { + log.Error("failed to create goroutine pool for segment loader", + zap.Error(err)) + panic(err) + } + + loader := &segmentLoader{ historicalReplica: historicalReplica, streamingReplica: streamingReplica, cm: cm, etcdKV: etcdKV, + // init them later + ioPool: ioPool, + cpuPool: cpuPool, + factory: factory, } + + return loader } diff --git a/internal/querynode/segment_loader_test.go b/internal/querynode/segment_loader_test.go index 1da190bc61082..263a20d71fd63 100644 --- a/internal/querynode/segment_loader_test.go +++ b/internal/querynode/segment_loader_test.go @@ -188,7 +188,7 @@ func TestSegmentLoader_loadSegmentFieldsData(t *testing.T) { binlog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema) assert.NoError(t, err) - err = loader.loadFiledBinlogData(segment, binlog) + err = loader.loadSealedFields(segment, binlog) assert.NoError(t, err) } @@ -309,6 +309,31 @@ func TestSegmentLoader_invalid(t *testing.T) { err = loader.loadSegment(req, segmentTypeSealed) assert.Error(t, err) }) + + t.Run("Test Invalid SegmentType", func(t *testing.T) { + node, err := genSimpleQueryNode(ctx) + assert.NoError(t, err) + loader := node.loader + assert.NotNil(t, loader) + + req := &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchQueryChannels, + MsgID: rand.Int63(), + }, + DstNodeID: 0, + Infos: []*querypb.SegmentLoadInfo{ + { + SegmentID: defaultSegmentID, + PartitionID: defaultPartitionID, + CollectionID: defaultCollectionID, + }, + }, + } + + err = loader.loadSegment(req, commonpb.SegmentState_Dropped) + assert.Error(t, err) + }) } func TestSegmentLoader_checkSegmentSize(t *testing.T) { diff --git a/internal/storage/data_codec.go b/internal/storage/data_codec.go index d562cac594977..8ffd2f56efaba 100644 --- a/internal/storage/data_codec.go +++ b/internal/storage/data_codec.go @@ -451,8 +451,9 @@ func (insertCodec *InsertCodec) DeserializeAll(blobs []*Blob) ( var cID UniqueID var pID UniqueID var sID UniqueID - resultData := &InsertData{} - resultData.Data = make(map[FieldID]FieldData) + resultData := &InsertData{ + Data: make(map[FieldID]FieldData), + } for _, blob := range blobList { binlogReader, err := NewBinlogReader(blob.Value) if err != nil { diff --git a/internal/util/concurrency/future.go b/internal/util/concurrency/future.go new file mode 100644 index 0000000000000..28e8fde36e850 --- /dev/null +++ b/internal/util/concurrency/future.go @@ -0,0 +1,51 @@ +package concurrency + +type Future struct { + ch chan struct{} + value interface{} + err error +} + +func newFuture() *Future { + return &Future{ + ch: make(chan struct{}), + } +} + +func (future *Future) Await() (interface{}, error) { + <-future.ch + + return future.value, future.err +} + +func (future *Future) Value() interface{} { + <-future.ch + + return future.value +} + +func (future *Future) OK() bool { + <-future.ch + + return future.err == nil +} + +func (future *Future) Err() error { + <-future.ch + + return future.err +} + +func (future *Future) Inner() <-chan struct{} { + return future.ch +} + +func AwaitAll(futures []*Future) error { + for i := range futures { + if !futures[i].OK() { + return futures[i].err + } + } + + return nil +} diff --git a/internal/util/concurrency/pool.go b/internal/util/concurrency/pool.go new file mode 100644 index 0000000000000..c8b24c8685494 --- /dev/null +++ b/internal/util/concurrency/pool.go @@ -0,0 +1,45 @@ +package concurrency + +import "github.com/panjf2000/ants/v2" + +type Pool struct { + inner *ants.Pool +} + +func NewPool(cap int, opts ...ants.Option) (*Pool, error) { + pool, err := ants.NewPool(cap, opts...) + if err != nil { + return nil, err + } + + return &Pool{ + inner: pool, + }, nil +} + +func (pool *Pool) Submit(method func() (interface{}, error)) *Future { + future := newFuture() + err := pool.inner.Submit(func() { + defer close(future.ch) + res, err := method() + if err != nil { + future.err = err + } else { + future.value = res + } + }) + if err != nil { + future.err = err + close(future.ch) + } + + return future +} + +func (pool *Pool) Cap() int { + return pool.inner.Cap() +} + +func (pool *Pool) Running() int { + return pool.inner.Running() +} diff --git a/internal/util/concurrency/pool_test.go b/internal/util/concurrency/pool_test.go new file mode 100644 index 0000000000000..57e660861d1c7 --- /dev/null +++ b/internal/util/concurrency/pool_test.go @@ -0,0 +1,42 @@ +package concurrency + +import ( + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestPool(t *testing.T) { + pool, err := NewPool(runtime.NumCPU()) + assert.NoError(t, err) + + taskNum := pool.Cap() * 2 + futures := make([]*Future, 0, taskNum) + for i := 0; i < taskNum; i++ { + res := i + future := pool.Submit(func() (interface{}, error) { + time.Sleep(500 * time.Millisecond) + return res, nil + }) + futures = append(futures, future) + } + + assert.Greater(t, pool.Running(), 0) + AwaitAll(futures) + for i, future := range futures { + res, err := future.Await() + assert.NoError(t, err) + assert.Equal(t, err, future.Err()) + assert.True(t, future.OK()) + assert.Equal(t, res, future.Value()) + assert.Equal(t, i, res.(int)) + + // Await() should be idempotent + <-future.Inner() + resDup, errDup := future.Await() + assert.Equal(t, res, resDup) + assert.Equal(t, err, errDup) + } +} diff --git a/internal/util/funcutil/parallel.go b/internal/util/funcutil/parallel.go index 18938ae43bb38..03afada3cfedf 100644 --- a/internal/util/funcutil/parallel.go +++ b/internal/util/funcutil/parallel.go @@ -31,13 +31,17 @@ func GetFunctionName(i interface{}) string { return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() } -// ProcessFuncParallel process function in parallel. +type TaskFunc func() error +type ProcessFunc func(idx int) error +type DataProcessFunc func(data interface{}) error + +// ProcessFuncParallel processes function in parallel. // // ProcessFuncParallel waits for all goroutines done if no errors occur. // If some goroutines return error, ProcessFuncParallel cancels other goroutines as soon as possible and wait // for all other goroutines done, and returns the first error occurs. // Reference: https://stackoverflow.com/questions/40809504/idiomatic-goroutine-termination-and-error-handling -func ProcessFuncParallel(total, maxParallel int, f func(idx int) error, fname string) error { +func ProcessFuncParallel(total, maxParallel int, f ProcessFunc, fname string) error { if maxParallel <= 0 { maxParallel = 1 } @@ -64,7 +68,6 @@ func ProcessFuncParallel(total, maxParallel int, f func(idx int) error, fname st var wg sync.WaitGroup for begin := 0; begin < total; begin = begin + nPerBatch { j := begin - wg.Add(1) go func(begin int) { defer wg.Done() @@ -81,7 +84,112 @@ func ProcessFuncParallel(total, maxParallel int, f func(idx int) error, fname st for idx := begin; idx < end; idx++ { err = f(idx) if err != nil { - log.Debug(fname, zap.Error(err), zap.Any("idx", idx)) + log.Error(fname, zap.Error(err), zap.Any("idx", idx)) + break + } + } + + ch := done // send to done channel + if err != nil { + ch = errc // send to error channel + } + + select { + case ch <- err: + return + case <-quit: + return + } + }(j) + + routineNum++ + } + + log.Debug(fname, zap.Any("NumOfGoRoutines", routineNum)) + + if routineNum <= 0 { + return nil + } + + count := 0 + for { + select { + case err := <-errc: + close(quit) + wg.Wait() + return err + case <-done: + count++ + if count == routineNum { + wg.Wait() + return nil + } + } + } +} + +// ProcessTaskParallel processes tasks in parallel. +// Similar to ProcessFuncParallel +func ProcessTaskParallel(maxParallel int, fname string, tasks ...TaskFunc) error { + // option := parallelProcessOption{} + // for _, opt := range opts { + // opt(&option) + // } + + if maxParallel <= 0 { + maxParallel = 1 + } + + t := time.Now() + defer func() { + log.Debug(fname, zap.Any("time cost", time.Since(t))) + }() + + total := len(tasks) + nPerBatch := (total + maxParallel - 1) / maxParallel + log.Debug(fname, zap.Any("total", total)) + log.Debug(fname, zap.Any("nPerBatch", nPerBatch)) + + quit := make(chan bool) + errc := make(chan error) + done := make(chan error) + getMin := func(a, b int) int { + if a < b { + return a + } + return b + } + routineNum := 0 + var wg sync.WaitGroup + for begin := 0; begin < total; begin = begin + nPerBatch { + j := begin + + // if option.preExecute != nil { + // err := option.preExecute() + // if err != nil { + // close(quit) + // wg.Wait() + // return err + // } + // } + + wg.Add(1) + go func(begin int) { + defer wg.Done() + + select { + case <-quit: + return + default: + } + + err := error(nil) + + end := getMin(total, begin+nPerBatch) + for idx := begin; idx < end; idx++ { + err = tasks[idx]() + if err != nil { + log.Error(fname, zap.Error(err), zap.Any("idx", idx)) break } } @@ -98,6 +206,9 @@ func ProcessFuncParallel(total, maxParallel int, f func(idx int) error, fname st return } }(j) + // if option.postExecute != nil { + // option.postExecute() + // } routineNum++ }