Skip to content

Commit

Permalink
workaround for go generics
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo committed Feb 26, 2025
1 parent 6a85b40 commit ad1932c
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 18 deletions.
2 changes: 1 addition & 1 deletion util.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,5 @@ func ForEach(m Multiaddr, cb func(c Component) bool) {

func (m Multiaddr) Match(p ...meg.Pattern) (bool, error) {
matcher := meg.PatternToMatcher(p...)
return meg.Match(matcher, m)
return meg.Match(matcher, m, func(c *Component) meg.Matchable { return c })
}
10 changes: 7 additions & 3 deletions x/meg/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ func preallocateCapture() *preallocatedCapture {

var webrtcMatchPrealloc *preallocatedCapture

func componentPtrToMatchable(c *multiaddr.Component) *multiaddr.Component {
return c
}

func (p *preallocatedCapture) IsWebRTCDirectMultiaddr(addr multiaddr.Multiaddr) (bool, int) {
found, _ := meg.Match(p.matcher, addr)
found, _ := meg.Match(p.matcher, addr, componentPtrToMatchable)
return found, len(p.certHashes)
}

Expand Down Expand Up @@ -107,7 +111,7 @@ func isWebTransportMultiaddrPrealloc() *preallocatedCapture {

func IsWebTransportMultiaddrPrealloc(m multiaddr.Multiaddr) (bool, int) {
p := isWebTransportMultiaddrPrealloc()
found, _ := meg.Match(p.matcher, m)
found, _ := meg.Match(p.matcher, m, componentPtrToMatchable)
return found, len(p.certHashes)
}

Expand Down Expand Up @@ -365,7 +369,7 @@ func BenchmarkIsWebTransportMultiaddrNoCapturePrealloc(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
isWT, _ := meg.Match(wtPreallocNoCapture, addr)
isWT, _ := meg.Match(wtPreallocNoCapture, addr, componentPtrToMatchable)
if !isWT {
b.Fatal("unexpected result")
}
Expand Down
12 changes: 8 additions & 4 deletions x/meg/meg.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ type Matchable interface {
// Match returns whether the given Components match the Pattern defined in MatchState.
// Errors are used to communicate capture errors.
// If the error is non-nil the returned bool will be false.
func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) {
// The ptrToMatchable function is used to convert type *T to a Matchable..
// This is due to a limitation of Go generics, where we cannot say *T implements Matchable.
// When meg moves out of the x/ directory, we can reference the `*Component` type directly and avoid this limitation.
func Match[S ~[]T, T any, G Matchable](matcher Matcher, components S, ptrToMatchable func(*T) G) (bool, error) {
states := matcher.states
startStateIdx := matcher.startIdx

Expand All @@ -92,19 +95,20 @@ func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) {

currentStates = appendState(currentStates, states, startStateIdx, nil, visitedBitSet)

for _, c := range components {
for ic, _ := range components {

Check failure on line 98 in x/meg/meg.go

View workflow job for this annotation

GitHub Actions / go-check / All

unnecessary assignment to the blank identifier (S1005)
clear(visitedBitSet)
if len(currentStates.states) == 0 {
return false, nil
}
for i, stateIndex := range currentStates.states {
s := states[stateIndex]
if s.codeOrKind == matchAny || (s.codeOrKind >= 0 && s.codeOrKind == c.Code()) {
cPtr := ptrToMatchable(&components[ic])
if s.codeOrKind == matchAny || (s.codeOrKind >= 0 && s.codeOrKind == cPtr.Code()) {
cm := currentStates.captures[i]
if s.capture != nil {
next := &capture{
f: s.capture,
v: c,
v: cPtr,
}
if cm == nil {
cm = next
Expand Down
24 changes: 14 additions & 10 deletions x/meg/meg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,30 @@ type codeAndValue struct {
}

// Code implements Matchable.
func (c codeAndValue) Code() int {
func (c *codeAndValue) Code() int {
return c.code
}

// Value implements Matchable.
func (c codeAndValue) Value() string {
func (c *codeAndValue) Value() string {
return c.val
}

// Bytes implements Matchable.
func (c codeAndValue) Bytes() []byte {
func (c *codeAndValue) Bytes() []byte {
return []byte(c.val)
}

// RawValue implements Matchable.
func (c codeAndValue) RawValue() []byte {
func (c *codeAndValue) RawValue() []byte {
return []byte(c.val)
}

var _ Matchable = codeAndValue{}
var _ Matchable = &codeAndValue{}

func codeAndValuePtrToMatchable(c *codeAndValue) *codeAndValue {
return c
}

func TestSimple(t *testing.T) {
type testCase struct {
Expand Down Expand Up @@ -106,12 +110,12 @@ func TestSimple(t *testing.T) {

for i, tc := range testCases {
for _, m := range tc.shouldMatch {
if matches, err := Match(tc.pattern, codesToCodeAndValue(m)); !matches {
if matches, err := Match(tc.pattern, codesToCodeAndValue(m), codeAndValuePtrToMatchable); !matches {
t.Fatalf("failed to match %v with %v. idx=%d. err=%v", m, tc.pattern, i, err)
}
}
for _, m := range tc.shouldNotMatch {
if matches, _ := Match(tc.pattern, codesToCodeAndValue(m)); matches {
if matches, _ := Match(tc.pattern, codesToCodeAndValue(m), codeAndValuePtrToMatchable); matches {
t.Fatalf("failed to not match %v with %v. idx=%d", m, tc.pattern, i)
}
}
Expand All @@ -125,7 +129,7 @@ func TestSimple(t *testing.T) {
return true
}
}
matches, _ := Match(tc.pattern, codesToCodeAndValue(notMatch))
matches, _ := Match(tc.pattern, codesToCodeAndValue(notMatch), codeAndValuePtrToMatchable)
return !matches
}, &quick.Config{}); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -172,7 +176,7 @@ func TestCapture(t *testing.T) {
_ = testCases
for _, tc := range testCases {
state, assert := tc.setup()
if matches, _ := Match(state, tc.parts); !matches {
if matches, _ := Match(state, tc.parts, codeAndValuePtrToMatchable); !matches {
t.Fatalf("failed to match %v with %v", tc.parts, state)
}
assert()
Expand Down Expand Up @@ -255,7 +259,7 @@ func FuzzMatchesRegexpBehavior(f *testing.F) {
return
}
p := PatternToMatcher(pattern...)
otherMatched, _ := Match(p, bytesToCodeAndValue(corpus))
otherMatched, _ := Match(p, bytesToCodeAndValue(corpus), codeAndValuePtrToMatchable)
if otherMatched != matched {
t.Log("regexp", string(regexpPattern))
t.Log("corpus", string(corpus))
Expand Down

0 comments on commit ad1932c

Please sign in to comment.