Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add cosine similarity query #3464

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions client/request/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ const (

DocIDArgName = "docID"

AverageFieldName = "_avg"
CountFieldName = "_count"
DocIDFieldName = "_docID"
GroupFieldName = "_group"
DeletedFieldName = "_deleted"
SumFieldName = "_sum"
VersionFieldName = "_version"
MaxFieldName = "_max"
MinFieldName = "_min"
AliasFieldName = "_alias"
AverageFieldName = "_avg"
CountFieldName = "_count"
DocIDFieldName = "_docID"
GroupFieldName = "_group"
DeletedFieldName = "_deleted"
SumFieldName = "_sum"
VersionFieldName = "_version"
MaxFieldName = "_max"
MinFieldName = "_min"
AliasFieldName = "_alias"
SimilarityFieldName = "_similarity"

// New generated document id from a backed up document,
// which might have a different _docID originally.
Expand Down Expand Up @@ -104,16 +105,17 @@ var (
}

ReservedFields = map[string]struct{}{
TypeNameFieldName: {},
VersionFieldName: {},
GroupFieldName: {},
CountFieldName: {},
SumFieldName: {},
AverageFieldName: {},
DocIDFieldName: {},
DeletedFieldName: {},
MaxFieldName: {},
MinFieldName: {},
TypeNameFieldName: {},
VersionFieldName: {},
GroupFieldName: {},
CountFieldName: {},
SumFieldName: {},
AverageFieldName: {},
DocIDFieldName: {},
DeletedFieldName: {},
MaxFieldName: {},
MinFieldName: {},
SimilarityFieldName: {},
}

Aggregates = map[string]struct{}{
Expand Down
26 changes: 26 additions & 0 deletions client/request/similarity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2025 Democratized Data Foundation
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package request

// Similarity is a functional field that defines the
// parameters to calculate the cosine similarity between two vectors.
type Similarity struct {
Field
// Vector contains the vector to compare the target field to.
//
// It will be of type Int, Float32 or Float64. It must be the same type and length as Target.
Vector any

// Target is the field in the host object that we will compare the the vector to.
//
// It must be a field of type Int, Float32 or Float64. It must be the same type and length as Vector.
Target string
}
9 changes: 9 additions & 0 deletions internal/planner/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var (
ErrUnknownRelationType = errors.New("failed sub selection, unknown relation type")
ErrUnknownExplainRequestType = errors.New("can not explain request of unknown type")
ErrUpsertMultipleDocuments = errors.New("cannot upsert multiple matching documents")
ErrMismatchLengthOnSimilarity = errors.New("source and vector must be of the same length")
)

func NewErrUnknownDependency(name string) error {
Expand All @@ -52,3 +53,11 @@ func NewErrFailedToCollectExecExplainInfo(inner error) error {
func NewErrSubTypeInit(inner error) error {
return errors.Wrap(errSubTypeInit, inner)
}

func NewErrMismatchLengthOnSimilarity(source, vector int) error {
return errors.WithStack(
ErrMismatchLengthOnSimilarity,
errors.NewKV("Source", source),
errors.NewKV("Vector", vector),
)
}
20 changes: 20 additions & 0 deletions internal/planner/mapper/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,26 @@ func getRequestables(
Key: getRenderKey(&f.Field),
})

mapping.Add(index, f.Name)
case *request.Similarity:
index := mapping.GetNextIndex()
fields = append(fields, &Similarity{
Field: Field{
Index: index,
Name: f.Name,
},
Vector: f.Vector,
SimilarityTarget: Targetable{
Field: Field{
Index: mapping.FirstIndexOfName(f.Target),
Name: f.Target,
},
},
})
mapping.RenderKeys = append(mapping.RenderKeys, core.RenderKey{
Index: index,
Key: getRenderKey(&f.Field),
})
mapping.Add(index, f.Name)
default:
return nil, nil, client.NewErrUnhandledType("field", field)
Expand Down
26 changes: 26 additions & 0 deletions internal/planner/mapper/similarity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2025 Democratized Data Foundation
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package mapper

import "github.com/sourcenetwork/defradb/internal/core"

// Similarity represents an cosine similarity operation definition.
type Similarity struct {
Field
// The mapping of this aggregate's parent/host.
*core.DocumentMapping

// The targetted field for the cosine similarity
SimilarityTarget Targetable

// The vector to compare the target field to.
Vector any
}
1 change: 1 addition & 0 deletions internal/planner/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ var (
_ planNode = (*valuesNode)(nil)
_ planNode = (*viewNode)(nil)
_ planNode = (*lensNode)(nil)
_ planNode = (*similarityNode)(nil)

_ MultiNode = (*parallelNode)(nil)
_ MultiNode = (*topLevelNode)(nil)
Expand Down
9 changes: 9 additions & 0 deletions internal/planner/planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ func (p *Planner) expandSelectTopNodePlan(plan *selectTopNode, parentPlan *selec
p.expandLimitPlan(plan, parentPlan)
}

p.expandSimilarityPlans(plan)

return nil
}

Expand All @@ -249,6 +251,13 @@ func (p *Planner) expandAggregatePlans(plan *selectTopNode) {
}
}

func (p *Planner) expandSimilarityPlans(plan *selectTopNode) {
for _, sim := range plan.similarity {
sim.SetPlan(plan.planNode)
plan.planNode = sim
}
}

func (p *Planner) expandMultiNode(multiNode MultiNode, parentPlan *selectTopNode) error {
for _, child := range multiNode.Children() {
if err := p.expandPlan(child, parentPlan); err != nil {
Expand Down
2 changes: 2 additions & 0 deletions internal/planner/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ func (n *scanNode) initFields(fields []mapper.Requestable) error {
n.tryAddFieldWithName(target.Field.Name)
}
}
case *mapper.Similarity:
n.tryAddFieldWithName(requestable.SimilarityTarget.Name)
}
}
return nil
Expand Down
39 changes: 25 additions & 14 deletions internal/planner/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@
// selectNode is used pre-wiring of the plan (before expansion and all).
selectNode *selectNode

// This is added temporarity until Planner is refactored
// https://github.com/sourcenetwork/defradb/issues/3467
similarity []*similarityNode

// plan is the top of the plan graph (the wired and finalized plan graph).
planNode planNode
}
Expand Down Expand Up @@ -236,14 +240,14 @@
// creating scanNodes, typeIndexJoinNodes, and splitting
// the necessary filters. Its designed to work with the
// planner.Select construction call.
func (n *selectNode) initSource() ([]aggregateNode, error) {
func (n *selectNode) initSource() ([]aggregateNode, []*similarityNode, error) {
if n.selectReq.CollectionName == "" {
n.selectReq.CollectionName = n.selectReq.Name
}

sourcePlan, err := n.planner.getSource(n.selectReq)
if err != nil {
return nil, err
return nil, nil, err
}
n.source = sourcePlan.plan
n.origSource = sourcePlan.plan
Expand All @@ -264,7 +268,7 @@
if n.selectReq.Cid.HasValue() {
c, err := cid.Decode(n.selectReq.Cid.Value())
if err != nil {
return nil, err
return nil, nil, err
}

// This exists because the fetcher interface demands a []Prefixes, yet the versioned
Expand Down Expand Up @@ -293,17 +297,17 @@
}
}

aggregates, err := n.initFields(n.selectReq)
aggregates, similarity, err := n.initFields(n.selectReq)
if err != nil {
return nil, err
return nil, nil, err
}

if isScanNode {
origScan.index = findIndexByFilteringField(origScan)
origScan.initFetcher(n.selectReq.Cid)
}

return aggregates, nil
return aggregates, similarity, nil
}

func findIndexByFilteringField(scanNode *scanNode) immutable.Option[client.IndexDescription] {
Expand Down Expand Up @@ -354,8 +358,9 @@
return immutable.None[client.IndexDescription]()
}

func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, error) {
func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, []*similarityNode, error) {
aggregates := []aggregateNode{}
similarity := []*similarityNode{}
// loop over the sub type
// at the moment, we're only testing a single sub selection
for _, field := range selectReq.Fields {
Expand All @@ -381,7 +386,7 @@
}

if aggregateError != nil {
return nil, aggregateError
return nil, nil, aggregateError

Check warning on line 389 in internal/planner/select.go

View check run for this annotation

Codecov / codecov/patch

internal/planner/select.go#L389

Added line #L389 was not covered by tests
}

if plan != nil {
Expand All @@ -408,11 +413,11 @@
commitPlan := n.planner.DAGScan(commitSlct)

if err := n.addSubPlan(f.Index, commitPlan); err != nil {
return nil, err
return nil, nil, err

Check warning on line 416 in internal/planner/select.go

View check run for this annotation

Codecov / codecov/patch

internal/planner/select.go#L416

Added line #L416 was not covered by tests
}
} else if f.Name == request.GroupFieldName {
if selectReq.GroupBy == nil {
return nil, ErrGroupOutsideOfGroupBy
return nil, nil, ErrGroupOutsideOfGroupBy
}
n.groupSelects = append(n.groupSelects, f)
} else if f.Name == request.LinksFieldName &&
Expand All @@ -427,13 +432,17 @@
// a traditional join here
err := n.addTypeIndexJoin(f)
if err != nil {
return nil, err
return nil, nil, err

Check warning on line 435 in internal/planner/select.go

View check run for this annotation

Codecov / codecov/patch

internal/planner/select.go#L435

Added line #L435 was not covered by tests
}
}
case *mapper.Similarity:
var simFilter *mapper.Filter
selectReq.Filter, simFilter = filter.SplitByFields(selectReq.Filter, f.Field)
similarity = append(similarity, n.planner.Similarity(f, simFilter))
}
}

return aggregates, nil
return aggregates, similarity, nil
}

func (n *selectNode) addTypeIndexJoin(subSelect *mapper.Select) error {
Expand Down Expand Up @@ -482,7 +491,7 @@
s.collection = col
}

aggregates, err := s.initFields(selectReq)
aggregates, similarity, err := s.initFields(selectReq)
if err != nil {
return nil, err
}
Expand All @@ -508,6 +517,7 @@
order: orderPlan,
group: groupPlan,
aggregates: aggregates,
similarity: similarity,
docMapper: docMapper{selectReq.DocumentMapping},
}
return top, nil
Expand All @@ -526,7 +536,7 @@
orderBy := selectReq.OrderBy
groupBy := selectReq.GroupBy

aggregates, err := s.initSource()
aggregates, similarity, err := s.initSource()
if err != nil {
return nil, err
}
Expand All @@ -552,6 +562,7 @@
order: orderPlan,
group: groupPlan,
aggregates: aggregates,
similarity: similarity,
docMapper: docMapper{selectReq.DocumentMapping},
}
return top, nil
Expand Down
Loading
Loading